This is an automated email from the ASF dual-hosted git repository.
guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fbb9102334 [Relax][ONNX] Preserve NaN in Relu to align with ONNX
Runtime (#19750)
fbb9102334 is described below
commit fbb9102334b27e7137d69766803f185b8cab4439
Author: Neo Chien <[email protected]>
AuthorDate: Fri Jun 19 23:12:40 2026 +0800
[Relax][ONNX] Preserve NaN in Relu to align with ONNX Runtime (#19750)
Hi Committers,
This PR fixes issues https://github.com/apache/tvm/issues/19572. Any
suggestions would be appreciated if you are available.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 7 +++++-
tests/python/relax/test_frontend_onnx.py | 29 +++++++++++++++++++++++++
2 files changed, 35 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 3cfe7c892c..7d38916926 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1530,7 +1530,12 @@ class Relu(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
- return relax.op.nn.relu(inputs[0])
+ x = inputs[0]
+ x_dtype = x.struct_info.dtype if isinstance(x.struct_info,
relax.TensorStructInfo) else None
+ y = relax.op.nn.relu(x)
+ if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
+ return relax.op.where(relax.op.isnan(x), x, y)
+ return y
class Elu(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 57f780868c..791526b584 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -804,6 +804,35 @@ def test_sign_nan_preserve():
)
+def test_relu_nan_preserve():
+ relu_node = helper.make_node("Relu", ["x"], ["y"])
+ graph = helper.make_graph(
+ [relu_node],
+ "relu_nan_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [5])],
+ )
+ model = helper.make_model(graph, producer_name="relu_nan_test")
+ model.ir_version = 8
+ for opset_import in model.opset_import:
+ if opset_import.domain in ["", "ai.onnx"]:
+ opset_import.version = 18
+ break
+ x = np.array([np.nan, 9.0, -9.0, 0.0, np.nan], dtype=np.float32)
+
+ ort_out = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ ).run([], {"x": x})[0]
+
+ tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
+ out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
+
+ np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
+ np.testing.assert_allclose(
+ out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7,
atol=1e-5
+ )
+
+
@pytest.mark.parametrize("op_name", ["ReduceMax", "ReduceMin"])
@pytest.mark.parametrize(
"x",