microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator

Home Page:https://onnxruntime.ai

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Feature request] Support bfloat16/float8 inputs in `session.run()`

justinchuby opened this issue · comments

session.run() currently does not support bfloat16 inputs because numpy does not support bfloat16. Would it be possible to support the input via np.uint16? session.run() can accept an uint16 value which is the bit representation of the bfloat16 value. ORT should be able to interpret the value correctly because it knows the expected input type of the graph. Same can be done for the float8* types.

ORT supports bfloat16 internally. I would think a separate per type approach is not the best way to go. Rather, I would implement a new universal method that does not depend on numpy so much, and also include other types, such as quantized 8 and 4.

I would use a built-in array module and provide shape and type as separate argument for maximum flexibility.
However, along with that, we should have a capability to return those types as results. And right now, we only return numpy arrays.

I think OrtValue based interfaces can help us. Simpe run() returns numpy.
However, we can add methods to OrtValue to accept other types in python array and convert the result to python arrays as well.