onnx sim shape inference failed
xiaotongnii opened this issue · comments
Conv op infer shape faile
x[1.8,80] -> Unsqueeze(aix 1) -> Conv(kernel 8,1,3,3)[1,8,x,80]
x输入形状:[1, 1, 8, 80]
卷积核形状:[8, 1, 3, 3]
填充(pads):[0, 1, 0, 1]
步幅(strides):[1, 1]
膨胀(dilations):[1, 1]
首先,根据填充和步幅计算输出的空间大小。假设输入的空间大小为 H×W,卷积核的大小为 KH×KW,填充为 PH×PW,步幅为 SH×SW,膨胀为 DH×DW。
输出的空间大小计算公式为:
[ OH = \frac{{H + 2 \times PH - DH \times (KH - 1) - 1}}{{SH}} + 1 ]
[ OW = \frac{{W + 2 \times PW - DW \times (KW - 1) - 1}}{{SW}} + 1 ]代入参数得到:
[ OH = \frac{{8 + 2 \times 0 - 1 \times (3 - 1) - 1}}{{1}} + 1 = 6 ]
[ OW = \frac{{80 + 2 \times 1 - 1 \times (3 - 1) - 1}}{{1}} + 1 = 80 ]因此,最终卷积的输出形状为 [1, 8, 6, 80]。
python api infer shape
auto OptAndShapeAndFold =
FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant},
fixed_point_iters, &converged);
auto sim_model = OptAndShapeAndFold(model);
C++ infer shape impl hal(onnxsim.cpp)
onnx::ModelProto _InferShapes(const onnx::ModelProto& model) {
onnx::ModelProto result;
result.CopyFrom(model);
onnx::shape_inference::InferShapes(result);
return result;
}
onnx inferShape impl
Lib\site-packages\onnx\shape_inference
void InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry,
const ShapeInferenceOptions& options,
std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
auto opset_imports = GetOpsetImportsFromProto(m);
SymbolTableImpl symbol_table;
ModelLocalFunctionsMap model_local_functions_by_id;
for (const auto& function_proto : m.functions()) {
model_local_functions_by_id.insert(
{GetModelLocalFunctionsMapIdentifier(function_proto.domain(), function_proto.name()), &function_proto});
}
InferShapesImpl(
m.mutable_graph(),
std::unordered_map<std::string, TypeProto*>(0),
opset_imports,
options,
&symbol_table,
model_local_functions_by_id,
schema_registry,
generated_shape_data_by_name,
m.ir_version());
}
void process(NodeProto& n) {
// Resolve domain for node
auto dit = opset_imports.find(n.domain());
if (dit == opset_imports.end()) {
// Both "" and "ai.onnx" refer to the default ONNX domain
if (n.domain() == "") {
dit = opset_imports.find("ai.onnx");
}
if (dit == opset_imports.end()) {
fail_type_inference(
"Cannot infer type and shape for node name ",
n.name(),
". No opset import for domain",
n.domain(),
" optype ",
n.op_type());
}
}
auto domain_version = dit->second;
const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
InferenceContextImpl ctx(
n,
value_types_by_name,
input_data_by_name,
input_sparse_data_by_name,
generated_shape_data_by_name,
&graph_inference_context);
ONNX_TRY {
if (schema) {
if (schema->has_type_and_shape_inference_function()) {
schema->GetTypeAndShapeInferenceFunction()(ctx);
} else if (schema->HasFunction()) {
InferShapeForFunctionNode(
*(schema->GetFunction()),
schema_registry,
ctx,
options,
model_local_functions_map,
symbol_table,
generated_shape_data_by_name);
} else {
// Continue with inference for remaining nodes
return;
}
} else if (model_local_functions_map.size() > 0) {
auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type()));
if (iter != model_local_functions_map.end()) {
InferShapeForFunctionNode(
*(iter->second),
schema_registry,
ctx,
options,
model_local_functions_map,
symbol_table,
generated_shape_data_by_name);
} else {
has_unsupported_op = true;
return;
}
} else {
has_unsupported_op = true;
return;
}
}
ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
ONNX_HANDLE_EXCEPTION([&]() {
// onnx does not support unsupported/experimental operators
// so it won't consider it as an error
if (!has_unsupported_op && !has_experimental_op) {
inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
}
});
// Continue with inference for remaining nodes
return;
}
ONNX_TRY {
// check the type-equality for input and output
if (options.check_type && schema) {
schema->CheckInputOutputType(ctx);
}
for (int i = 0; i < n.output_size(); ++i) {
// skip type and shape propagation for missing optional outputs.
if (!n.output(i).empty())
updateType(n.output(i), ctx.getOutputType(i));
}
preprocess(n);
// If data propagation is enabled, propagate shape data if it exists.
if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
if (generated_shape_data_by_name == nullptr) {
fail_shape_inference(
"Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
}
DataPropagationContextImpl data_propagation_ctx(
n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
schema->GetDataPropagationFunction()(data_propagation_ctx);
}
}
ONNX_CATCH(const std::runtime_error& err) {
ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
}
}
schema 是一个描述 ONNX 模型中操作(op)的规范。每个 op 都有一个对应的 schema,它定义了该 op 的输入和输出参数类型、名称、形状等信息。 ONNX 中的每个 op 都有一个唯一的名称,并且每个 op 的 schema 都可以通过 ONNX 官方文档或 ONNX 运行时 API 来获取。 在 ONNXShapeInference 类中的 process 函数中,schema 是指当前正在处理的节点(NodeProto)对应的 schema。通过调用 schema_registry->GetSchema(n.op_type(), domain_version, n.domain()) 方法,可以获取到该节点的 schema
name: "Conv"
since_version: "1"
description: """
Performs a convolution operation on the input tensor. The output is a tensor with the same rank as the input.
The convolution operation can be performed on any number of dimensions, but it is most commonly used for 2D images.
"""input [
{
name: "X"
description: "The input tensor."
type: T
shape: [D1, ..., Dn]
},
{
name: "W"
description: "The weights tensor."
type: T
shape: [M1, ..., Mm, K1, ..., Kn]
}
]output [
{
name: "Y"
description: "The output tensor."
type: T
shape: [D1, ..., Dn]
}
]where:
T = {tensor(float), tensor(double)}
n >= 2
m >= 2
D1, ..., Dn are the dimensions of the input tensor
M1, ..., Mm are the dimensions of the weights tensor
K1, ..., Kn are the kernel sizesattribute {
name: "strides"
type: INTS
description: "The strides of the convolution operation."
default: [1, ..., 1]
}attribute {
name: "pads"
type: INTS
description: "The paddings of the convolution operation."
default: [0, ..., 0]
}attribute {
name: "dilations"
type: INTS
description: "The dilations of the convolution operation."
default: [1, ..., 1]
}attribute {
name: "group"
type: INT
description: "The number of groups to split the input and output channels into."
default: 1
}
网络shape信息,input x shape 为[1,8,80],由Unsqueeze op 后,扩展维度为[1,1,8,80],但网络 Unsqueeze shape为[0,1,0,80],可见 Unsqueeze infer shape 时发生了错误,导致Conv output shape 发生错误。
Node Name: /encoder_embed/Unsqueeze
Node OpType: Unsqueeze
Input Shapes: []
Output Shapes: [(0, 1, 0, 80)]Node Name: /encoder_embed/conv/0/Conv
Node OpType: Conv
Input Shapes: [(0, 1, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]
Weights Shape: (8, 1, 3, 3)Node Name: /encoder_embed/conv/3/Sub
Node OpType: Sub
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]Node Name: /encoder_embed/conv/3/Max
Node OpType: Max
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]Node Name: /encoder_embed/conv/3/Sub_1
Node OpType: Sub
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]Node Name: /encoder_embed/conv/3/Abs
Node OpType: Abs
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]Node Name: /encoder_embed/conv/3/Neg
Node OpType: Neg
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]Node Name: /encoder_embed/conv/3/Exp
Node OpType: Exp
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]