LiyuanLucasLiu / Transformer-Clinic

Understanding the Difficulty of Training Transformers

Home Page:https://arxiv.org/abs/2004.08249

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Post-LN with 12-12 is trained ok, but 12-3 diverge

ZhenYangIACAS opened this issue · comments

Hi, As we expect, the model with more transformer layers is easier to diverge during training. However, we find that the model with 12 encoder layers and 12 decoder layers is trained ok, but the model with 12 encoders layers and 3 decoder layers diverged. Have you found this result in your experiments? Thank you

Thanks for asking, wondering whether you can share more details about the phenomenon : -)

I guess the phenomenon you mentioned is that: 12L-12L Post-LN converged, but 12L-3L PostLN-diverged. I think our current theory can only provides guidances like deep models are more unstable and are more likely to diverge. However, as to case-by-case converge/diverge predictions, I believe randomness contributes and more research need to be conducted (e.g., in our experiments, 12L-12L Post-LN diverges in WMT'14 En-De, but it seems to converge in your settings).

On the other hand, I believe Admin should converge well in both situations.

Thanks for your answering. In our experiments, admin-12-12 converges, but admin-12-1, admin-12-2 diverge.

Emmm, wondering whether you can provide more detailed information about the setting?

On WMT'14 En-De, I trained Admin-12-2 for 10 epoches, and it does not seem to have any problems (its dev ppl @ 10 epoch is 5.80).

The command I use is (some path need to be changed to your settings):

CUDA_VISIBLE_DEVICES=4 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 1 \
--save-dir cps/wmt-admin-12-2 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive-profiling --fp16 --fp16-scale-window 256 \
--encoder-layers 12 --decoder-layers 2 \
--threshold-loss-scale 0.03125 

CUDA_VISIBLE_DEVICES=4,3 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 2 \
--save-dir cps/wmt-admin-12-2 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 12 --decoder-layers 2 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-12-2.log

I also trained Admin-12-1 for 10 epoches, and it does not seem to have any problems (its dev ppl @ 10 epoch is 7.38).

The command I use is (some path need to be changed to your settings):

CUDA_VISIBLE_DEVICES=4 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 1 \
--save-dir cps/wmt-admin-12-1 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive-profiling --fp16 --fp16-scale-window 256 \
--encoder-layers 12 --decoder-layers 1 \
--threshold-loss-scale 0.03125 

CUDA_VISIBLE_DEVICES=4,3 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 2 \
--save-dir cps/wmt-admin-12-1 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 12 --decoder-layers 1 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-12-1.log

The script we used to train 18-x-admin is as following. The model with 18-1, 18-2 both diverge.

export CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7'
GPUID=1
TOKEN_NUMBER=4096
UPDATE_FREQUENCE=1

for lnum in 1 2 3 4 6
do
python ../fairseq/train.py
../data-bin/wmt14_en_de_joined_dict/ -s en -t de
--arch transformer_wmt_en_de --share-all-embeddings
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0
--lr-scheduler inverse_sqrt --max-update 500000
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09
--criterion label_smoothed_cross_entropy --label-smoothing 0.1
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1
--max-tokens $TOKEN_NUMBER --update-freq $UPDATE_FREQUENCE
--save-dir wmt14ende/wmt-admin-18-${lnum}l --restore-file x.pt --seed 1111
--user-dir ../radam_fairseq --log-format simple --log-interval 500
--init-type adaptive-profiling
--encoder-layers 18 --decoder-layers $lnum
--threshold-loss-scale 0.03125

python ../fairseq/train.py \
  ../data-bin/wmt14_en_de_joined_dict/ -s en -t de \
  --arch transformer_wmt_en_de --share-all-embeddings \
  --optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
  --lr-scheduler inverse_sqrt --max-update 500000 \
  --warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
  --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
  --weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
  --max-tokens $TOKEN_NUMBER --update-freq $UPDATE_FREQUENCE \
  --save-dir wmt14ende/wmt-admin-18-${lnum}l --restore-file x.pt --seed 1111 \
  --user-dir ../radam_fairseq --log-format simple --log-interval 500 \
  --init-type adaptive  \
  --encoder-layers 18 --decoder-layers $lnum \
  --threshold-loss-scale 0.03125 | tee ./wmt14ende/log/loss_admin-18-${lnum}l.log

bash eval_wmt_en-de.sh wmt14ende/wmt-admin-18-${lnum}l $GPUID
mv profile.ratio.init wmt14ende/wmt-admin-18-${lnum}l/

done

no worries, I guess Admin-12-X also converges well in your experiments --- it verifies our intuition that deeper models are less stable : -)

