This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 59af03c718 [Relax][ONNX] Drop NaN-preservation isnan-where wrappers
(#19847)
59af03c718 is described below
commit 59af03c7188e3e25bfcecc9dc3c193b09574aa3b
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 19 20:32:55 2026 -0400
[Relax][ONNX] Drop NaN-preservation isnan-where wrappers (#19847)
This pr removes the explicit NaN-preservation guards added in the ONNX
frontend for Relu (#19750), Sign (#19674), Clip's input (#19535), and
ReduceMax/ReduceMin (the _reduce_min_max_preserve_nan helper, #19750).
Each paid an extra isnan + where -- the reduce helper a full sum(isnan)
pass
This pr also drops the corresponding tests: test_relu_nan_preserve,
test_sign_nan_preserve, test_reduce_min_max_nan_preserve, and the
NaN-input case of test_clip_v13.
PRs about NaN-preservation updated in backend will be followed up in the
future
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 50 ++----------
tests/python/relax/test_frontend_onnx.py | 101 +-----------------------
2 files changed, 9 insertions(+), 142 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 7d38916926..cdb213f10d 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1474,8 +1474,6 @@ class Clip(OnnxOpConverter):
if inputs[2] is not None:
hi = cls._sanitize_nan_clip_bound(bb, inputs[2], for_min=False)
results = bb.emit_te(topi.minimum, results, hi)
- if _relax_dtype_is_floating_point(x.struct_info.dtype):
- results = bb.emit(relax.op.where(relax.op.isnan(x), x, results))
return results
@@ -1530,12 +1528,7 @@ class Relu(OnnxOpConverter):
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
- x = inputs[0]
- x_dtype = x.struct_info.dtype if isinstance(x.struct_info,
relax.TensorStructInfo) else None
- y = relax.op.nn.relu(x)
- if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
- return relax.op.where(relax.op.isnan(x), x, y)
- return y
+ return relax.op.nn.relu(inputs[0])
class Elu(OnnxOpConverter):
@@ -3892,28 +3885,6 @@ class RMSNormalization(OnnxOpConverter):
return output
-def _reduce_min_max_preserve_nan(reduce_op, data, axes, keepdims):
- """Apply a min/max reduction with well-defined, order-independent NaN
propagation.
-
- relax.op.max/min legalize to a max/min fold implemented as select(x > y,
x, y) with an
- ordered float comparison, so NaN propagation depends on the fold position
(a later non-NaN
- element silently overwrites an earlier NaN). ONNX Runtime is also
order-independent (it only
- yields NaN when the first reduced element is NaN), which is an
implementation artifact rather
- than a defined semantics and is impractical to replicate portably. We
instead adopt the
- numpy/IEEE convention used by numpy.max/min and torch.amax/amin: for
floating pint inputs,
- detect NaN along the reduced axes and force the output to NaN whenever any
reduced element is
- NaN.
- """
- y = reduce_op(data, axes, keepdims)
- dtype = data.struct_info.dtype if isinstance(data.struct_info,
relax.TensorStructInfo) else None
- if dtype is None or not _relax_dtype_is_floating_point(dtype):
- return y
- nan_count = relax.op.sum(relax.op.astype(relax.op.isnan(data), dtype),
axes, keepdims)
- has_nan = relax.op.greater(nan_count, relax.const(0, dtype))
- nan_filled = relax.op.full_like(y, relax.const(float("nan"), dtype))
- return relax.op.where(has_nan, nan_filled, y)
-
-
class ReduceMax(OnnxOpConverter):
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""
@@ -3922,7 +3893,7 @@ class ReduceMax(OnnxOpConverter):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
- return _reduce_min_max_preserve_nan(relax.op.max, data, axes, keepdims)
+ return relax.op.max(data, axes, keepdims)
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
@@ -3939,13 +3910,13 @@ class ReduceMax(OnnxOpConverter):
# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
- return _reduce_min_max_preserve_nan(relax.op.max, data, None,
keepdims)
+ return relax.op.max(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input
unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
- return _reduce_min_max_preserve_nan(relax.op.max, data, axes,
keepdims)
+ return relax.op.max(data, axes, keepdims)
class ReduceMin(OnnxOpConverter):
@@ -3956,7 +3927,7 @@ class ReduceMin(OnnxOpConverter):
data = inputs[0]
axes = attr.get("axes", None)
keepdims = attr.get("keepdims", 1)
- return _reduce_min_max_preserve_nan(relax.op.min, data, axes, keepdims)
+ return relax.op.min(data, axes, keepdims)
@classmethod
def _impl_v18(cls, bb, inputs, attr, params):
@@ -3973,13 +3944,13 @@ class ReduceMin(OnnxOpConverter):
# If axes is empty and noop_with_empty_axes is False, reduce all dims
if not axes and not noop_with_empty_axes:
- return _reduce_min_max_preserve_nan(relax.op.min, data, None,
keepdims)
+ return relax.op.min(data, None, keepdims)
# If axes is empty and noop_with_empty_axes is True, return input
unchanged
elif not axes and noop_with_empty_axes:
return data
# Otherwise reduce over specified axes
else:
- return _reduce_min_max_preserve_nan(relax.op.min, data, axes,
keepdims)
+ return relax.op.min(data, axes, keepdims)
class ReduceSum(OnnxOpConverter):
@@ -4621,12 +4592,7 @@ class Sign(OnnxOpConverter):
@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
- x = inputs[0]
- x_dtype = x.struct_info.dtype if isinstance(x.struct_info,
relax.TensorStructInfo) else None
- y = relax.op.sign(x)
- if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
- return relax.op.where(relax.op.isnan(x), x, y)
- return y
+ return relax.op.sign(inputs[0])
class Not(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 791526b584..6e9d4c9d95 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -775,104 +775,6 @@ def test_unary(op_name: str):
verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype,
output_dtype=output_dtype)
-def test_sign_nan_preserve():
- sign_node = helper.make_node("Sign", ["x"], ["y"])
- graph = helper.make_graph(
- [sign_node],
- "sign_nan_test",
- inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4])],
- outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
- )
- model = helper.make_model(graph, producer_name="sign_nan_test")
- model.ir_version = 8
- for opset_import in model.opset_import:
- if opset_import.domain in ["", "ai.onnx"]:
- opset_import.version = 18
- break
- x = np.array([np.nan, 9.0, -9.0, np.nan], dtype=np.float32)
-
- ort_out = onnxruntime.InferenceSession(
- model.SerializeToString(), providers=["CPUExecutionProvider"]
- ).run([], {"x": x})[0]
-
- tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
- out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
-
- np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
- np.testing.assert_allclose(
- out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7,
atol=1e-5
- )
-
-
-def test_relu_nan_preserve():
- relu_node = helper.make_node("Relu", ["x"], ["y"])
- graph = helper.make_graph(
- [relu_node],
- "relu_nan_test",
- inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])],
- outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [5])],
- )
- model = helper.make_model(graph, producer_name="relu_nan_test")
- model.ir_version = 8
- for opset_import in model.opset_import:
- if opset_import.domain in ["", "ai.onnx"]:
- opset_import.version = 18
- break
- x = np.array([np.nan, 9.0, -9.0, 0.0, np.nan], dtype=np.float32)
-
- ort_out = onnxruntime.InferenceSession(
- model.SerializeToString(), providers=["CPUExecutionProvider"]
- ).run([], {"x": x})[0]
-
- tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
- out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
-
- np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
- np.testing.assert_allclose(
- out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7,
atol=1e-5
- )
-
-
[email protected]("op_name", ["ReduceMax", "ReduceMin"])
[email protected](
- "x",
- [
- # NaN in different positions. TVM's max/min fold previously dropped
NaN depending on
- # position, ONNX Runtime only propagates NaN when it is the first
reduced element, which
- # is an order-dependent implementation artifact. We instead adopt the
well-defined,
- # order-independent numpy/IEEE semantics: any NaN in the reduced range
yields NaN.
- np.array([np.nan, 1.0, 2.0], dtype=np.float32),
- np.array([2.0, 1.0, np.nan], dtype=np.float32),
- np.array([1.0, np.nan, 2.0], dtype=np.float32),
- np.array([1.0, 2.0, 3.0], dtype=np.float32),
- ],
-)
-def test_reduce_min_max_nan_preserve(op_name, x):
- reduce_node = helper.make_node(op_name, ["x"], ["y"], keepdims=0)
- graph = helper.make_graph(
- [reduce_node],
- "reduce_nan_test",
- inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
list(x.shape))],
- outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [])],
- )
- model = helper.make_model(graph, producer_name="reduce_nan_test")
- model.ir_version = 8
- for opset_import in model.opset_import:
- if opset_import.domain in ["", "ai.onnx"]:
- opset_import.version = 18
- break
-
- # Reference is numpy (NaN propagates if any element is NaN), not ONNX
Runtime.
- ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
-
- tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
- out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
-
- np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
- if not np.isnan(ref_out):
- np.testing.assert_allclose(out_np, ref_out, rtol=1e-7, atol=1e-5)
-
-
@pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
def test_softmax_family_opset11_default_axis_semantics(op_name: str):
verify_unary(op_name, [2, 3, 4], opset=11)
@@ -1775,11 +1677,10 @@ def test_clip_v6(max, min):
"input",
[
np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32),
- np.array([0.5, -3.0, 4.5, 11.0, np.nan], dtype=np.float32),
],
)
def test_clip_v13(input, min, max):
- # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT);
input NaN preserved.
+ # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT).
clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
graph = helper.make_graph(
[clip_node],