调用torch.onnx.export()导出时遇到一些问题
Longfei-Jiao opened this issue · comments
Longfei-Jiao commented
作者您好!我想要将发布的CasA-V模型转换到onnx格式做一些测试,把CasA-V.pth文件下载后放在了tools文件夹中,并且在tools/test.py文件中删除掉原来190-195行内容,并在189行后添加如下代码:
model.load_params_from_file(filename=args.ckpt, logger=logger, to_cpu=dist_test)
# set the model to inference mode
model.eval()
# Let's create a dummy input tensor
dummy_input = torch.randn(1, 3, 244, 244, requires_grad=True)
# Export the model
torch.onnx.export(model=model, # model being run
args=dummy_input, # model input (or a tuple for multiple inputs)
f="Output-CasA.onnx", # where to save the model
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['modelInput'], # the model's input names
output_names=['modelOutput'], # the model's output names
dynamic_axes={'modelInput': {0: 'batch_size'}, # variable length axes
'modelOutput': {0: 'batch_size'}}
)
pycharm报错信息:
Traceback (most recent call last):
File "/home/jlf/PycharmProjects/CasA/tools/pthToOnnx.py", line 148, in <module>
main()
File "/home/jlf/PycharmProjects/CasA/tools/pthToOnnx.py", line 123, in main
torch.onnx.export(model=model, # model being run
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/__init__.py", line 203, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 86, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 526, in _export
graph, params_dict, torch_out = _model_to_graph(model, args, verbose, input_names,
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 366, in _model_to_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/onnx/utils.py", line 319, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 338, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 421, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/jit/__init__.py", line 412, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/jlf/PycharmProjects/CasA/pcdet/models/detectors/voxel_rcnn.py", line 11, in forward
batch_dict = cur_module(batch_dict)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 720, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 704, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/jlf/PycharmProjects/CasA/pcdet/models/backbones_3d/vfe/mean_vfe.py", line 27, in forward
if 'semi_test' in batch_dict:
File "/home/jlf/miniconda3/lib/python3.8/site-packages/torch/tensor.py", line 502, in __contains__
raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.
Process finished with exit code 1
Hailanyi commented
抱歉,我没有相关经验,不能解决你的问题。