I'm running some experiments on Admin 18-X, and will get back to you soon.

Emmm, I'm not sure why, but I did not meet any problems when training Admin-18-1/2 myself.

On WMT'14 En-De, I trained Admin-18-2 for 10 epoches. its dev ppl @ 10 epoch is 5.74, better than 5.80, Admin-12-2@10 epoch.

The command I use is (some path need to be changed to your settings, both 8k and 16k warmup performs well):

CUDA_VISIBLE_DEVICES=3 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 16000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-1 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive-profiling --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 1 \
--threshold-loss-scale 0.03125 

CUDA_VISIBLE_DEVICES=3 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 16000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-1 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 1 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-18-1.log

CUDA_VISIBLE_DEVICES=6 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-1-8k --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 1 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-18-1-8k.log

I also trained Admin-18-1 for 10 epoches. its dev ppl @ 10 epoch is 7.23, better than 7.38, the dev ppl of Admin-12-1.

The command I use is (some path need to be changed to your settings, both 8k and 16k warmup performs well):

CUDA_VISIBLE_DEVICES=4 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 16000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-2 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive-profiling --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 2 \
--threshold-loss-scale 0.03125 

CUDA_VISIBLE_DEVICES=4 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 16000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-2 --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 2 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-18-2.log

CUDA_VISIBLE_DEVICES=5 fairseq-train \
./data/wmt14_en_de_joined_dict/ -s en -t de \
--arch transformer_wmt_en_de --share-all-embeddings \
--optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --max-update 500000 \
--warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
--weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
--max-tokens 8192 --update-freq 4 \
--save-dir cps/wmt-admin-18-2-8k --restore-file x.pt --seed 1111 \
--user-dir ../radam_fairseq --log-format simple --log-interval 500 \
--init-type adaptive --fp16 --fp16-scale-window 256 \
--encoder-layers 18 --decoder-layers 2 \
--threshold-loss-scale 0.03125 | tee ./log/app_wmt/loss_admin-18-2-8k.log

Wondering whether you can plot the training curve for your setting. It is also possible that the unstable-ness happens in the late stage of training (since I stopped the training @ 10 epochs).

It is so strange. If I modify the scripts as following, it runs ok.

1 GPUID=1
2 TOKEN_NUMBER=8192
3 UPDATE_FREQUENCE=1
4
5
6 for lnum in 1 2 3 4 6
7 do
8 export CUDA_VISIBLE_DEVICES='0'
9 python ../fairseq/train.py
10 ../data-bin/wmt14_en_de_joined_dict/ -s en -t de
11 --arch transformer_wmt_en_de --share-all-embeddings
12 --optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0
13 --lr-scheduler inverse_sqrt --max-update 500000
14 --warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09
15 --criterion label_smoothed_cross_entropy --label-smoothing 0.1
16 --weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1
17 --max-tokens $TOKEN_NUMBER --update-freq 1
18 --save-dir wmt14ende/wmt-admin-18-${lnum}l --restore-file x.pt --seed 1111
19 --user-dir ../radam_fairseq --log-format simple --log-interval 500
20 --init-type adaptive-profiling
21 --encoder-layers 18 --decoder-layers $lnum
22 --threshold-loss-scale 0.03125
23
24 export CUDA_VISIBLE_DEVICES='0,1,2,3'
25 python ../fairseq/train.py
26 ../data-bin/wmt14_en_de_joined_dict/ -s en -t de
27 --arch transformer_wmt_en_de --share-all-embeddings
28 --optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0
29 --lr-scheduler inverse_sqrt --max-update 500000
30 --warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09
31 --criterion label_smoothed_cross_entropy --label-smoothing 0.1
32 --weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1
33 --max-tokens $TOKEN_NUMBER --update-freq 2
34 --save-dir wmt14ende/wmt-admin-18-${lnum}l --restore-file x.pt --seed 1111
35 --user-dir ../radam_fairseq --log-format simple --log-interval 500
36 --init-type adaptive
37 --encoder-layers 18 --decoder-layers $lnum
38 --threshold-loss-scale 0.03125 | tee ./wmt14ende/log/loss_admin-18-${lnum}l.log
39
40 bash eval_wmt_en-de.sh wmt14ende/wmt-admin-18-${lnum}l $GPUID
41 mv profile.ratio.init wmt14ende/wmt-admin-18-${lnum}l/
42 done

Glad to hear it works now.
As to the mechanism behind, I suspect it is related to the relationship between lr and batch size. We noticed larger batch can afford a larger lr. We are working on some projects about this, stay tuned : -)