apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators

Home Page:https://tvm.apache.org/

Geek Repo:Geek Repo

Github PK Tool:Github PK Tool

[Bug] [Relay] [Torch] [ONNX] Robustness of `Cast` operator accepting `NaN` values

shaoyuyoung opened this issue · comments

Description

Here is a single op: Cast
image

In TVM, when it accepts NaN value, it outputs False.

However, in PyTorch, it outputs True.

In Pytorch and ONNX, Cast would cast the Nonzero value to False, the others to True.
The evidence is here: https://onnx.ai/onnx/operators/onnx__Cast.html#l-onnx-doc-cast
image

I am unsure how the Cast op is defined in TVM. But if it is different from other frameworks/compilers (e.g., Pytorch & ONNX), the final results would be inconsistent with other frameworks/compilers in complex scenarios (i.e., a model containing more ops).

Code to repro

import pickle
import torch
import torch.nn as nn
import tvm
from tvm import relay
from tvm.contrib import graph_executor
import numpy as np
import onnx
import numpy.testing as npt


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

    def forward(self, input_tensor):
        cast_output = input_tensor.to(torch.bool)

        return cast_output


model = Model()
input_tensor = torch.tensor([float('nan')])

torch_output = model(input_tensor).numpy()

torch.onnx.export(
    model,
    input_tensor,
    "test.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    do_constant_folding=True,
)
onnx_model = onnx.load("test.onnx")

target = "llvm"

shape_dict = {"input": input_tensor.shape}

mod, params = relay.frontend.from_onnx(onnx_model, shape_dict)

dev = tvm.cpu(0)
with tvm.transform.PassContext(opt_level=4):
    executor = relay.build_module.create_executor(
        "graph", mod, dev, target, params
    ).evaluate()

inputs = {"input": tvm.nd.array(input_tensor.numpy())}

tvm_output = executor(**inputs).numpy()

npt.assert_allclose(torch_output, tvm_output, rtol=1e-5, atol=1e-8)

Error log

AssertionError: 
Not equal to tolerance rtol=1e-05, atol=1e-08

Mismatched elements: 1 / 1 (100%)
 x: array([ True])
 y: array([False])

Environment & Version

ubuntu 20
TVM d1ac1c0

cc @KJlaccHoeUM9l @shingjan @yelite