This repository contains the complete implementation of a sophisticated Transformer-based language model, featuring a unique Multi Latent Attention (MLA) mechanism and a Mixture-of-Experts (MoE) feed-forward layer. The model is designed for high-performance text generation and is built to scale efficiently using distributed training.
A regular transformer version of this model (single FFN, no routing) which you can find on the "Old
" branch of this repo beat gpt-2 large testing perplexity on wikitext-2 in 2h36 when trained on a node of 8 AMD MI300X with ~300M parameters.
This project provides the full codebase, from the architectural backbone and data processing pipelines to single-GPU and distributed training scripts, and a ready-to-use interactive Streamlit application for inference.
- Multi Latent Attention (MLA): A novel attention mechanism first introduced in the Deepseek-V3 paper that splits query and key projections into two paths: a content-based path and a rotary-based path. This allows the model to separately process and weigh contextual information and positional information, leading to more nuanced text generation.
- Mixture-of-Experts (MoE) Layers: The feed-forward network in each Transformer block is replaced with a sparse MoE layer. This allows the model to have a very high parameter count while only activating a small subset of expert networks for each token, drastically improving training and inference efficiency. The router architecture was inspired by the HuggingFace post on MoE's.
- Rotary Position Embeddings (RoPE): Implements state-of-the-art relative position embeddings, which are embedded into the MLA mechanism.
- Distributed Training Ready: Includes a script (
main_distributed.py
) that leverages PyTorch'sDistributedDataParallel
(DDP) for robust and scalable multi-GPU training (tested on a node of 8 AMD MI300X). - Custom Data Pipeline: A dedicated data loader (
OpenWebText.py
) for processing the OpenWebText dataset, including on-the-fly tokenization, cleaning, and batching. - Interactive Demo: A user-friendly Streamlit application (
user.py
) to interact with the trained model, featuring real-time text generation and adjustable sampling parameters.
This repository is organized to provide a clear path from understanding the model's architecture to training it and finally using it for inference.
model.py
: This is the heart of the project. It defines the complete model architecture, including:TheTransformer
: The main class that assembles the entire model.MultiHeadAttention
: The custom Multi Latent Attention implementation.GatingNetwork
&TransformerBlock
: The core components for the Mixture-of-Experts (MoE) layers.RotaryPositionEncoding
: The implementation for RoPE.
The repository includes two scripts for training the model, catering to different hardware setups.
-
training.py
(Single-GPU Training)- Purpose: A straightforward script for training the model on a single GPU.
- Details: It handles data loading, model initialization, a standard training loop with mixed-precision support (
torch.amp
), and a custom learning rate scheduler. - Use Case: Ideal for debugging, running smaller-scale experiments, or for users who do not have a multi-GPU environment.
-
main_distributed.py
(Multi-GPU Distributed Training)- Purpose: The primary script for training the full-scale model efficiently across multiple GPUs.
- Details: It leverages PyTorch's
DistributedDataParallel
(DDP) andDistributedSampler
to parallelize the training process. It also includes an optional token-dropping feature as a regularization technique. - Use Case: The recommended script for training the model from scratch to achieve the best performance on large datasets.
user.py
(Interactive Streamlit Demo)- Purpose: A web-based application for generating text with the trained model.
- How to Use:
- Ensure you have a trained model checkpoint (e.g.,
weights/mol.pth
). The script is pre-configured to look for this file. - Install the required Python packages:
pip install -r requirements.txt
. - Run the application from your terminal:
streamlit run user.py
- Ensure you have a trained model checkpoint (e.g.,