TylerYep / torchinfo

View model summaries in PyTorch!

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Summary failed for transformers.GPT2Model

apivovarov opened this issue · comments

Versions

torch.__version__
'1.13.1+cu117'

transformers.__version__
'4.20.0'

torchinfo.__version__
'1.7.2'

To Reproduce

import torch
from transformers import GPT2Model, GPT2Tokenizer
from torchinfo import summary

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
model.eval()
input0 = torch.tensor(
    [[tokenizer.encode("Here is some text to encode Hello World", add_special_tokens=True)]]
)
outputs = model(input0)
summary(model, input0.shape)

Error:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py", line 288, in forward_pass
    _ = model.to(device)(*x, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 834, in forward
    inputs_embeds = self.wte(input_ids)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/sparse.py", line 160, in forward
    return F.embedding(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 2210, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py", line 218, in summary
    summary_list = forward_pass(
  File "/usr/local/lib/python3.8/dist-packages/torchinfo/torchinfo.py", line 297, in forward_pass
    raise RuntimeError(
RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []

ok, input0 dtype in INT64 but not default float32

The following works fine

summary(model, input_data=input0)