This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 59af03c718 [Relax][ONNX] Drop NaN-preservation isnan-where wrappers 
(#19847)
59af03c718 is described below

commit 59af03c7188e3e25bfcecc9dc3c193b09574aa3b
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 19 20:32:55 2026 -0400

    [Relax][ONNX] Drop NaN-preservation isnan-where wrappers (#19847)
    
    This pr removes the explicit NaN-preservation guards added in the ONNX
    frontend for Relu (#19750), Sign (#19674), Clip's input (#19535), and
    ReduceMax/ReduceMin (the _reduce_min_max_preserve_nan helper, #19750).
    Each paid an extra isnan + where -- the reduce helper a full sum(isnan)
    pass
    
    This pr also drops the corresponding tests: test_relu_nan_preserve,
    test_sign_nan_preserve, test_reduce_min_max_nan_preserve, and the
    NaN-input case of test_clip_v13.
    
    PRs about NaN-preservation updated in backend will be followed up in the
    future
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  50 ++----------
 tests/python/relax/test_frontend_onnx.py        | 101 +-----------------------
 2 files changed, 9 insertions(+), 142 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7d38916926..cdb213f10d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1474,8 +1474,6 @@ class Clip(OnnxOpConverter):
         if inputs[2] is not None:
             hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
             results = bb.emit_te(topi.minimum, results, hi)
-        if _relax_dtype_is_floating_point(x.struct_info.dtype):
-            results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
         return results
 
 
@@ -1530,12 +1528,7 @@ class Relu(OnnxOpConverter):
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
-        x = inputs[0]
-        x_dtype = x.struct_info.dtype if isinstance(x.struct_info, 
relax.TensorStructInfo) else None
-        y = relax.op.nn.relu(x)
-        if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
-            return relax.op.where(relax.op.isnan(x), x, y)
-        return y
+        return relax.op.nn.relu(inputs[0])
 
 
 class Elu(OnnxOpConverter):
@@ -3892,28 +3885,6 @@ class RMSNormalization(OnnxOpConverter):
         return output
 
 
-def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims):
-    """Apply a min/max reduction with well-defined, order-independent NaN 
propagation.
-
-    relax.op.max/min legalize to a max/min fold implemented as select(x > y, 
x, y) with an
-    ordered float comparison, so NaN propagation depends on the fold position 
(a later non-NaN
-    element silently overwrites an earlier NaN). ONNX Runtime is also 
order-independent (it only
-    yields NaN when the first reduced element is NaN), which is an 
implementation artifact rather
-    than a defined semantics and is impractical to replicate portably. We 
instead adopt the
-    numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for 
floating pint inputs,
-    detect NaN along the reduced axes and force the output to NaN whenever any 
reduced element is
-    NaN.
-    """
-    y = reduce_op(data, axes, keepdims)
-    dtype = data.struct_info.dtype if isinstance(data.struct_info, 
relax.TensorStructInfo) else None
-    if dtype is None or not _relax_dtype_is_floating_point(dtype):
-        return y
-    nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype), 
axes, keepdims)
-    has_nan = relax.op.greater(nan_count, relax.const(0, dtype))
-    nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype))
-    return relax.op.where(has_nan, nan_filled, y)
-
-
 class ReduceMax(OnnxOpConverter):
     """Converts an onnx ReduceMax node into an equivalent Relax expression."""
 
@@ -3922,7 +3893,7 @@ class ReduceMax(OnnxOpConverter):
         data = inputs[0]
         axes = attr.get("axes", None)
         keepdims = attr.get("keepdims", 1)
-        return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)
+        return relax.op.max(data, axes, keepdims)
 
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
@@ -3939,13 +3910,13 @@ class ReduceMax(OnnxOpConverter):
 
         # If axes is empty and noop_with_empty_axes is False, reduce all dims
         if not axes and not noop_with_empty_axes:
-            return _reduce_min_max_preserve_nan(relax.op.max, data, None, 
keepdims)
+            return relax.op.max(data, None, keepdims)
         # If axes is empty and noop_with_empty_axes is True, return input 
unchanged
         elif not axes and noop_with_empty_axes:
             return data
         # Otherwise reduce over specified axes
         else:
-            return _reduce_min_max_preserve_nan(relax.op.max, data, axes, 
keepdims)
+            return relax.op.max(data, axes, keepdims)
 
 
 class ReduceMin(OnnxOpConverter):
@@ -3956,7 +3927,7 @@ class ReduceMin(OnnxOpConverter):
         data = inputs[0]
         axes = attr.get("axes", None)
         keepdims = attr.get("keepdims", 1)
-        return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)
+        return relax.op.min(data, axes, keepdims)
 
     @classmethod
     def _impl_v18(cls, bb, inputs, attr, params):
@@ -3973,13 +3944,13 @@ class ReduceMin(OnnxOpConverter):
 
         # If axes is empty and noop_with_empty_axes is False, reduce all dims
         if not axes and not noop_with_empty_axes:
