[BUG] An error occurs due to mismatched shapes during the process of splitting mixed_x_layer
RockMiin opened this issue · comments
I am currently in the process of pretraining GPT, and I encountered an error in the split_tensor function in megatron/model/transformer.py. The split_tensor function is documented as transforming [sq, b, nkv, (nq // nkv + 2), hn] to 3 [sq, b, np, hn]. During the process of reshaping the query_layer, I think it is correct to use mixed_x_layer.shape[:-2]
instead of mixed_x_layer.shape[:-1]
.
def split_tensor(self, mixed_x_layer):
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:-2] + (-1, self.hidden_size_per_attention_head))
key_layer = mixed_x_layer[:, :, :, -2, :]
value_layer = mixed_x_layer[:, :, :, -1, :]
return query_layer, key_layer, value_layer
I encountered the following error.
[default0]: query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:-1] + (-1, self.hidden_size_per_attention_head))
[default0]:RuntimeError: shape '[512, 1, 2, 3, -1, 4]' is invalid for input of size 4096
I encountered the same issue with the code in the latest main branch, and I'm unable to fix the problem. However, the code works when switching to the commit with the hash 2348eed on Nov 17, 2023.
I also encountered the same issue with the code in the latest main branch
CheckpointFunction.apply(function, all_outputs, *args)
File "/home/anaconda3/envs/zbx1/lib/python3.10/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 566, in forward
outputs = run_function(*inputs_cuda)
File "/home/zbx/Megatron-DeepSpeed/megatron/model/transformer.py", line 1729, in custom_forward
output = layer(x_, *args, **kwargs)
File "/home/anaconda3/envs/zbx1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/wangzhigangcs/zbx/Megatron-DeepSpeed/megatron/model/transformer.py", line 1222, in forward
self.self_attention(
File "/home/anaconda3/envs/zbx1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/zbx/Megatron-DeepSpeed/megatron/model/transformer.py", line 694, in forward
value_layer) = self.split_tensor(mixed_x_layer)
File "/home/zbx/Megatron-DeepSpeed/megatron/model/transformer.py", line 647, in split_tensor
query_layer = mixed_x_layer[:, :, :, :-2, :].reshape(mixed_x_layer.shape[:-1] + (-1, self.hidden_size_per_attention_head))
RuntimeError: shape '[2048, 2, 2, 3, -1, 64]' is invalid for input of size 524288