pykeio / ort

A Rust wrapper for ONNX Runtime

Home Page:https://ort.pyke.io/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

Output order is not preserved with .iter()

gregszumel opened this issue · comments

Thanks for your work building/maintaining ort!

Our application relies on iterating through the outputs from ort.run(...), and expects those outputs to be in the same order as defined in the onnx file. However, we noticed upon updating to ort-2.0 that the output order is not preserved.

Here's a minimum working example. If we take the following onnx-file, generated by:

import torch

class MyModel(torch.nn.Module):
    def __init__(self, a):
        super().__init__()
        self.a = a

    def forward(self, X):
        return X, X[0, 0, 0, :], X[0,0,0,0]
    

dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
model = MyModel(True)

input_names = [ "actual_input1" ]
output_names = [ "output1", "output2", "output3" ]

torch.onnx.export(
    model, dummy_input, 
    "test.onnx", verbose=True, 
    input_names=input_names, 
    output_names=output_names)

we would expect the 3 outputs of the onnx file to have shape [10, 3, 224, 224], [244], and []. But running the following ort 2.0 snippet, we get:

use ort::{GraphOptimizationLevel, Session};

fn main() {

    let model = Session::builder().unwrap()
        .with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
        .with_model_from_file("test.onnx").unwrap();

    let image: ndarray::Array4<f32> = ndarray::Array4::zeros([10, 3, 224, 224]);
    let outputs = model.run(ort::inputs![image].unwrap()).unwrap();

    let vals: Vec<_> = outputs.iter().map(|(name, tensor)| {
        let val: ort::Tensor<f32> = tensor.extract_tensor().unwrap();
        println!("{}, {:?}", name, val.view().shape());
        val
    }).collect();
}

we get a different order:

output3, []
output1, [10, 3, 224, 224]
output2, [224]

Do you have any thoughts on how we can iterate through the output such that we can preserve the original output order? The examples I looked at hard-code the output (e.g., output["output_1"]) or indexed directly into the output (e.g., output[0]).

Do you have any thoughts on how we can iterate through the output such that we can preserve the original output order?

No, we get the order of outputs directly from ONNX Runtime. Maybe something changed with ONNX Runtime v1.17, or torch.onnx.export is not exporting the outputs in the correct order.

I'm not seeing the issue when using onnxruntime 1.17.1 and onnx 1.15. Running the above python script with

import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("test.onnx")

outputs = ort_session.run(
    None,
    {"actual_input1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)

print(ort.__version__)

for output in outputs:
    print(output.shape)

returns output in the correct order:

1.17.1
(10, 3, 224, 224)
(224,)
()

Indeed, I forgot SessionOutputs used a HashMap which doesn't guarantee order. 39bd0d0 refactors it to use a BTreeMap which should preserve the order.

^ This fixed the issue - thanks!

Would it be possible to cut a release with these changes?