gangweiX / Fast-ACVNet

[TPAMI 2023] Fast-ACV: Fast Attention Concatenation Volume for Accurate and Real-time Stereo Matching

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

How to export onnx model?

ForestWang opened this issue · comments

Hi bro:
I'm working on converting the pytorch model to onnx with the following script, seems ok. but when infering with TensorRT the result is wrong. so could you help to figure it out ? thanks a lot.


import torch
import torch.nn.functional as F
import numpy as np
from models import models
import onnx
import onnxruntime as ort

def main():
attention_weights_only = False
model = models['Fast_ACVNet_plus'](192, attention_weights_only)

#load parameters
model_path = './weights/generalization.ckpt'
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
model_dict = model.state_dict()
pre_dict = {k: v for k, v in state_dict['model'].items() if k in model_dict}
model_dict.update(pre_dict) 
model.load_state_dict(model_dict)
model.eval()


#export to onnx
in_h, in_w = (480, 640)
t1 = torch.randn(1, 3, in_h, in_w)
t2 = torch.randn(1, 3, in_h, in_w)
output = model(t1, t2)
print(output[0].shape)
torch.onnx.export(model,               
                  (t1, t2),
                  "fast_acvplus.onnx",       # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=16,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['left_image', 'right_image'],   # the model's input names
                  output_names = ['output'])

# onnx loading
# Load the ONNX model
model = onnx.load("fast_acvplus.onnx")

# Check that the model is well formed
check = onnx.checker.check_model(model)
print('check: ', check)

if name == 'main':
main()