Issue with Rotary Embedding Initialization when the number of devices is > 1
diegodoimo opened this issue · comments
This issue is to the lit-gpt repo. By mistake, I raised it here. Apologies.
Hi,
I found an issue while working with the finetune/lora.py script. It seems that when the number of devices is greater than 1, the rotary embeddings are not initialized correctly. After investigating, I found that the reset_parameter_function might not be triggering the proper initialization of the rope embeddings.
Indeed, inspecting the cos and sin attributes of the GPT model after they are materialized on multiple GPUs, I observed a discrepancy in their initialization from the single GPU setup.
For instance, printing model.cos after this line, with more than one GPU:
model = fabric.setup_module(model)
fabric.print("model.cos:", model.cos)
I get:
model.cos: tensor([[-0.0062, 0.0113, -0.0073, ..., 0.0138, 0.0008, 0.0071],
[ 0.0012, 0.0030, -0.0079, ..., 0.0090, -0.0039, 0.0082],
[-0.0086, -0.0106, -0.0128, ..., 0.0115, 0.0030, 0.0096],
...,
[ 0.0071, 0.0103, 0.0019, ..., -0.0016, -0.0002, 0.0115]],
device='cuda:0', dtype=torch.bfloat16)
However, on a single GPU, the initialization is correct:
model.cos: tensor([[ 1.0000, 1.0000, 1.0000, ..., 1.0000, 1.0000, 1.0000],
[ 0.5403, 0.6479, 0.7318, ..., 1.0000, 1.0000, 1.0000],
[-0.4161, -0.1604, 0.0709, ..., 1.0000, 1.0000, 1.0000],
...,
[-0.0660, -0.7424, -0.0900, ..., 0.8077, 0.8546, 0.8903]],
device='cuda:0')
Potential cause
The issue seems to stem from this line. Indeed the model.sin and model.cos tensors seems already materialized on a CUDA device when the reset_parameter_function is called for initialization.
Possible fix
A possible fix could be to reinitialize the rope embeddings directly in the reset_parameter function:
def reset_parameters(self) -> None:
# Trigger resetting the rope-cache
# self.max_seq_length = self.config.block_size
self.cos, self.sin = self.rope_cache()
Can someone confirm this issue? If it's valid, I am willing to open a pull request with the proposed fix.
Environment
I am working in a conda environment and I have successfully installed the lit-gpt package as recommended in the README.md, with the following command:
pip install -r requirements-all.txt
Relevant dependencies include:
lightning 2.2.0.dev0 pypi_0 pypi
python 3.11.7 h955ad1f_0
pytorch-lightning 2.1.3 pypi_0 pypi
torch 2.1.2 pypi_0 pypi