-            return _reduce_min_max_preserve_nan(relax.op.min, data, None, 
keepdims)
+            return relax.op.min(data, None, keepdims)
         # If axes is empty and noop_with_empty_axes is True, return input 
unchanged
         elif not axes and noop_with_empty_axes:
             return data
         # Otherwise reduce over specified axes
         else:
-            return _reduce_min_max_preserve_nan(relax.op.min, data, axes, 
keepdims)
+            return relax.op.min(data, axes, keepdims)
 
 
 class ReduceSum(OnnxOpConverter):
@@ -4621,12 +4592,7 @@ class Sign(OnnxOpConverter):
 
     @classmethod
     def _impl_v9(cls, bb, inputs, attr, params):
-        x = inputs[0]
-        x_dtype = x.struct_info.dtype if isinstance(x.struct_info, 
relax.TensorStructInfo) else None
-        y = relax.op.sign(x)
-        if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
-            return relax.op.where(relax.op.isnan(x), x, y)
-        return y
+        return relax.op.sign(inputs[0])
 
 
 class Not(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 791526b584..6e9d4c9d95 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -775,104 +775,6 @@ def test_unary(op_name: str):
     verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, 
output_dtype=output_dtype)
 
 
-def test_sign_nan_preserve():
-    sign_node = helper.make_node("Sign", ["x"], ["y"])
-    graph = helper.make_graph(
-        [sign_node],
-        "sign_nan_test",
-        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4])],
-        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
-    )
-    model = helper.make_model(graph, producer_name="sign_nan_test")
-    model.ir_version = 8
-    for opset_import in model.opset_import:
-        if opset_import.domain in ["", "ai.onnx"]:
-            opset_import.version = 18
-            break
-    x = np.array([np.nan, 9.0, -9.0, np.nan], dtype=np.float32)
-
-    ort_out = onnxruntime.InferenceSession(
-        model.SerializeToString(), providers=["CPUExecutionProvider"]
-    ).run([], {"x": x})[0]
-
-    tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
-    out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else 
tvm_out).numpy()
-
-    np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
-    np.testing.assert_allclose(
-        out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7, 
atol=1e-5
-    )
-
-
-def test_relu_nan_preserve():
-    relu_node = helper.make_node("Relu", ["x"], ["y"])
-    graph = helper.make_graph(
-        [relu_node],
-        "relu_nan_test",
-        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])],
-        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [5])],
-    )
-    model = helper.make_model(graph, producer_name="relu_nan_test")
-    model.ir_version = 8
-    for opset_import in model.opset_import:
-        if opset_import.domain in ["", "ai.onnx"]:
-            opset_import.version = 18
-            break
-    x = np.array([np.nan, 9.0, -9.0, 0.0, np.nan], dtype=np.float32)
-
-    ort_out = onnxruntime.InferenceSession(
-        model.SerializeToString(), providers=["CPUExecutionProvider"]
-    ).run([], {"x": x})[0]
-
-    tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
-    out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else 
tvm_out).numpy()
-
-    np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
-    np.testing.assert_allclose(
-        out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7, 
atol=1e-5
-    )
-
-
[email protected]("op_name", ["ReduceMax", "ReduceMin"])
[email protected](
-    "x",
-    [
-        # NaN in different positions. TVM's max/min fold previously dropped 
NaN depending on
-        # position, ONNX Runtime only propagates NaN when it is the first 
reduced element, which
-        # is an order-dependent implementation artifact. We instead adopt the 
well-defined,
-        # order-independent numpy/IEEE semantics: any NaN in the reduced range 
yields NaN.
-        np.array([np.nan, 1.0, 2.0], dtype=np.float32),
-        np.array([2.0, 1.0, np.nan], dtype=np.float32),
-        np.array([1.0, np.nan, 2.0], dtype=np.float32),
-        np.array([1.0, 2.0, 3.0], dtype=np.float32),
-    ],
-)
-def test_reduce_min_max_nan_preserve(op_name, x):
-    reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0)
-    graph = helper.make_graph(
-        [reduce_node],
-        "reduce_nan_test",
-        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
list(x.shape))],
-        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
-    )
-    model = helper.make_model(graph, producer_name="reduce_nan_test")
-    model.ir_version = 8
-    for opset_import in model.opset_import:
-        if opset_import.domain in ["", "ai.onnx"]:
-            opset_import.version = 18
-            break
-
-    # Reference is numpy (NaN propagates if any element is NaN), not ONNX 
Runtime.
-    ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
-
-    tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
-    out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else 
tvm_out).numpy()
-
-    np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
-    if not np.isnan(ref_out):
-        np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5)
-
-
 @pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
 def test_softmax_family_opset11_default_axis_semantics(op_name: str):
     verify_unary(op_name, [2, 3, 4], opset=11)
@@ -1775,11 +1677,10 @@ def test_clip_v6(max, min):
     "input",
     [
         np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
-        np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
     ],
 )
 def test_clip_v13(input, min, max):
-    # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); 
input NaN preserved.
+    # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT).
     clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
     graph = helper.make_graph(
         [clip_node],

Reply via email to