This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6569cf0ac9 [Relax][ONNX] Fix CumSum axis handling: support runtime
axis tensor, error on multi-element axis (#19467)
6569cf0ac9 is described below
commit 6569cf0ac968f237ca6ad5f19216577cd5cb6230
Author: Neo Chien <[email protected]>
AuthorDate: Fri May 1 19:01:42 2026 +0800
[Relax][ONNX] Fix CumSum axis handling: support runtime axis tensor, error
on multi-element axis (#19467)
Hi Committers,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/19437. Any suggestions would be
appreciated if you are available.
### Root Cause
The original CumSum converter always defaulted to axis=0 when the axis
input was a relax.Var (i.e., a runtime tensor), ignoring the actual
runtime value. This led to incorrect behavior and did not comply with
the ONNX specification.
### Solutions
Update CumSum._impl_v14 to:
- Check if the axis input is a Constant: require it to have exactly one
element, otherwise raise an error.
- If the axis input is a relax.Var, raise an error instead of always
defaulted to axis=0.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 32 ++++++++++++-----
tests/python/relax/test_frontend_onnx.py | 47 +++++++++++++++++++++++--
2 files changed, 69 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9f7c09a953..9d65fe0e52 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1748,22 +1748,38 @@ class CumSum(OnnxOpConverter):
@classmethod
def _impl_v14(cls, bb, inputs, attr, params):
data = inputs[0]
- axis = get_constant(inputs[1], params)
+ axis_input = get_constant(inputs[1], params)
assert not attr.get("exclusive", False), "Exclusive option not yet
supported."
- if isinstance(axis, relax.Constant):
- axis = int(axis.data.numpy())
- elif isinstance(axis, relax.Var):
- axis = 0
-
+ if isinstance(axis_input, relax.Constant):
+ axis_data = axis_input.data.numpy()
+ if axis_data.ndim == 0:
+ axis = int(axis_data.item())
+ elif axis_data.ndim == 1 and axis_data.shape[0] == 1:
+ axis = int(axis_data.item())
+ else:
+ raise ValueError(
+ "CumSum axis input must be a scalar (0-D) or a
single-element 1-D tensor, "
+ "got shape {}".format(axis_data.shape)
+ )
+ elif isinstance(axis_input, relax.Var):
+ axis_shape = axis_input.struct_info.shape if
hasattr(axis_input.struct_info, "shape") else None
+ raise ValueError(
+ "CumSum with non-constant axis input is not supported yet. "
+ "ONNX permits runtime axis tensors, but Relax/TE currently
requires a compile-time "
+ "constant axis for cumsum/flip. Got axis shape
{}".format(axis_shape)
+ )
+ else:
+ raise TypeError("CumSum axis input must be a Constant or Var")
+
if attr.get("reverse", 0) != 0:
- data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
+ data = bb.emit_te(topi.flip, data, axis=axis)
data = relax.op.cumsum(data, axis)
data = bb.normalize(data)
if attr.get("reverse", 0) != 0:
- data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
+ data = bb.emit_te(topi.flip, data, axis=axis)
return data
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 136016e823..db68476609 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1687,13 +1687,56 @@ def test_cumsum1():
"cumsum_graph",
inputs=[
helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
- helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1],
"axis"),
],
+ initializer=[helper.make_tensor("axis", onnx.TensorProto.INT32, [1],
[0])],
outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE,
input_shape)],
)
model = helper.make_model(graph, producer_name="cumsum_graph")
- check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)})
+ check_correctness(model)
+
+
+def test_cumsum_dynamic_axis_not_supported():
+ input_shape = [2, 3]
+
+ graph = helper.make_graph(
+ [
+ helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+ ],
+ "cumsum_dynamic_axis_graph",
+ inputs=[
+ helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
+ helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1],
"axis"),
+ ],
+ outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE,
input_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="cumsum_dynamic_axis_graph")
+ with pytest.raises(ValueError, match="non-constant axis input is not
supported"):
+ from_onnx(model, opset=14, keep_params_in_input=True)
+
+
+def test_cumsum_axis_shape_validation():
+ input_shape = [2, 3]
+
+ graph = helper.make_graph(
+ [
+ helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+ ],
+ "cumsum_invalid_axis_shape_graph",
+ inputs=[
+ helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE,
input_shape),
+ ],
+ initializer=[helper.make_tensor("axis", onnx.TensorProto.INT64, [2],
[0, 1])],
+ outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE,
input_shape)],
+ )
+
+ model = helper.make_model(graph,
producer_name="cumsum_invalid_axis_shape_graph")
+ with pytest.raises(
+ ValueError,
+ match="axis input must be a scalar \(0-D\) or a single-element 1-D
tensor",
+ ):
+ from_onnx(model, opset=14, keep_params_in_input=True)
@pytest.mark.parametrize("axis", [[0, 2], None])