NVIDIA-AI-IOT / torch2trt

An easy to use PyTorch to TensorRT converter

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Inconsistent inference results between PyTorch and converted TensorRT model with BatchNorm or InstanceNorm operator

hongliyu0716 opened this issue · comments

Description:

I'm experiencing a discrepancy between the inference results of PyTorch model and the TensorRT model obtained by converting it using the torch2trt tool.

Reproduce

  1. BatchNorm
import torch
from torch.nn import Module
from torch2trt import torch2trt

para_0 = torch.randn([2, 3, 4, 4, 4], dtype=torch.float32).cuda()
para_1 = torch.randn([3], dtype=torch.float32).cuda()
para_2 = torch.randn([3], dtype=torch.float32).cuda()
para_3 = torch.randn([3], dtype=torch.float32).cuda()
para_4 = torch.randn([3], dtype=torch.float32).cuda()
para_5 = True
para_6 = 0.00026441036488630354
para_7 = 0.001
class batch_norm(Module):
    def forward(self, *args):
        return torch.nn.functional.batch_norm(args[0], para_1,para_2,para_3,para_4,para_5,para_6,para_7,)
model = batch_norm().float().eval().cuda()
model_trt = torch2trt(model, [para_0])
output = model(para_0)
output_trt = model_trt(para_0)
print(torch.max(torch.abs(output - output_trt)))
  1. InstanceNorm
import torch
from torch.nn import Module
from torch2trt import torch2trt

para_0 = torch.randn([2, 3, 4, 4, 4], dtype=torch.float32).cuda()
para_1 = torch.randn([3], dtype=torch.float32).cuda()
para_2 = torch.randn([3], dtype=torch.float32).cuda()
para_3 = None
para_4 = None
para_5 = False
para_6 = 0.3
para_7 = 0.001
class instance_norm(Module):
    def forward(self, *args):
        return torch.nn.functional.instance_norm(args[0], para_1,para_2,para_3,para_4,para_5,para_6,para_7,)
model = instance_norm().float().eval().cuda()
model_trt = torch2trt(model, [para_0])

output = model(para_0)
output_trt = model_trt(para_0)
print(torch.max(torch.abs(output - output_trt)))

Environment

  • torch: 2.1.1
  • torch2trt: 0.4.0
  • tensorrt: 8.6.1

Hi @hongliyu0716 ,
Have you solved this issue yet?