Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.

Home Page:https://lightning.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

OOM Error: CUDA out of memory when finetuning llama3-8b

zhaosheng-thu opened this issue · comments

When I finetune Llama3-8b by finetune/lora.py, OOM occured.
My training and dataset parameters:

The parameters and the config
>   --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B \
>   --precision 'bf16-true' \
>   --train.global_batch_size 8 \
>   --train.max_seq_length 2048 \
>   --data JSON \
>   --data.prompt_style 'llama3' \
>   --data.json_path /root/szhao/ES-Lora/litllama/ExTES/ExTES.json \
>   --data.val_split_fraction 0.1 \
>   --data.mask_prompt True \
>   --out_dir out/llama3-esconv-test

In the Command, the prompt_style 'llama3' is defined by myself.
But I encountered the Error as following:

The Error showed in terminal {'checkpoint_dir': PosixPath('checkpoints/meta-llama/Meta-Llama-3-8B'), 'data': JSON(json_path=PosixPath('/root/szhao/ES-Lora/litllama/ExTES/ExTES.json'), mask_prompt=True, val_split_fraction=0.1, prompt_style=, ignore_index=-100, seed=42, num_workers=4), 'devices': 1, 'eval': EvalArgs(interval=100, max_new_tokens=100, max_iters=100), 'logger_name': 'csv', 'lora_alpha': 16, 'lora_dropout': 0.05, 'lora_head': False, 'lora_key': False, 'lora_mlp': False, 'lora_projection': False, 'lora_query': True, 'lora_r': 8, 'lora_value': True, 'out_dir': PosixPath('out/llama3-esconv-test'), 'precision': 'bf16-true', 'quantize': None, 'seed': 1337, 'train': TrainArgs(save_interval=1000, log_interval=1, global_batch_size=8, micro_batch_size=1, lr_warmup_steps=100, lr_warmup_fraction=None, epochs=5, max_tokens=None, max_steps=None, max_seq_length=2048, tie_embeddings=None, learning_rate=0.0003, weight_decay=0.02, beta1=0.9, beta2=0.95, max_norm=None, min_lr=6e-05)} Seed set to 1337 Number of trainable parameters: 3,407,872 Number of non-trainable parameters: 8,030,261,248 The longest sequence length in the train data is 1923, the model's maximum sequence length is 1923 and context length is 8192 Validating ... Epoch 1 | iter 1 step 0 | loss train: 2.891, val: n/a | iter time: 956.90 ms Epoch 1 | iter 2 step 0 | loss train: 2.950, val: n/a | iter time: 524.16 ms Epoch 1 | iter 3 step 0 | loss train: 3.036, val: n/a | iter time: 547.41 ms Epoch 1 | iter 4 step 0 | loss train: 2.932, val: n/a | iter time: 653.34 ms Epoch 1 | iter 5 step 0 | loss train: 2.988, val: n/a | iter time: 395.61 ms Epoch 1 | iter 6 step 0 | loss train: 3.014, val: n/a | iter time: 516.08 ms Epoch 1 | iter 7 step 0 | loss train: 3.029, val: n/a | iter time: 741.14 ms Epoch 1 | iter 8 step 1 | loss train: 3.025, val: n/a | iter time: 513.20 ms (step) Epoch 1 | iter 9 step 1 | loss train: 3.058, val: n/a | iter time: 645.27 ms Epoch 1 | iter 10 step 1 | loss train: 3.028, val: n/a | iter time: 693.90 ms Epoch 1 | iter 11 step 1 | loss train: 2.986, val: n/a | iter time: 656.02 ms Epoch 1 | iter 12 step 1 | loss train: 2.994, val: n/a | iter time: 643.70 ms Epoch 1 | iter 13 step 1 | loss train: 2.952, val: n/a | iter time: 469.85 ms Epoch 1 | iter 14 step 1 | loss train: 2.910, val: n/a | iter time: 649.19 ms Epoch 1 | iter 15 step 1 | loss train: 2.868, val: n/a | iter time: 500.48 ms Epoch 1 | iter 16 step 2 | loss train: 2.869, val: n/a | iter time: 547.06 ms (step) Epoch 1 | iter 17 step 2 | loss train: 2.841, val: n/a | iter time: 638.36 ms Epoch 1 | iter 18 step 2 | loss train: 2.899, val: n/a | iter time: 398.46 ms Epoch 1 | iter 19 step 2 | loss train: 2.903, val: n/a | iter time: 649.55 ms Epoch 1 | iter 20 step 2 | loss train: 2.925, val: n/a | iter time: 520.43 ms Epoch 1 | iter 21 step 2 | loss train: 2.927, val: n/a | iter time: 689.26 ms Epoch 1 | iter 22 step 2 | loss train: 2.942, val: n/a | iter time: 525.35 ms Epoch 1 | iter 23 step 2 | loss train: 2.933, val: n/a | iter time: 462.21 ms Epoch 1 | iter 24 step 3 | loss train: 2.916, val: n/a | iter time: 654.28 ms (step) Epoch 1 | iter 25 step 3 | loss train: 2.930, val: n/a | iter time: 537.02 ms Epoch 1 | iter 26 step 3 | loss train: 2.911, val: n/a | iter time: 476.17 ms Epoch 1 | iter 27 step 3 | loss train: 2.913, val: n/a | iter time: 545.19 ms Epoch 1 | iter 28 step 3 | loss train: 2.882, val: n/a | iter time: 528.47 ms Epoch 1 | iter 29 step 3 | loss train: 2.921, val: n/a | iter time: 463.99 ms Epoch 1 | iter 30 step 3 | loss train: 2.899, val: n/a | iter time: 484.63 ms Epoch 1 | iter 31 step 3 | loss train: 2.927, val: n/a | iter time: 390.68 ms Epoch 1 | iter 32 step 4 | loss train: 2.922, val: n/a | iter time: 691.96 ms (step) Epoch 1 | iter 33 step 4 | loss train: 2.919, val: n/a | iter time: 461.07 ms Epoch 1 | iter 34 step 4 | loss train: 2.867, val: n/a | iter time: 690.43 ms Epoch 1 | iter 35 step 4 | loss train: 2.828, val: n/a | iter time: 540.38 ms Traceback (most recent call last): File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 432, in CLI(setup) File "/root/szhao/ES-Lora/litgpt/litgpt/utils.py", line 412, in CLI return CLI(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/jsonargparse/_cli.py", line 96, in CLI return _run_component(components, cfg_init) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/jsonargparse/_cli.py", line 196, in _run_component return component(**cfg) File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 143, in setup fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 866, in launch return self._wrap_and_launch(function, self, *args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 952, in _wrap_and_launch return to_run(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/fabric.py", line 957, in _wrap_with_setup return to_run(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 196, in main fit( File "/root/szhao/ES-Lora/litgpt/litgpt/finetune/lora.py", line 276, in fit logits = model(input_ids, lm_head_chunk_size=128) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/lightning/fabric/wrappers.py", line 143, in forward output = self._forward_module(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/lora.py", line 545, in forward x = block(x, cos, sin, mask, input_pos) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/model.py", line 187, in forward x = self.mlp(self.norm_2(x)) + x File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/model.py", line 311, in forward x_fc_1 = self.fc_1(x) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/szhao/ES-Lora/litgpt/litgpt/lora.py", line 168, in forward pretrained = self.linear(x) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/root/anaconda3/envs/ligpt/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 48.00 MiB. GPU

I find it so weird because formally I have finetune llama2-7b by the lit-llama repository on the same dataset with almost same train config, at that time everything went smoothly. Can you help me? Thanks.

Hm, it could be related to the slightly larger size.

llama2-7b by the lit-llama

I think llama 2 is not supported by lit-llama. Do you perhaps meant llama 7B in lit-llama or llama 2 7B in LitGPT?

If you meant lit-llama, I am curious, does the 7B Llama 2 model work for you in LitGPT?

In any case, you could perhaps try QLoRA or a smaller sequence length to make it work.

With --quantize bnb.nf4, I am able to fine-tune the Llama 3-8B without any problem on a single A10 GPU.

Thanks for all the help. I found that the OOM error vanishes when I choose a smaller max-seq-length. I believe it's because my dataset samples are too long, leading to OOM. When I tried Lora with the Alpaca-2k dataset, it consumed 20.5GB of memory. When I used my dataset without limiting max_seq_length, it would OOM regardless of whether I used --quantize bnb.nf4 or not. The issue was resolved when I limited --max-seq-length 512.