lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Command to run train_flatT5.py

samarthsarin opened this issue · comments

What is the argument and command line for running the fine tuning code of Flan T5?

Hi this is the commands we use on a 8 GPU cluster, please fill data_path with your data, e.g. playground/data/dummy.json. The script will preprocessed data and store it in preprocessed_path so that future runs can directly load from it. You can also specify a path you like. let me know if there is any issue:
torchrun --nnodes=1 --nproc_per_node=8 --master_port=9778 fastchat/train/train_flant5.py --model_name_or_path google/flan-t5-xl --data_path data_path --bf16 True --output_dir ./checkpoints_flant5_3b --num_train_epochs 3 --per_device_train_batch_size 1 --per_device_eval_batch_size 1 --gradient_accumulation_steps 4 --evaluation_strategy "no" --save_strategy "steps" --save_steps 300 --save_total_limit 1 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type "cosine" --logging_steps 1 --fsdp "full_shard auto_wrap" --fsdp_transformer_layer_cls_to_wrap T5Block --tf32 True --model_max_length 2048 --preprocessed_path ./preprocessed_data/processed.json --gradient_checkpointing True

Hi, @DachengLi1

Thanks for providing the commands for fine-tuning FlanT5.

I used the commands you provided, the training phase is fine.

But when it save out the model, it will encounter CUDA error.

I find there is a comment "# potential bug for T5 model" in the code:

# potential bug for T5 model

Is it an already known issue for saving out the model?

Traceback (most recent call last):
  File "xxx/FastChat/fastchat/train/train_flant5.py", line 428, in <module>
    train()
  File "xxx/FastChat/fastchat/train/train_flant5.py", line 424, in train
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
  File "xxx/FastChat/fastchat/train/train_flant5.py", line 78, in safe_save_model_for_hf_trainer
    cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
  File "xxx/FastChat/fastchat/train/train_flant5.py", line 78, in <dictcomp>
    cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
RuntimeError: CUDA error: invalid argument

@gary9630 I think this is likely due to a PyTorch FSDP bug that causes OOM when saving. Are you able to save intermediate checkpoints (PyTorch will continue training, even if the intermediate checkpoint saving cause OOM, you may need to actually monitor a saving step to check this)? If that happens, it can be solved by the solution mentioned here.

And thanks for checking the comment. There is another (smaller) issue with Flan-t5+FSDP saving. After you resolve this issue and can save the model correctly, use this function to clean up the weight path. Under the hood, it seems FSDP has some trouble when saving shared weights. And this function manually correct shared weights(T5 encoder embedding, T5 decoder embedding and shared embedding are actually the same thing, but has three names). Then you should be able to load the model!

@DachengLi1 Thank you for such a fast response.

Regarding the OOM issue, I've already solved by the method you provided when fine-tuning Vicuna's model few days ago.

It is really helpful and saving my day, appreciate all your hard work and excellent Github community.

I am not sure it is correct or not, but I modified the code:

cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}

to

cpu_state_dict = {key: value for key, value in state_dict.items()}

I can get the model output successfully:

image

But when I try to load the model, it can not be loaded with the error messages:

OSError: Unable to load weights from pytorch checkpoint file for 'xxx/Llama_models/Finetune_t5_3B_FT/pytorch_model-00001-of-00002.bin' at 
'xxx/Llama_models/Finetune_t5_3B_FT/pytorch_model-00001-of-00002.bin'. If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True.

I am not sure is this error related to the second part you mentioned that I have to clean the weight path.

For the function you provided, I just need to give the current ckpt path and it will do the magic cleaning job for me, right?

Again, thank you for such great work, I really enjoy it!

Nice to hear that! Yes, this is exactly the second issue. Can you try to call the function above (You may want to copy the weight in case anything unexpected happens. This function will rewrite some of the weight in the path you provide)?

I can now successfully load the fine-tuned model of FlanT5 now! Thank you @DachengLi1

For the record, I've done following things:

  1. Modify the code from key: value.cpu() to key: value in: https://github.com/lm-sys/FastChat/blob/cad445eb510989a0adc376e8c56eb1c1f6f6fee0/fastchat/train/train_flant5.py#LL78C1-L78C1
  2. After the fine-tuned FlanT5 model is output in the given model path, run the script with cleaning function provided in:
    def clean_flant5_ckpt(ckpt_path):

Then, you should be able load the fine-tuned FlanT5 model for inferencing.

@gary9630 Glad it helps! Thanks for the summarization, we will probably redirect any other related issues to this solution. Closing this issue now, feel free to re-open it if you find other issues.

@DachengLi1 could you send a PR and add some docs for using the T5 training scripts?

After cleaning the checkpoint, when I try to use the weights I get the following error:

ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of
AutoModel: AutoModelForCausalLM.
Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig,
BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, CodeGenConfig, CpmAntConfig, CTRLConfig,
Data2VecTextConfig, ElectraConfig, ErnieConfig, GitConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig,
GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, LlamaConfig, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig,
MvpConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, ReformerConfig,
RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, Speech2Text2Config, TransfoXLConfig,
TrOCRConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig.

Does someone know how to solve it?

After cleaning the checkpoint, when I try to use the weights I get the following error:

ValueError: Unrecognized configuration class <class 'transformers.models.t5.configuration_t5.T5Config'> for this kind of AutoModel: AutoModelForCausalLM. Model type should be one of BartConfig, BertConfig, BertGenerationConfig, BigBirdConfig, BigBirdPegasusConfig, BioGptConfig, BlenderbotConfig, BlenderbotSmallConfig, BloomConfig, CamembertConfig, CodeGenConfig, CpmAntConfig, CTRLConfig, Data2VecTextConfig, ElectraConfig, ErnieConfig, GitConfig, GPT2Config, GPT2Config, GPTBigCodeConfig, GPTNeoConfig, GPTNeoXConfig, GPTNeoXJapaneseConfig, GPTJConfig, LlamaConfig, MarianConfig, MBartConfig, MegaConfig, MegatronBertConfig, MvpConfig, OpenAIGPTConfig, OPTConfig, PegasusConfig, PLBartConfig, ProphetNetConfig, QDQBertConfig, ReformerConfig, RemBertConfig, RobertaConfig, RobertaPreLayerNormConfig, RoCBertConfig, RoFormerConfig, Speech2Text2Config, TransfoXLConfig, TrOCRConfig, XGLMConfig, XLMConfig, XLMProphetNetConfig, XLMRobertaConfig, XLMRobertaXLConfig, XLNetConfig, XmodConfig.

Does someone know how to solve it?

I guess it is the folder naming you used to save the model weight does not contain "t5" string.

It should be OK when you rename the model path.

For more details, I encountered this error before and based on the code I traced for inferencing:

  1. When you called inference function, it will use load_model function to determine which model and tokenizer it should use:
    model, tokenizer = load_model(
  2. The load_model function use model_adapter to decide which config should be used for inferencing, and for t5 model case, it requires model path containing "t5":
    def match(self, model_path: str):