Simu's repositories
Griffin-Jax
Jax implementation of "Griffin: Mixing Gated Linear Recurrences with Local Attention for Efficient Language Models"
miniF2F-code
Dataset of formal Olympiad-level mathematics problems solved with Python code instructions.
Tri-RMSNorm
Efficient kernel for RMS normalization with fused operations, includes both forward and backward passes, compatibility with PyTorch.
LongConv-Jax
Jax/Flax/Linen implementation of "Simple Hardware-Efficient Long Convolutions for Sequence Modeling"
triton-activations
Collection of neural network activation function kernels for Triton Language Compiler by OpenAI
GradientAscent-Jax
Custom gradient ascent solver (optimizer) for JAX/Flax models
lmppl-cli-csv-wrapper
A tiny CLI wrapper around lmppl for Pre-Trained Language Models Perplexity Calculation for CSV files
Mixture-of-Depths-Jax
Jax module for the paper: "Mixture-of-Depths: Dynamically allocating compute in transformer-based language models"
Ring-Attention-Jax
Packaged Ring Attention with Blockwise Transformers for Near-Infinite Context implemented in Jax + Flax.
Python-Template
Python Package Template is all you need
simudt.github.io
blog for the AI era
Composable-Datasets
Transform JSONL Q&A datasets to instruct format with ease
jax-triton
jax-triton contains integrations between JAX and OpenAI Triton
MEGABYTE-pytorch-DS
Modificated DeepSpeed training setup fork of MEGABYTE - PyTorch by lucidrains, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
PaLM-rlhf-pytorch-DS
Modificated DeepSpeed training setup fork of RLHF (Reinforcement Learning with Human Feedback) by lucidrains on top of the PaLM architecture. Basically ChatGPT but with PaLM
Simba
A simpler Pytorch + Zeta Implementation of the paper: "SiMBA: Simplified Mamba-based Architecture for Vision and Multivariate Time series"
zeta
Build high-performance AI models with modular building blocks