jayelm / gisting

Learning to Compress Prompts with Gist Tokens - https://arxiv.org/abs/2304.08467

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Unable to reproduce LLaMA-7B results when training from scratch

Xiuyu-Li opened this issue · comments

Hi,

I was trying to reproduce the LLaMA-7B with 1 gist token results from scratch following the training instruction in the README. I ran the script below on 4 A100-80GB GPUs:

TAG="train80g"

port=$(shuf -i25000-30000 -n1)

deepspeed --master_port $port --num_gpus=4 --no_local_rank \
    --module src.train \
    +model=llama-7b wandb.tag=$TAG \
    training.deepspeed=ds_configs/stage3.json \
    training.gist.condition=gist \
    training.gist.num_gist_tokens=1

However, the final results after 3 epochs are much lower than the reported ones in the paper. I got seen 51.24, unseen 42.01, human 19.00 for ROUGE-L. I tried training for longer epochs but it didn't help with unseen and human ROUGE-L results. I did not change anything in the training config other than the wandb account.

I also evaluated the 3 provided checkpoints (gist, pos_control, neg_control) and the results are consistent with the paper (< 0.1 difference in terms of ROUGE-L) for all of them, so the evaluation code should function normally. Could you help double check if the above training setup is correct, and do you have any suggestions on how to reproduce LLaMA results in the paper?

Hello, this is unrelated to your issue but I was hoping you could point me in the right direction since I'm an ML beginner trying to leverage this repo for another organization.

Did you try running compress.py on your GPU set up? And if you were successful, did you have to make any changes to the file's code apart from the device_map="auto"/ device_map="balanced" parameter passed into the model initializer? I've tried that as well as quantization but I'm unable to actually load the model prior to using it (running 4 A100-40 GB GPU's)

Thanks.

Hello, this is unrelated to your issue but I was hoping you could point me in the right direction since I'm an ML beginner trying to leverage this repo for another organization.

Did you try running compress.py on your GPU set up? Did you have to make any changes to the file's code apart from the device_map="auto"/ device_map="balanced" parameter passed into the model initializer?

Thanks.

Hi, I tried the example command in the README and it works without an issue:

python -m src.compress --model_name_or_path llama-7b-gist-1-recovered \
    --instruction "Name the top cities in France that should not be missed. Include the best aspects of each place as well."

where llama-7b-gist-1-recovered is the path of the llama checkpoint after applying the llama-7b-gist-1 weight_diff. I did not modify the code at all.

Also, I would recommend you open a separate issue if there's any follow-up / new question.

Hi, give me a few days to look into this. Have you tried training the positive control model? Curious if the issue is just the gist model or all of the models.

Also, are you using the decapoda-research llama checkpoint? Can you check whether it might be due to an incorrect tokenization config with the decapoda models? See

Hello, this is unrelated to your issue but I was hoping you could point me in the right direction since I'm an ML beginner trying to leverage this repo for another organization.
Did you try running compress.py on your GPU set up? Did you have to make any changes to the file's code apart from the device_map="auto"/ device_map="balanced" parameter passed into the model initializer?
Thanks.

Hi, I tried the example command in the README and it works without an issue:

python -m src.compress --model_name_or_path llama-7b-gist-1-recovered \
    --instruction "Name the top cities in France that should not be missed. Include the best aspects of each place as well."

where llama-7b-gist-1-recovered is the path of the llama checkpoint after applying the llama-7b-gist-1 weight_diff. I did not modify the code at all.

Also, I would recommend you open a separate issue if there's any follow-up / new question.

I'm guessing why it worked with no modification was because your GPU specs fit the required ones perfectly haha. Thanks

Hi, give me a few days to look into this. Have you tried training the positive control model? Curious if the issue is just the gist model or all of the models.

Thank you for the reply. I haven't tried training the positive control model but let me do it now. Just another two quick follow-ups:

  1. Is there anything else that needs to be changed than setting training.gist.condition=pos_control to train the positive control model using the script above?
  2. I noticed that although the config specifies a cosine lr scheduler, the logged wandb lr was always unchanged. I am not sure if this is a wandb issue (as I have encountered similar logging issues in the past) or the lr scheduler was truly not working properly.

