NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

TypeError: 'int' object is not iterable when using torch2trt to convert a PyTorch model with Transformer operator

Thrsu opened this issue · comments

Description:

I encountered an error while using torch2trt for converting a PyTorch model with Transformer operator. The error message is as follows:

Traceback (most recent call last):
  ...
    model_trt = torch2trt(model, input_data)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 778, in torch2trt
    outputs = module(*inputs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 204, in forward
    memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask,
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 387, in forward
    output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 707, in forward
    x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/transformer.py", line 715, in _sa_block
    x = self.self_attn(x, x, x,
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1241, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 300, in wrapper
    outputs = method(*args, **kwargs)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/functional.py", line 5300, in multi_head_attention_forward
    q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch/nn/functional.py", line 4827, in _in_projection_packed
    return proj[0], proj[1], proj[2]
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/torch2trt.py", line 309, in wrapper
    converter["converter"](ctx)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/getitem.py", line 40, in convert_tensor_getitem
    num_ellipsis = len(input.shape) - num_slice_types(slices)
  File "/root/miniconda3/envs/nnsmith/lib/python3.9/site-packages/torch2trt-0.4.0-py3.9.egg/torch2trt/converters/getitem.py", line 22, in num_slice_types
    for s in slices:
TypeError: 'int' object is not iterable

Reproduce:

Here is a minimal script to reproduce the issue:

import torch
from torch.nn import Module
from torch2trt import torch2trt

model = torch.nn.Transformer(4,2,2,2,8,0.0,'relu',).eval().cuda()
input_data = [torch.randn([3, 3, 4], dtype=torch.float32).cuda(), torch.randn([2, 3, 4], dtype=torch.float32).cuda(), torch.randn([3, 3], dtype=torch.float32).cuda()]
model_trt = torch2trt(model, input_data)

Environment

  • torch: 2.1.1
  • torch2trt: 0.4.0
  • tensorrt: 8.6.1

I would appreciate it if you could look into this issue and provide any guidance or potential solutions. Let me know if you need any further information. Thank you for your assistance!

Hi
Is there any update regarding this issue?
I have the same.