This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6569cf0ac9 [Relax][ONNX] Fix CumSum axis handling: support runtime 
axis tensor, error on multi-element axis (#19467)
6569cf0ac9 is described below

commit 6569cf0ac968f237ca6ad5f19216577cd5cb6230
Author: Neo Chien <[email protected]>
AuthorDate: Fri May 1 19:01:42 2026 +0800

    [Relax][ONNX] Fix CumSum axis handling: support runtime axis tensor, error 
on multi-element axis (#19467)
    
    Hi Committers,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/19437. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    The original CumSum converter always defaulted to axis=0 when the axis
    input was a relax.Var (i.e., a runtime tensor), ignoring the actual
    runtime value. This led to incorrect behavior and did not comply with
    the ONNX specification.
    
    ### Solutions
    Update CumSum._impl_v14 to:
    - Check if the axis input is a Constant: require it to have exactly one
    element, otherwise raise an error.
    - If the axis input is a relax.Var, raise an error instead of always
    defaulted to axis=0.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 32 ++++++++++++-----
 tests/python/relax/test_frontend_onnx.py        | 47 +++++++++++++++++++++++--
 2 files changed, 69 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9f7c09a953..9d65fe0e52 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1748,22 +1748,38 @@ class CumSum(OnnxOpConverter):
     @classmethod
     def _impl_v14(cls, bb, inputs, attr, params):
         data = inputs[0]
-        axis = get_constant(inputs[1], params)
+        axis_input = get_constant(inputs[1], params)
         assert not attr.get("exclusive", False), "Exclusive option not yet 
supported."
 
-        if isinstance(axis, relax.Constant):
-            axis = int(axis.data.numpy())
-        elif isinstance(axis, relax.Var):
-            axis = 0
-
+        if isinstance(axis_input, relax.Constant):
+            axis_data = axis_input.data.numpy()
+            if axis_data.ndim == 0:
+                axis = int(axis_data.item())
+            elif axis_data.ndim == 1 and axis_data.shape[0] == 1:
+                axis = int(axis_data.item())
+            else:
+                raise ValueError(
+                    "CumSum axis input must be a scalar (0-D) or a 
single-element 1-D tensor, "
+                    "got shape {}".format(axis_data.shape)
+                )
+        elif isinstance(axis_input, relax.Var):
+            axis_shape = axis_input.struct_info.shape if 
hasattr(axis_input.struct_info, "shape") else None
+            raise ValueError(
+                "CumSum with non-constant axis input is not supported yet. "
+                "ONNX permits runtime axis tensors, but Relax/TE currently 
requires a compile-time "
+                "constant axis for cumsum/flip. Got axis shape 
{}".format(axis_shape)
+            )
+        else:
+            raise TypeError("CumSum axis input must be a Constant or Var")
+            
         if attr.get("reverse", 0) != 0:
-            data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
+            data = bb.emit_te(topi.flip, data, axis=axis)
 
         data = relax.op.cumsum(data, axis)
         data = bb.normalize(data)
 
         if attr.get("reverse", 0) != 0:
-            data = bb.emit_te(topi.flip, data, axis=axis if axis else 0)
+            data = bb.emit_te(topi.flip, data, axis=axis)
 
         return data
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 136016e823..db68476609 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1687,13 +1687,56 @@ def test_cumsum1():
         "cumsum_graph",
         inputs=[
             helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
-            helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], 
"axis"),
         ],
+        initializer=[helper.make_tensor("axis", onnx.TensorProto.INT32, [1], 
[0])],
         outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, 
input_shape)],
     )
 
     model = helper.make_model(graph, producer_name="cumsum_graph")
-    check_correctness(model, inputs={"axis": np.array([0], dtype=np.int32)})
+    check_correctness(model)
+
+
+def test_cumsum_dynamic_axis_not_supported():
+    input_shape = [2, 3]
+
+    graph = helper.make_graph(
+        [
+            helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+        ],
+        "cumsum_dynamic_axis_graph",
+        inputs=[
+            helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
+            helper.make_tensor_value_info("axis", onnx.TensorProto.INT32, [1], 
"axis"),
+        ],
+        outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, 
input_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="cumsum_dynamic_axis_graph")
+    with pytest.raises(ValueError, match="non-constant axis input is not 
supported"):
+        from_onnx(model, opset=14, keep_params_in_input=True)
+
+
+def test_cumsum_axis_shape_validation():
+    input_shape = [2, 3]
+
+    graph = helper.make_graph(
+        [
+            helper.make_node("CumSum", inputs=["X", "axis"], outputs=["Y"]),
+        ],
+        "cumsum_invalid_axis_shape_graph",
+        inputs=[
+            helper.make_tensor_value_info("X", onnx.TensorProto.DOUBLE, 
input_shape),
+        ],
+        initializer=[helper.make_tensor("axis", onnx.TensorProto.INT64, [2], 
[0, 1])],
+        outputs=[helper.make_tensor_value_info("Y", onnx.TensorProto.DOUBLE, 
input_shape)],
+    )
+
+    model = helper.make_model(graph, 
producer_name="cumsum_invalid_axis_shape_graph")
+    with pytest.raises(
+        ValueError, 
+        match="axis input must be a scalar \(0-D\) or a single-element 1-D 
tensor",
+    ):
+        from_onnx(model, opset=14, keep_params_in_input=True)
 
 
 @pytest.mark.parametrize("axis", [[0, 2], None])

Reply via email to