splinter21 / torch2trt_dynamic

A pytorch to tensorrt convert with dynamic shape support

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

torch2trt dynamic

This is a branch of torch2trt with dynamic input support

Note that not all layers support dynamic input such as torch.split() etc...


Here are some examples


from torch2trt_dynamic import torch2trt_dynamic
import torch
from torch import nn
from torchvision.models.resnet import resnet50

# create some regular pytorch model...
model = resnet50().cuda().eval()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
opt_shape_param = [
        [1, 3, 128, 128],   # min
        [1, 3, 256, 256],   # opt
        [1, 3, 512, 512]    # max
model_trt = torch2trt_dynamic(model, [x], fp16_mode=False, opt_shape_param=opt_shape_param)


We can execute the returned TRTModule just like the original PyTorch model

x = torch.rand(1,3,256,256).cuda()
with torch.no_grad():
    y = model(x)
    y_trt = model_trt(x)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))

Save and load

We can save the model as a state_dict.

torch.save(model_trt.state_dict(), 'alexnet_trt.pth')

We can load the saved model into a TRTModule

from torch2trt_dynamic import TRTModule

model_trt = TRTModule()



To install without compiling plugins, call the following

git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic
cd torch2trt_dynamic
python setup.py develop

Set plugins(optional)

Some layers such as GN need c++ plugins. Install the plugin project below


DO NOT FORGET to export the environment variable AMIRSTAN_LIBRARY_PATH

How to add (or override) a converter

Here we show how to add a converter for the ReLU module using the TensorRT Python API.

import tensorrt as trt
from torch2trt_dynamic import tensorrt_converter

def convert_ReLU(ctx):
    input = ctx.method_args[1]
    output = ctx.method_return
    layer = ctx.network.add_activation(input=input._trt, type=trt.ActivationType.RELU)
    output._trt = layer.get_output(0)

The converter takes one argument, a ConversionContext, which will contain the following

  • ctx.network - The TensorRT network that is being constructed.

  • ctx.method_args - Positional arguments that were passed to the specified PyTorch function. The _trt attribute is set for relevant input tensors.

  • ctx.method_kwargs - Keyword arguments that were passed to the specified PyTorch function.

  • ctx.method_return - The value returned by the specified PyTorch function. The converter must set the _trt attribute where relevant.

Please see this folder for more examples.

ezoic increase your site revenue


A pytorch to tensorrt convert with dynamic shape support

License:MIT License


Language:Python 95.0%Language:Jupyter Notebook 4.1%Language:Shell 0.8%