microsoft / simulated-trial-and-error

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Error Occurring at loss.backward() Despite Loss Being Calculable

aidarikako opened this issue · comments

Hello,

I am encountering an issue when running the following code snippet:

CUDA_VISIBLE_DEVICES=0,1,2,3` torchrun --nnodes 1 --nproc_per_node 4 llama_finetuning.py
--enable_fsdp
--model_name <your_model_directory>
--num_epochs 2
--batch_size_training 16
--micro_batch_size 1
--val_batch_size 8
--lr 2e-5
--num_workers_dataloader 1
--seed 42
--data_path <your_data_directory>
--max_words_dataset 2048
--checkpoint_folder <your_directory_to_save>
--save_with_hf
--warmup_ratio 0.03
--save_epoch_interval 1
--add_token_list ft_datasets/toolken_list_50.json

This results in the following error:

File "/mnt/home/xlh/code/simulated-trial-and-error-main/simulated-trial-and-error-main/llama-recipes/utils/train_utils.py", line 94, in train
loss.backward()
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1112, in unpack_hook
frame.recompute_fn(*args)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 1401, in recompute_fn
fn(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 741, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/miniconda3/envs/ste/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 671, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
RuntimeError: The expanded size of the tensor (4096) must match the existing size (2048) at non-singleton dimension 3. Target sizes: [1, 32, 2048, 4096]. Tensor sizes: [1, 1, 2048, 2048]

The error occurs at loss.backward() but the loss value is computed successfully and can be printed out. I would appreciate any insights or suggestions on possible causes for this error.

Thank you for your help!

Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?

Sorry for the late reply. I ran the code on my side and did not observe this issue. Seems to be some kind of mismatch on the tensor shapes; have you tried to trace back from the error tensor?

After thorough investigation, I discovered that the issue actually arose from a line a bit before the one where the error was reported.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

if past_key_value is not None:
    # sin and cos are specific to RoPE models; cache_position needed for the static cache
    cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

Specifically, after the operation past_key_value.update, the last dimension of key_states changed from 2048 to 4096. Upon closer inspection through print statements, this change didn't affect all instances; initially, the dimension remained at 2048, but starting from a certain sample, key_states turned into 4096 after past_key_value.update, subsequently causing the error and terminating the program. Therefore, the problem lies within the past_key_value.update operation.

As a temporary solution, I commented out the section involving past_key_value.update, and the model fine-tuning proceeded without errors. Do you have any possible insights regarding this bug?