Issue in `ParallelEmbedding` constructor - scale_grad_by_freq being assigned to norm_type
gtamer2 opened this issue · comments
Overview
In the init function of ParallelEmbedding, both self.norm_type
and self.scale_grad_by_freq
are assigned to scale_grad_by_freq
.
fairscale/fairscale/nn/model_parallel/layers.py
Lines 181 to 182 in 164cc0f
Impact
The embedding norm_type will be incorrectly set to 0 (the float interpretation of scale_grad_by_freq default boolean False). This does not cause a runtime error when in Eager mode because interpretation is still possible. However, it prevents compilation into Graph mode with Torch JIT, because the type mismatch would cause a compilation error.
Impact Example
- Facebook Research's Llama model uses parallel embeddings without
norm_type
orscale_grad_by_freq
defined here: https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/model.py#L437-L439 - This prevents the Llama transformer block to be scripted by Torch JIT, because the type mismatch is a compiler error
Steps to Reproduce
Code
Create a new file script_model.py
at the root of https://github.com/facebookresearch/llama
import torch
import fire
from llama import Llama
from typing import List
def script_transformer(ckpt_dir,tokenizer_path,max_seq_len,max_batch_size)
llama = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
return generator
transformer = llama.model
transformer.eval()
with torch.no_grad():
print("Attempting to script model...")
scripted_transformer = torch.jit.script(transformer)
if __name__ == "__main__":
fire.Fire(script_transformer)
and run torchrun script_model.py --ckpt_dir llama-2-7b/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6
Error Message
RuntimeError:
embedding(Tensor input, Tensor weight, int? padding_idx=None, float? max_norm=None, float norm_type=2., bool scale_grad_by_freq=False, bool sparse=False) -> Tensor:
Expected a value of type 'bool' for argument 'scale_grad_by_freq' but instead found type 'float'.
:
File "/path/to/venv/lib/python3.9/site-packages/fairscale/nn/model_parallel/layers.py", line 205
def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore
input_parallel = copy_to_model_parallel_region(input_)
output_parallel = F.embedding(
~~~~~~~~~~~ <--- HERE
input_parallel,
self.weight,
Thanks @gtamer2, also create a PR here.
Thanks @brad-mengchi !