daquexian / onnx-simplifier

Simplify your onnx model

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

onnx sim shape inference failed

xiaotongnii opened this issue · comments

Conv op infer shape faile
x[1.8,80] -> Unsqueeze(aix 1) -> Conv(kernel 8,1,3,3)[1,8,x,80]

x输入形状:[1, 1, 8, 80]
卷积核形状:[8, 1, 3, 3]
填充(pads):[0, 1, 0, 1]
步幅(strides):[1, 1]
膨胀(dilations):[1, 1]

首先,根据填充和步幅计算输出的空间大小。假设输入的空间大小为 H×W,卷积核的大小为 KH×KW,填充为 PH×PW,步幅为 SH×SW,膨胀为 DH×DW。

输出的空间大小计算公式为:
[ OH = \frac{{H + 2 \times PH - DH \times (KH - 1) - 1}}{{SH}} + 1 ]
[ OW = \frac{{W + 2 \times PW - DW \times (KW - 1) - 1}}{{SW}} + 1 ]

代入参数得到:
[ OH = \frac{{8 + 2 \times 0 - 1 \times (3 - 1) - 1}}{{1}} + 1 = 6 ]
[ OW = \frac{{80 + 2 \times 1 - 1 \times (3 - 1) - 1}}{{1}} + 1 = 80 ]

因此,最终卷积的输出形状为 [1, 8, 6, 80]。

python api infer shape

  auto OptAndShapeAndFold =
      FixedPointFn(std::function{OptAndShape}, std::function{FoldConstant},
                   fixed_point_iters, &converged);
  auto sim_model = OptAndShapeAndFold(model);

C++ infer shape impl hal(onnxsim.cpp)

onnx::ModelProto _InferShapes(const onnx::ModelProto& model) {
  onnx::ModelProto result;
  result.CopyFrom(model);
  onnx::shape_inference::InferShapes(result);
  return result;
}

onnx inferShape impl
Lib\site-packages\onnx\shape_inference

void InferShapes(
    ModelProto& m,
    const ISchemaRegistry* schema_registry,
    const ShapeInferenceOptions& options,
    std::unordered_map<std::string, TensorShapeProto>* generated_shape_data_by_name) {
  auto opset_imports = GetOpsetImportsFromProto(m);
  SymbolTableImpl symbol_table;
  ModelLocalFunctionsMap model_local_functions_by_id;
  for (const auto& function_proto : m.functions()) {
    model_local_functions_by_id.insert(
        {GetModelLocalFunctionsMapIdentifier(function_proto.domain(), function_proto.name()), &function_proto});
  }
  InferShapesImpl(
      m.mutable_graph(),
      std::unordered_map<std::string, TypeProto*>(0),
      opset_imports,
      options,
      &symbol_table,
      model_local_functions_by_id,
      schema_registry,
      generated_shape_data_by_name,
      m.ir_version());
}
    void process(NodeProto& n) {
    // Resolve domain for node
    auto dit = opset_imports.find(n.domain());
    if (dit == opset_imports.end()) {
      // Both "" and "ai.onnx" refer to the default ONNX domain
      if (n.domain() == "") {
        dit = opset_imports.find("ai.onnx");
      }
      if (dit == opset_imports.end()) {
        fail_type_inference(
            "Cannot infer type and shape for node name ",
            n.name(),
            ". No opset import for domain",
            n.domain(),
            " optype ",
            n.op_type());
      }
    }
    auto domain_version = dit->second;
    const auto schema = schema_registry->GetSchema(n.op_type(), domain_version, n.domain());
    InferenceContextImpl ctx(
        n,
        value_types_by_name,
        input_data_by_name,
        input_sparse_data_by_name,
        generated_shape_data_by_name,
        &graph_inference_context);

    ONNX_TRY {
      if (schema) {
        if (schema->has_type_and_shape_inference_function()) {
          schema->GetTypeAndShapeInferenceFunction()(ctx);
        } else if (schema->HasFunction()) {
          InferShapeForFunctionNode(
              *(schema->GetFunction()),
              schema_registry,
              ctx,
              options,
              model_local_functions_map,
              symbol_table,
              generated_shape_data_by_name);
        } else {
          // Continue with inference for remaining nodes
          return;
        }
      } else if (model_local_functions_map.size() > 0) {
        auto iter = model_local_functions_map.find(GetModelLocalFunctionsMapIdentifier(n.domain(), n.op_type()));
        if (iter != model_local_functions_map.end()) {
          InferShapeForFunctionNode(
              *(iter->second),
              schema_registry,
              ctx,
              options,
              model_local_functions_map,
              symbol_table,
              generated_shape_data_by_name);
        } else {
          has_unsupported_op = true;
          return;
        }
      } else {
        has_unsupported_op = true;
        return;
      }
    }
    ONNX_CATCH(const ONNX_NAMESPACE::InferenceError& ex) {
      ONNX_HANDLE_EXCEPTION([&]() {
        // onnx does not support unsupported/experimental operators
        // so it won't consider it as an error
        if (!has_unsupported_op && !has_experimental_op) {
          inference_errors.push_back(GetErrorWithNodeInfo(n, ex));
        }
      });
      // Continue with inference for remaining nodes
      return;
    }

    ONNX_TRY {
      // check the type-equality for input and output
      if (options.check_type && schema) {
        schema->CheckInputOutputType(ctx);
      }

      for (int i = 0; i < n.output_size(); ++i) {
        // skip type and shape propagation for missing optional outputs.
        if (!n.output(i).empty())
          updateType(n.output(i), ctx.getOutputType(i));
      }

      preprocess(n);

      // If data propagation is enabled, propagate shape data if it exists.
      if (options.enable_data_propagation && schema && schema->has_data_propagation_function()) {
        if (generated_shape_data_by_name == nullptr) {
          fail_shape_inference(
              "Container for generated shape data cannot be nullptr when enable_data_propagation option is set.");
        }
        DataPropagationContextImpl data_propagation_ctx(
            n, value_types_by_name, input_data_by_name, *generated_shape_data_by_name);
        schema->GetDataPropagationFunction()(data_propagation_ctx);
      }
    }
    ONNX_CATCH(const std::runtime_error& err) {
      ONNX_HANDLE_EXCEPTION([&]() { fail_shape_inference(GetErrorWithNodeInfo(n, err)); });
    }
  }

