akhilkedia / TranformersGetStable

[ICML 2024] Official Repository for the paper "Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models"

Home Page:https://arxiv.org/abs/2403.09635

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models

🌟 ArXiv Preprint | ICML Poster

Train transformers with 1000 layers!

πŸ”— Quick Links

Brief Introduction

We share the code for -

  1. Running simuations on singal propagation for transformer components, written in simple pytorch. See the section Simulations
  2. Signal propagation for the entire transformer model, see Xavier Signal Propagation Figures
  3. Signal propagation for our proposed changes to transformer architecture, DeepScaleLM. See DSLM Signal Propagation Figures
  4. For running baseline pretraining and downstream finetuning on BERT without DeepScaleLM to reproduce our results.

The entire model signal propagation and BERT training are based on Nvidia Megatron. HuggingFace code for DeepScaleLM will be released soon!

Each folder has a readme with instructions to run the same!

Simulations

Environment Setup

Use pip install -r simulations/requirements.txt. Only requires torch, numpy, tqdm and matplotlib. A CUDA GPU is required as well, tested on 1 A100.

Running

Run the file run_all.sh to reproduce all our simulations for transformer components. The file approximations.py can also plots the approximations of RELU/MLP covariance.

Expected Output

Expected output is provided in expected_output.txt

Xavier Signal Propagation Figures

These figures show forward and back change in variances for vanilla transformer models.

Running

cd into xavier_signal_propagation_figures and run the file make_figs.sh.

It will make xavier figures preln_forward.png, preln_backward.png and postln_backward.png and exit.

These files are already included, delete these .png files to recreate.

Expected Output

drawing

drawing

drawing

DSLM Signal Propagation Figures

These figures show Unit forward and back change in variances.

Environment Setup

Use conda env create -f environment.yml. Tested on 8x A100 80GB. Same enviroment is also required for Xavier and Pre-training.

Also requires pre-training data from bert_wiki_pretraining/prepare_data.sh

Running

cd into DSLM_signal_propagation_figures and run the file make_figs.sh.

It will make xavier figures preln_forward.png, preln_backward.png and postln_backward.png and exit.

These files are already included, delete these .png files to recreate.

Expected Output

drawing

drawing

drawing

Baseline BERT Model Pretraining and Finetuning

Running

  1. cd into bert_wiki_pretraining
  2. Run the file prepare_data.sh to download and process the pre-training dataset. This is Wikipedia from TFDS.
  3. Run the file run_bert_wiki.sh to run the pretraining.
  4. Run the files examples/run_mnli.sh, examples/run_qqp.sh, examples/run_race.sh to run finetuning.

Bugs or Questions?

If you have any questions related to the code or the paper, feel free to email Akhil Kedia (akhil.kedia @ samsung.com). If you encounter any problems when using the code, you can open an issue!

Citation

Please cite our paper if you find the repo helpful in your work:

@article{DSLM,
  author       = {Akhil Kedia and
                  Mohd Abbas Zaidi and
                  Sushil Khyalia and
                  Jungho Jung and
                  Harshith Goka and
                  Haejun Lee},
  title        = {Transformers Get Stable: An End-to-End Signal Propagation Theory for
                  Language Models},
  journal      = {CoRR},
  volume       = {abs/2403.09635},
  year         = {2024},
  url          = {https://doi.org/10.48550/arXiv.2403.09635},
  doi          = {10.48550/ARXIV.2403.09635},
  eprinttype    = {arXiv},
  eprint       = {2403.09635}
}

About

[ICML 2024] Official Repository for the paper "Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models"

https://arxiv.org/abs/2403.09635


Languages

Language:Python 86.7%Language:C++ 5.6%Language:Shell 4.9%Language:Cuda 2.2%Language:Perl 0.3%Language:C 0.2%Language:HTML 0.2%Language:Makefile 0.0%