Error Occurring at loss.backward() Despite Loss Being Calculable
aidarikako opened this issue · comments
Hello,
I am encountering an issue when running the following code snippet:
CUDA_VISIBLE_DEVICES=0,1,2,3` torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py
--enable_fsdp
--model_name <your_model_directory>
--num_epochs 2
--batch_size_training 16
--micro_batch_size 1
--val_batch_size 8
--lr 2e-5
--num_workers_dataloader 1
--seed 42
--data_path <your_data_directory>
--max_words_dataset 2048
--checkpoint_folder <your_directory_to_save>
--save_with_hf
--warmup_ratio 0.03
--save_epoch_interval 1
--add_token_list ft_datasets/toolken_list_50.json
This results in the following error:
File "/mnt/home/xlh/code/simulated-trial-and-error-main/simulated-trial-and-error-main/llama-recipes/utils/train_utils.py", line 94, in train
loss.backward()
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1112, in unpack_hook
frame.recompute_fn(*args)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1401, in recompute_fn
fn(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 741, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 671, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (4096) must match the existing size (2048) at non-singleton dimension 3. Target sizes: [1, 32, 2048, 4096]. Tensor sizes: [1, 1, 2048, 2048]
The error occurs at loss.backward() but the loss value is computed successfully and can be printed out. I would appreciate any insights or suggestions on possible causes for this error.
Thank you for your help!
Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?
Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?
After thorough investigation, I discovered that the issue actually arose from a line a bit before the one where the error was reported.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
Specifically, after the operation past_key_value.update
, the last dimension of key_states
changed from 2048 to 4096. Upon closer inspection through print statements, this change didn't affect all instances; initially, the dimension remained at 2048, but starting from a certain sample, key_states
turned into 4096 after past_key_value.update
, subsequently causing the error and terminating the program. Therefore, the problem lies within the past_key_value.update
operation.
As a temporary solution, I commented out the section involving past_key_value.update
, and the model fine-tuning proceeded without errors. Do you have any possible insights regarding this bug?