facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Issue in `ParallelEmbedding` constructor - scale_grad_by_freq being assigned to norm_type

gtamer2 opened this issue · comments

Overview

In the init function of ParallelEmbedding, both self.norm_type and self.scale_grad_by_freq are assigned to scale_grad_by_freq.

self.norm_type = scale_grad_by_freq
self.scale_grad_by_freq = scale_grad_by_freq

Impact

The embedding norm_type will be incorrectly set to 0 (the float interpretation of scale_grad_by_freq default boolean False). This does not cause a runtime error when in Eager mode because interpretation is still possible. However, it prevents compilation into Graph mode with Torch JIT, because the type mismatch would cause a compilation error.

Impact Example

Steps to Reproduce

Code

Create a new file script_model.py at the root of https://github.com/facebookresearch/llama

import torch
import fire
from llama import Llama
from typing import List

def script_transformer(ckpt_dir,tokenizer_path,max_seq_len,max_batch_size)
    llama = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
    return generator
    transformer = llama.model
    transformer.eval()
    with torch.no_grad():
        print("Attempting to script model...")
        scripted_transformer = torch.jit.script(transformer)

if __name__ == "__main__":
    fire.Fire(script_transformer)

and run torchrun script_model.py --ckpt_dir llama-2-7b/ --tokenizer_path tokenizer.model --max_seq_len 512 --max_batch_size 6

Error Message

RuntimeError: 

embedding(Tensor input, Tensor weight, int? padding_idx=None, float? max_norm=None, float norm_type=2., bool scale_grad_by_freq=False, bool sparse=False) -> Tensor:
Expected a value of type 'bool' for argument 'scale_grad_by_freq' but instead found type 'float'.
:
  File "/path/to/venv/lib/python3.9/site-packages/fairscale/nn/model_parallel/layers.py", line 205
    def forward(self, input_: torch.Tensor) -> torch.Tensor:  # type: ignore
        input_parallel = copy_to_model_parallel_region(input_)
        output_parallel = F.embedding(
                          ~~~~~~~~~~~ <--- HERE
            input_parallel,
            self.weight,

Thanks @gtamer2, also create a PR here.