Also, are you using the decapoda-research llama checkpoint? Can you check whether it might be due to an incorrect tokenization config with the decapoda models? See

This is a good point. All my previous experiments were done using the decapoda-research checkpoint, and I was actually redoing experiments with the official checkpoint today. I'll let you know how it goes.

I also have a question w.r.t. the provided gist checkpoint: it seems like in the generation_config.json, the tokenization looks like this:

{
  "_from_model_config": true,
  "bos_token_id": 0,
  "eos_token_id": 1,
  "pad_token_id": -1,
  "transformers_version": "4.28.0.dev0"
}

While in the decapoda-research checkpoint it is:

{
  "_from_model_config": true, 
  "bos_token_id": 0, 
  "eos_token_id": 1, 
  "pad_token_id": 0, 
  "transformers_version": "4.27.0.dev0"
}

and in the official checkpoint:

{
  "_from_model_config": true,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "pad_token_id": 0,
  "transformers_version": "4.28.0.dev0"
}

Could you explain the differences in your gist checkpoint tokenization config? Thank you.

It seems like there may be some versioning issues here...I think I wrote this code and converted the official Llama checkpoint back when there were still some outstanding tokenization issues in the huggingface repo (e.g. prior to huggingface/transformers#22402); note the pinned transformers version in this codebase's requirements.txt is fb366b9a https://github.com/huggingface/transformers/blob/fb366b9a2a94b38171896f6ba9fb9ae8bffd77af/src/transformers/models/llama/tokenization_llama.py

The Llama checkpoint I used has the following configs:

generation_config.json

{"_from_model_config": true, "bos_token_id": 0, "eos_token_id": 1, "pad_token_id": -1, "transformers_version": "4.27.0.dev0"}

config.json

{"architectures": ["LLaMAForCausalLM"], "bos_token_id": 0, "eos_token_id": 1, "hidden_act": "silu", "hidden_size": 4096, "intermediate_size": 11008, "initializer_range": 0.02, "max_sequence_length": 2048, "model_type": "llama", "num_attention_heads": 32, "num_hidden_layers": 32, "pad_token_id": -1, "rms_norm_eps": 1e-06, "torch_dtype": "float16", "transformers_version": "4.27.0.dev0", "use_cache": true, "vocab_size": 32000, "_name_or_path": "llama-7b"}

which seems to be incorrect, however, in a debug session after loading the LlamaTokenizer from this checkpoint, I get

ipdb> p tokenizer
LlamaTokenizer(name_or_path='llama-7b', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='left', truncation_side='right', special_tokens={})
ipdb> tokenizer.bos_token_id
1
ipdb> tokenizer.eos_token_id
2

so I'm definitely a bit confused here, the pinned version of transformers I used might be overwriting the checkpoint values. In fact I recall having to manually overwrite the generation config here:

logger.warning(
"Overwriting existing generation config due to "
"DeepSpeed bug. If model is not LLAMA, check this."
)
gen_kwargs["generation_config"] = GenerationConfig(
max_length=512,
do_sample=False,
bos_token_id=1,
eos_token_id=2,
pad_token_id=0,
)

I previously attributed this to a DeepSpeed bug (for some reason I remember not seeing this issue without deepspeed) but maybe it's related to tokenization.

Some answers that would help clarify things:

  • Can you confirm you are using the same version of Transformers I am using?
  • Can you check what your bos and eos token ID are set to after loading the tokenizer in the training script when you train the model that observed degraded performance? Is it 1 and 2 (as it should be), or 0 and 1 (as in the decapoda-research checkpoint)? E.g. if you're using a more recent version of transformers I wonder if loading the LlamaTokenizer does not overwrite the bos and eos token id as I show above, but rather loads whatever is in the checkpoint (which, if 0 and 1, is wrong)
  • I suspect this is a tokenization issue, but it might not be. If your pos control training run also exhibits degraded performance, then that would lean in favor of tokenization being the issue. But if it doesn't, then the issue is probably somewhere else, so I'm interested in what you find.

Appreciate your help debugging this!

  1. Is there anything else that needs to be changed than setting training.gist.condition=pos_control to train the positive control model using the script above?

No, that is the only change required!

  1. I noticed that although the config specifies a cosine lr scheduler, the logged wandb lr was always unchanged. I am not sure if this is a wandb issue (as I have encountered similar logging issues in the past) or the lr scheduler was truly not working properly.

Aha, this was also the case in the original Alpaca codebase—it specified a cosine LR scheduler but the LR wasn't actually changed at least according to wandb. I didn't look too closely into this. So even if not entirely correct, this is expected, and I observed this in my experiments.

Thank you for putting the efforts on the detailed response! With regard to your questions:

  1. I was using 4.28.0.dev0, which was installed using requirements.txt. There is also a check_min_version("4.28.0.dev0") in your src/train.py. It should be the same as yours based on the commit number in requirements.txt.
  2. My bos and eos token ID are correct -- they are 1 and 2, albeit the generation_config.json and config.json are wrong. I am not entirely sure, but I think this is because when loading the tokenizer by

    gisting/src/train.py

    Lines 140 to 142 in acd78b4

    tokenizer = LlamaTokenizer.from_pretrained(
    args.model.model_name_or_path, **tokenizer_kwargs
    )

    it will not look into files containing wrong IDs (generation_config.json and config.json), and will initialize bos_token="<s>", eos_token="</s>". These essentially correspond to 1 and 2 in the vocab (see huggingface/transformers#22312). Thus the tokenizer seems to be loaded properly, and your manual overwrite fixes the generation & prediction. This makes the situation quite frustrating as I am not sure if the issue is with the tokenizer.
  3. I am running pos control training with the hf checkpoint that I previously used now. I am also running gist and pos control training using the converted official checkpoint with the correct tokenizer config. I'll let you know how it goes!

I got no luck in running those new experiments:

  • pos control training with the decapoda-research checkpoint got seen 54.04, unseen 43.58, human 20.42.
  • pos control training with the converted official checkpoint got seen 53.77, unseen 43.75, human 18.80.
  • gist training with the converted official checkpoint got seen 51.26, unseen 42.32, human 18.89.

Basically there is no significant difference between the decapoda-research checkpoint and the converted checkpoint, and both gist and pos control experiments have worse results than those reported in the paper. Could you kindly consider rerunning the training on your side to verify if you're able to reproduce the results presented in the paper?

Ah, this could be the reason: I noticed that the python env I used for training is actually different than the python env described in requirements.txt. This is because I benchmarked the trained models using a different machine with a different python installation, and created requirements.txt from the python env I used for benchmarking, rather than the one I used for training.

I downloaded a fresh copy of the repo and reran the gist model training command with the python env below and reproduced the paper results exactly (see end of this post). So it's possible this might be the reason for the discrepancy (perhaps torch 2.0 -> torch 1.13)? If it's not too much trouble, could you try retraining with the following updated env:

Python version 3.10.8

requirements.txt generated from pipreqs: (you may have to install torch 1.13.1 from pytorch.org, and lmk if you have other trouble reproducing this env)

datasets==2.8.0
deepspeed==0.8.3
evaluate==0.4.0
fire==0.5.0
hydra-core==1.3.2
nltk==3.8
numpy==1.23.5
omegaconf==2.3.0
pandas==1.5.3
torch==1.13.1+cu116
torch_xla==1.0
tqdm==4.64.1
git+https://github.com/huggingface/transformers@fb366b9a
wandb==0.13.11

Somewhat confused that versioning could lead to such discrepancies, but if this is indeed the cause, sorry for the confusion! Even if this is not the issue, the fact that the positive control performance is also degraded on your end (and gist model still does well as pos control) indicates that the issue is not with the gist models specifically, but some mismatch with the general llama finetuning setup that we'd stil need to nail down.

Thanks again for your help!


Reproduction:

image

my job timed out before it finished eval on the unseen split at 3k steps, but I observed

  • seen 57.8,
  • unseen 46.2 (at 1.5k steps)
  • human 23.9

which is in line with the paper.


Also, on the off chance that we are still using different llama-7b checkpoints somehow, here are the shasums of the checkpoint I am using

da1ef16b048f5d759149c7472b6e4858f19bc287  config.json
e37700792e352efa5dba3c83478ff31d60b50e02  config_original.json
a2badb096591d9e0ca3f1c1151a4ac85063a6325  generation_config.json
0096c5b09f6403a1a1a211a50090d37d24acbd12  pytorch_model-00001-of-00033.bin
83af20eb30fb7b52e62e80a07cb8a7de2c35612e  pytorch_model-00002-of-00033.bin
1b73432827a81a77a5548652ee6fe1f127d53aeb  pytorch_model-00003-of-00033.bin
98cde125792b2cf172581f26db3e3583765bd07f  pytorch_model-00004-of-00033.bin
3eecc843ced51c6b8f24ba5e81a33c28c0028abb  pytorch_model-00005-of-00033.bin
1cbd688ddd5411443fad88a6ac32656f74ba66d8  pytorch_model-00006-of-00033.bin
021af212e30f8294e24996eb9c3710ab5963d93f  pytorch_model-00007-of-00033.bin
099abf28dc554ac030db312e518335f47dbc630d  pytorch_model-00008-of-00033.bin
4d9af68d697c3a9a6624d69f32dfc5eaa4f02b89  pytorch_model-00009-of-00033.bin
7358c017a646ac18331a1b318a9bea6f87c505e3  pytorch_model-00010-of-00033.bin
eae98416c3f8ff31a532ac1b017183c610c07916  pytorch_model-00011-of-00033.bin
f4fd5e8087ddb9b46ecd729a7e3970711233b2d3  pytorch_model-00012-of-00033.bin
f3a296be997c73e5f39c7b6bc5b534840a8ed8f0  pytorch_model-00013-of-00033.bin
e98055c932e250352b0660b9c47889af2829a645  pytorch_model-00014-of-00033.bin
072af56043c00bf9c1d30b84d79612feaf4660b4  pytorch_model-00015-of-00033.bin
8ac62fb6884080851c3cfffc151b8aa13b6cb919  pytorch_model-00016-of-00033.bin
e3679a8b6e64dff5b8163e1428658a4ed1428e67  pytorch_model-00017-of-00033.bin
2fb9c4413cbcfc6be5b94229b38c40781f43a41b  pytorch_model-00018-of-00033.bin
084b9637913276c7ce9f8d92f1a80a5b89eb9146  pytorch_model-00019-of-00033.bin
7158b438698dc10235212bf609fa5ccb74145109  pytorch_model-00020-of-00033.bin
bfbc357f4c5db9e07ef0baa29563c648a17d463d  pytorch_model-00021-of-00033.bin
cdcf5c860eeab5632113dd39260e06ae5aa8e301  pytorch_model-00022-of-00033.bin
39a24a8576ea63fc238f26dd6dc599a3e9efcd11  pytorch_model-00023-of-00033.bin
bd8222ced6951461781c5fb45c7c7c63ae0a6e6b  pytorch_model-00024-of-00033.bin
2168e992c91dcadbc5ba8c9de9f1f5c302971ae4  pytorch_model-00025-of-00033.bin
7e657df1760b10db893f76835ec524c633f1da53  pytorch_model-00026-of-00033.bin
aed1f083e57dc8e711ee3f50b06c2d36a466bc18  pytorch_model-00027-of-00033.bin
8e0bcdb0d1b974fe9967c638e4c24d6172a09402  pytorch_model-00028-of-00033.bin
c77a739a98c877616d7d5e65991cc893d5988bc2  pytorch_model-00029-of-00033.bin
c4b9df5137840cafafd26de980f55bf0f907c281  pytorch_model-00030-of-00033.bin
37022a5d1d8e71a0b5a84f6e51d3f1454ec48beb  pytorch_model-00031-of-00033.bin
9e789f6bbe3fb919c9eee6f6ad47818f1747ea74  pytorch_model-00032-of-00033.bin
1444445a36623906176a8e636dc1c51763f3c05d  pytorch_model-00033-of-00033.bin
16b8de6e0ac1e2b75d1032c8315c3df2a115f0d0  pytorch_model.bin.index.json
bf21a9e8fbc5a3846fb05b4fa0859e0917b2202f  special_tokens_map.json
39d4ad6eced3ea31d27f762010a5ba67deb8b3ff  tokenizer_config.json
7a4e789beca293352e60b6fad5eef1908070cee0  tokenizer.model

Since you've tried a few llama models, it seems more likely that the discrepancy is due to the codebase, but just in case, you can check these hashes. I can also email you the llama checkpoint directly if you'd like—you can reach out to my academic email (available on my website) and we can sort something out. But it's probably worth synchronizing python envs first.

Thank you for the thorough investigation. Let me run experiments with your newly provided environment first. If that doesn't work I'll dig into the checkpoint.

Just a side note, I was actually running most of my experiments using 8 A100-40GB gpus since I have relatively limited A100-80GB resources. I changed per_device_train_batch_size and per_device_eval_batch_size from 4 to 2, so the batch size would still be 128 and theoretically there shouldn't be any difference between this setup and the suggested setup. Nonetheless, I will verify if I can reproduce the results with any of the setups.

Switching to the provided python envs indeed fixed the training! I got seen 57.39, unseen 46.53, human 25.83 for the 1 gist token setting trained using the converted official checkpoint. These are pretty much consistent with the paper results.

It requires a bit more changes to the requirements.txt to set up the environment. torch==1.13.1+cu116 could not be found by pip, and I first installed pytorch via

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia

Then slightly updated the requirements.txt to

accelerate==0.18.0
datasets==2.8.0
deepspeed==0.8.3
evaluate==0.4.0
fire==0.5.0
hydra-core==1.3.2
nltk==3.8
numpy==1.23.5
omegaconf==2.3.0
openai==0.27.2
rouge_score==0.1.2
sentencepiece==0.1.98
pandas==1.5.3
tqdm==4.64.1
git+https://github.com/huggingface/transformers@fb366b9a
wandb==0.13.11

Using this environment could successfully reproduce the gist results when training from scratch on my end. I think it's worth further investigating the version mismatch of which packages caused the performance degradation in my previous experiments, but this can serve as a temporary solution for now. Thanks again for your time in looking into this!

I ran a few more experiments and located the issue -- the performance discrepancy is not caused by the torch version, but the deepspeed version. I was still able to reproduce the paper results after changing the torch version back to 2.0.0, so I did some further debugging and realized that since I did not create a new env from scratch previously, my older experiments were run with deepspeed 0.9.4 instead of deepspeed 0.8.3. Somehow pip install -r requirements.txt did not overwrite the existing deepspeed installation on both of my A100-40G servers and A100-80G servers...

I then conducted experiments using the same env with only the deepspeed versions being different, and verified that only the 0.9.4 one resulted in the worse performance that I encountered earlier in this issue.

I haven't figured out why the up-to-date deepspeed could cause such performance degradation. And I apologize for my oversight (again!) -- your current listed dependencies should be all good.

Apologies for the late response, I've been on vacation 😄

Glad to hear the results replicate, and thanks so much for looking so closely into this! This will help a lot with reproducibility.

Super weird that the DeepSpeed version leads to such drastic performance differences. I'll make a note of this in the repo when I get back from vacation.

Hi @Xiuyu-Li, could you please share your reproduced results? I also tried to replicate Llama-7b from scratch, and this is my results:

wandb: Run summary:
wandb:                   eval/gen_len 40.231
wandb:             eval/human_gen_len 88.93254
wandb:                eval/human_loss 2.18223
wandb:              eval/human_rouge1 28.2652
wandb:              eval/human_rouge2 10.6288
wandb:              eval/human_rougeL 23.9176
wandb:           eval/human_rougeLsum 26.8962
wandb:             eval/human_runtime 10497.9282
wandb:  eval/human_samples_per_second 0.024
wandb:    eval/human_steps_per_second 0.002
wandb:                      eval/loss 1.25835
wandb:                    eval/rouge1 52.6178
wandb:                    eval/rouge2 29.9502
wandb:                    eval/rougeL 46.6212
wandb:                 eval/rougeLsum 49.1622
wandb:                   eval/runtime 5599.2181
wandb:        eval/samples_per_second 0.179
wandb:              eval/seen_gen_len 22.35
wandb:                 eval/seen_loss 0.58948
wandb:               eval/seen_rouge1 59.6561
wandb:               eval/seen_rouge2 34.5262
wandb:               eval/seen_rougeL 57.779
wandb:            eval/seen_rougeLsum 58.4709
wandb:              eval/seen_runtime 2058.5075
wandb:   eval/seen_samples_per_second 0.486
wandb:     eval/seen_steps_per_second 0.031
wandb:          eval/steps_per_second 0.011
wandb:            eval/unseen_gen_len 40.231
wandb:               eval/unseen_loss 1.25835
wandb:             eval/unseen_rouge1 52.6178
wandb:             eval/unseen_rouge2 29.9502
wandb:             eval/unseen_rougeL 46.6212
wandb:          eval/unseen_rougeLsum 49.1622
wandb:            eval/unseen_runtime 5602.1366
wandb: eval/unseen_samples_per_second 0.179
wandb:   eval/unseen_steps_per_second 0.011
wandb:                    train/epoch 3.0
wandb:              train/global_step 3000
wandb:            train/learning_rate 2e-05
wandb:                     train/loss 0.3331
wandb:               train/total_flos 85213543661568.0
wandb:               train/train_loss 0.6486
wandb:            train/train_runtime 69016.7382
wandb: train/train_samples_per_second 5.567
wandb:   train/train_steps_per_second 0.043

I train the model on 4*A100(80G) which takes about 24 hours to finish training and I use deepspeed==0.8.3 and torch==1.13.1+cu116.

Hi @Hannibal046 , the results you just reported align with the paper's, namely these lines:

wandb:              eval/human_rougeL 23.9176
wandb:               eval/seen_rougeL 57.779
wandb:             eval/unseen_rougeL 46.6212

and the table:

image

so it seems like you've successfully reproduced the results. Did you have other questions?

Hi, @jayelm
I don't have further questions and I appreciate the great idea the great code, from which I learned a lot!

Hi @jayelm, sorry to bother.
I found a weird chart in the wandb panel with respect to train/learning_rate.
image
I expected it to be a cosine shape with warmup considering that config:
image

Hi @Hannibal046, see this comment earlier in the thread:

3. I noticed that although the config specifies a cosine lr scheduler, the logged wandb lr was always unchanged. I am not sure if this is a wandb issue (as I have encountered similar logging issues in the past) or the lr scheduler was truly not working properly.

Aha, this was also the case in the original Alpaca codebase—it specified a cosine LR scheduler but the LR wasn't actually changed at least according to wandb. I didn't look too closely into this. So even if not entirely correct, this is expected, and I observed this in my experiments.

aha, got it! my mistake and appreciate your help!

Also, if I would like to help migrate the current code base to the latest deepspeed and huggingface transformers, what would you suggest to take care of? currently, these are in the to-do list:

  • modification of rope function
  • adjust generation
  • debug with the latest deepspeed

Hi @Hannibal046 , those are the things I'm aware of! First step would be migrating to a newer huggingface version while keeping the DeepSpeed version the same—you should be able to exactly reproduce the results. Second you can try debugging what's going on with DeepSpeed, I really have no clue and haven't had time to look myself.

If you do figure out either of these, please let me know and/or submit a PR, would be really helpful!