w32zhong / tinyllama-bitnet

Train your own small bitnet model

Repository from Github https://github.comw32zhong/tinyllama-bitnetRepository from Github https://github.comw32zhong/tinyllama-bitnet

Edit: this is getting stars so im adding this note. For anyone trying to replicate there are some fixes I need to push. Will try to get it done asap. Thanks.

Tinyllama Bitnet

This repository demonstrates training your own BitNet model based on the llama2 architecture. Unedited, the script will train a ~84M param model on ~1.5B tokens.

File structure

train.py - the entire training process including preparing the data, defining the model architecture, and training model.

utils.py - contains the BitLinear implementation, and convert_to_bitnet function for converting huggingface's LlamaForCausalLM to BitNet.

inference.py - run inference with a trained BitNet model.

I wanted to make this process as straight forward and hackable as possible, so all of these scripts are minimal and easily adjustable.

Training Data

The script currently uses a 15% subset of openwebtext2 for training. This has been pretokenized at a context length of 256 for ease of testing, but code is also included to tokenize data yourself. You can replace a couple lines in the script to train on pretty much anything else you want.

Dependencies

You'll want to install these packages. The last two are optional and are for logging and HF auth.

BitNet

The BitLinear definition is copied straight from the released training details manuscript. The BitNet architecture is defined by loading a blank Llama2 model using huggingface, and then making the necessary replacements (as per the manuscript):

  1. Replace all nn.Linear in attention and SwiGLU with BitLinear
  2. Remove RMSNorm before attention and SwiGLU because BitLinear has built-in RMSNorm.

About

Train your own small bitnet model

License:MIT License


Languages

Language:Python 100.0%