π ArXiv Preprint | ICML Poster
Train transformers with 1000 layers!
- Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models
We share the code for -
- Running simuations on singal propagation for transformer components, written in simple pytorch. See the section Simulations
- Signal propagation for the entire transformer model, see Xavier Signal Propagation Figures
- Signal propagation for our proposed changes to transformer architecture, DeepScaleLM. See DSLM Signal Propagation Figures
- For running baseline pretraining and downstream finetuning on BERT without DeepScaleLM to reproduce our results.
The entire model signal propagation and BERT training are based on Nvidia Megatron. HuggingFace code for DeepScaleLM will be released soon!
Each folder has a readme with instructions to run the same!
Use pip install -r simulations/requirements.txt
. Only requires torch, numpy, tqdm and matplotlib. A CUDA GPU is required as well, tested on 1 A100.
Run the file run_all.sh
to reproduce all our simulations for transformer components. The file approximations.py
can also plots the approximations of RELU/MLP covariance.
Expected output is provided in expected_output.txt
These figures show forward and back change in variances for vanilla transformer models.
cd into xavier_signal_propagation_figures
and run the file make_figs.sh
.
It will make xavier figures preln_forward.png
, preln_backward.png
and postln_backward.png
and exit.
These files are already included, delete these .png
files to recreate.
These figures show Unit forward and back change in variances.
Use conda env create -f environment.yml
. Tested on 8x A100 80GB. Same enviroment is also required for Xavier and Pre-training.
Also requires pre-training data from bert_wiki_pretraining/prepare_data.sh
cd into DSLM_signal_propagation_figures
and run the file make_figs.sh
.
It will make xavier figures preln_forward.png
, preln_backward.png
and postln_backward.png
and exit.
These files are already included, delete these .png
files to recreate.
- cd into
bert_wiki_pretraining
- Run the file
prepare_data.sh
to download and process the pre-training dataset. This is Wikipedia from TFDS. - Run the file
run_bert_wiki.sh
to run the pretraining. - Run the files
examples/run_mnli.sh
,examples/run_qqp.sh
,examples/run_race.sh
to run finetuning.
If you have any questions related to the code or the paper, feel free to email Akhil Kedia (akhil.kedia @ samsung.com). If you encounter any problems when using the code, you can open an issue!
Please cite our paper if you find the repo helpful in your work:
@article{DSLM,
author = {Akhil Kedia and
Mohd Abbas Zaidi and
Sushil Khyalia and
Jungho Jung and
Harshith Goka and
Haejun Lee},
title = {Transformers Get Stable: An End-to-End Signal Propagation Theory for
Language Models},
journal = {CoRR},
volume = {abs/2403.09635},
year = {2024},
url = {https://doi.org/10.48550/arXiv.2403.09635},
doi = {10.48550/ARXIV.2403.09635},
eprinttype = {arXiv},
eprint = {2403.09635}
}