schema 是一个描述 ONNX 模型中操作(op)的规范。每个 op 都有一个对应的 schema,它定义了该 op 的输入和输出参数类型、名称、形状等信息。 ONNX 中的每个 op 都有一个唯一的名称,并且每个 op 的 schema 都可以通过 ONNX 官方文档或 ONNX 运行时 API 来获取。 在 ONNXShapeInference 类中的 process 函数中,schema 是指当前正在处理的节点(NodeProto)对应的 schema。通过调用 schema_registry->GetSchema(n.op_type(), domain_version, n.domain()) 方法,可以获取到该节点的 schema

name: "Conv"
since_version: "1"
description: """
Performs a convolution operation on the input tensor. The output is a tensor with the same rank as the input.
The convolution operation can be performed on any number of dimensions, but it is most commonly used for 2D images.
"""

input [
{
name: "X"
description: "The input tensor."
type: T
shape: [D1, ..., Dn]
},
{
name: "W"
description: "The weights tensor."
type: T
shape: [M1, ..., Mm, K1, ..., Kn]
}
]

output [
{
name: "Y"
description: "The output tensor."
type: T
shape: [D1, ..., Dn]
}
]

where:
T = {tensor(float), tensor(double)}
n >= 2
m >= 2
D1, ..., Dn are the dimensions of the input tensor
M1, ..., Mm are the dimensions of the weights tensor
K1, ..., Kn are the kernel sizes

attribute {
name: "strides"
type: INTS
description: "The strides of the convolution operation."
default: [1, ..., 1]
}

attribute {
name: "pads"
type: INTS
description: "The paddings of the convolution operation."
default: [0, ..., 0]
}

attribute {
name: "dilations"
type: INTS
description: "The dilations of the convolution operation."
default: [1, ..., 1]
}

attribute {
name: "group"
type: INT
description: "The number of groups to split the input and output channels into."
default: 1
}

网络shape信息,input x shape 为[1,8,80],由Unsqueeze op 后,扩展维度为[1,1,8,80],但网络 Unsqueeze shape为[0,1,0,80],可见 Unsqueeze infer shape 时发生了错误,导致Conv output shape 发生错误。

Node Name: /encoder_embed/Unsqueeze
Node OpType: Unsqueeze
Input Shapes: []
Output Shapes: [(0, 1, 0, 80)]

Node Name: /encoder_embed/conv/0/Conv
Node OpType: Conv
Input Shapes: [(0, 1, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]
Weights Shape: (8, 1, 3, 3)

Node Name: /encoder_embed/conv/3/Sub
Node OpType: Sub
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]

Node Name: /encoder_embed/conv/3/Max
Node OpType: Max
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]

Node Name: /encoder_embed/conv/3/Sub_1
Node OpType: Sub
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]

Node Name: /encoder_embed/conv/3/Abs
Node OpType: Abs
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]

Node Name: /encoder_embed/conv/3/Neg
Node OpType: Neg
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]

Node Name: /encoder_embed/conv/3/Exp
Node OpType: Exp
Input Shapes: [(0, 8, 0, 80)]
Output Shapes: [(0, 8, 0, 80)]