jongwooko / distillm

Official PyTorch implementation of DistiLLM: Towards Streamlined Distillation for Large Language Models

Home Page:https://arxiv.org/abs/2402.03898

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

DistiLLM: Towards Streamlined Distillation for Large Language Models (ICML 2024)

Official PyTorch implementation of DistiLLM, as presented in our paper:

DistiLLM: Towards Streamlined Distillation for Large Language Models
Jongwoo Ko, Sungnyun Kim, Tianyi Chen, Se-Young Yun
KAIST AI and Microsoft

🚀 Updates

  • (24.05.01) Our paper has been accepted in ICML 2024. We are open to receiving any discussions and will reflect them in the camera-ready version. Looking forward to seeing you in Vienna!
  • (24.03.13) Release LoRA checkpoints for OpenLLaMa2-3B

Environment

bash install.sh

Our code is based on this commit of HuggingFace Transformers by following MiniLLM.

Data

Resources

  • The training/evaluation intruction-response data before processing can be downloaded from this link.
  • The plain-text corpus $\mathcal{D}_\text{PT}$ can be download from the HugginFace datasets repository.

Data Processing

Get plain-text corpus $\mathcal{D}_\text{PT}$:

python3 tools/get_openwebtext.py

This script will replace the continuous \n in each document with a special token "<@x(x!>" and write each document in OpenWebText in a line, which is convenient for parallel processing. In data/openwebtext/data.txt, we give an example of the resulting format. You can follow this format to prepare other corpus beyond OpenWebText.

Tokenize the data and store them in binary files:

bash scripts/gpt2/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/gpt2/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Train / Validation Data

bash scripts/opt/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/opt/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data

bash scripts/llama/tools/process_data_dolly.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process Dolly Train / Validation Data
bash scripts/llama/tools/process_data_pretrain.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM} # Process OpenWebText Corpus Train / Validation Data

Base Pre-trained Models

To run fine-tuning or standard KD baselines, you need to download the model checkpoints from [Huggingface Model Hub] and put them in checkpoints/. For example, for gpt2-large, you can download the model from this link and put them in checkpoints/gpt2-large.

Alternatively, you can also change the CKPT variable in each script to the corresponding model name to enable Transformers to download the base models automatically. For example, set CKPT="gpt2-large" in scripts/gpt2/sft/sft_large.sh causes download of the gpt2-large base model from the HugginFace model hub.

Train

We provide example commands for GPT-2 models. Similar scripts for model families can be found in scripts/opt and scripts/openllama2. All our experiments are conducted on 4 * 40A100, which can be reduced for small models.

Baselines

The final checkpoints are selected by the ROUGE-L scores.

Fine-tune the teacher models

bash scripts/gpt2/sft/sft_xlarge.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

SFT Baselines

bash scripts/gpt2/sft/sft_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/sft/sft_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

KD Baselines

bash scripts/gpt2/kd/kd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/kd/kd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

SeqKD Baselines

Generate and process responses with the teacher:

bash scripts/gpt2/tools/generate_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/tools/process_pseudo_data_seqkd.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Fine-tune the model with SeqKD:

bash scripts/gpt2/seqkd/seqkd_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/seqkd/seqkd_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Student Initialization

The final checkpoints are selected by the validation loss.

bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

ImitKD Baselines

bash scripts/gpt2/imitkd/imitkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/imitkd/imitkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

MiniLLM Baselines

bash scripts/gpt2/minillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/minillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

GKD Baselines

bash scripts/gpt2/gkd/gkd_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/gkd/gkd_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

DistiLLM

The final checkpoints are selected by the validation loss.

bash scripts/gpt2/init/init_base.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_medium.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/init/init_large.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

The final checkpoints are selected by the ROUGE-L scores.

bash scripts/gpt2/distillm/train_base_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_medium_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}
bash scripts/gpt2/distillm/train_large_xl.sh ${/PATH/TO/DistiLLM} ${MASTER_PORT} ${GPU_NUM}

Run Evaluation

bash scripts/gpt2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM}
bash scripts/opt/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 
bash scripts/openllama2/eval/run_eval.sh ${GPU_IDX} ${/PATH/TO/DistiLLM} 

Results

DistiLLM outperforms other KD baselines in terms of both generation performance and training speed for various model families such as GPT-2, OPT, and OpenLLaMA.

Checkpoints (OpenLLaMA-3B)

We share the LoRA weights for OpenLLaMA-3B in google drive.

Acknowledgement

Our code is based on the code of ICLR2024 MiniLLM: Knowledge Distillation of Large Language Models.

BibTeX

If you find this repo useful for your research, please consider citing our paper:

@article{ko2024distillm,
      title={DistiLLM: Towards Streamlined Distillation for Large Language Models}, 
      author={Jongwoo Ko and Sungnyun Kim and Tianyi Chen and Se-Young Yun},
      year={2024},
      journal={arXiv preprint arXiv:2402.03898},
}

Contact

About

Official PyTorch implementation of DistiLLM: Towards Streamlined Distillation for Large Language Models

https://arxiv.org/abs/2402.03898


Languages

Language:Python 90.7%Language:MDX 7.9%Language:Shell 0.5%Language:Jupyter Notebook 0.4%Language:Cuda 0.4%Language:Dockerfile 0.0%Language:C++ 0.0%Language:C 0.0%Language:Cython 0.0%Language:Makefile 0.0%Language:Jsonnet 0.0%