This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82a37dac1d [BugFix][Relax][ONNX] Resolve param Vars in Concat to 
handle mixed Shape/Tensor inputs (#19498)
82a37dac1d is described below

commit 82a37dac1dd404f2cdec75ae74cb62a3f73f11e0
Author: Soowon Jeong <[email protected]>
AuthorDate: Mon May 4 17:34:55 2026 +0900

    [BugFix][Relax][ONNX] Resolve param Vars in Concat to handle mixed 
Shape/Tensor inputs (#19498)
    
    ## Description
    
    When `from_onnx(model, keep_params_in_input=True)` is used, every ONNX
    initializer becomes a `relax.Var` instead of a `relax.Constant`. The
    `Concat` handler's `is_shape_like()` check only recognizes
    `relax.ShapeExpr` and 1D-int64 `relax.Constant`, so a 1D-int64 shape
    value loaded as a Var is no longer recognized.
    
    When such a Var is concatenated with a `ShapeExpr` — the standard
    pattern for dynamic-batch `Reshape` in PyTorch-exported ONNX models —
    the heterogeneous `Tuple(ShapeExpr, Tensor)` is rejected by
    `relax.op.concat` with:
    
    ```
    InternalError: Op(relax.concat) expects the input to be a Tuple of Tensors.
    However, the given input is R.Tuple(R.Shape([N]), R.Tensor((1,), 
dtype="int64"))
    ```
    
    This effectively breaks `keep_params_in_input=True` for any model with
    dynamic-batch `Reshape` (extremely common in PyTorch ONNX exports).
    
    ## Fix
    
    Run each `Concat` input through the existing `get_constant` helper
    before the `is_shape_like` check. This resolves any `Var` that maps to a
    known param back to its baked `Constant`, restoring the all-shape-like
    fast path.
    
    ## Minimal repro
    
    An 8-node ONNX graph (`Shape` → `Slice` → `Concat([dyn_n, [12]])` →
    `Reshape`) fails with `keep_params_in_input=True` before this PR and
    passes after. A regression test (`test_concat_with_param_shape_value`)
    covers this pattern.
    
    ## Testing
    
    ```
    pytest tests/python/relax/test_frontend_onnx.py -k concat
    ```
    
    9 passed (1 new + 8 existing).
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 17 +++++++-
 tests/python/relax/test_frontend_onnx.py        | 53 ++++++++++++++++++++++++-
 2 files changed, 67 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9d65fe0e52..268d91b750 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1014,6 +1014,7 @@ class Concat(OnnxOpConverter):
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
         axis = attr.get("axis", 0)
+        _, param_dict = params
 
         def is_shape_like(x: Any) -> bool:
             if isinstance(x, relax.ShapeExpr):
@@ -1023,10 +1024,22 @@ class Concat(OnnxOpConverter):
             else:
                 return False
 
+        # Resolve 1D-int64 param Vars to constants only for the shape-like
+        # fast path; tensor fallback keeps the original Vars so runtime
+        # weights aren't folded under keep_params_in_input=True.
+        def resolve(x):
+            if isinstance(x, relax.Var) and x.name_hint in param_dict:
+                arr = param_dict[x.name_hint][1].numpy()
+                if arr.ndim == 1 and arr.dtype == _np.int64:
+                    return relax.const(arr, "int64")
+            return x
+
+        resolved = [resolve(inp) for inp in inputs]
+
         # If all inputs are shape expr, perform computation directly.
-        if all([is_shape_like(inp) for inp in inputs]):
+        if all([is_shape_like(inp) for inp in resolved]):
             const_inputs = []
-            for inp in inputs:
+            for inp in resolved:
                 if isinstance(inp, relax.ShapeExpr):
                     const_inputs.extend(inp.values)
                 elif isinstance(inp, relax.Constant):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index db68476609..5a8d84b090 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -29,7 +29,7 @@ import onnx
 import onnxruntime
 import pytest
 import tvm_ffi
-from onnx import ModelProto, TensorProto, helper
+from onnx import ModelProto, TensorProto, helper, numpy_helper
 
 import tvm
 import tvm.testing
@@ -533,6 +533,57 @@ def test_concat():
     verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
 
 
+def test_concat_with_param_shape_value():
+    """Concat must handle a 1D-int64 initializer mixed with a ShapeExpr when
+    keep_params_in_input=True. Standard pattern in PyTorch-exported ONNX
+    models for dynamic-batch Reshape: Reshape(x, Concat(Shape(x)[:1], 
[12]))."""
+    inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, ["N", 3, 4])
+    out = helper.make_tensor_value_info("y", TensorProto.FLOAT, ["N", 12])
+    twelve = numpy_helper.from_array(np.array([12], dtype=np.int64), "twelve")
+    starts = numpy_helper.from_array(np.array([0], dtype=np.int64), "starts")
+    ends = numpy_helper.from_array(np.array([1], dtype=np.int64), "ends")
+    nodes = [
+        helper.make_node("Shape", ["x"], ["x_shape"]),
+        helper.make_node("Slice", ["x_shape", "starts", "ends"], ["dyn_n"]),
+        helper.make_node("Concat", ["dyn_n", "twelve"], ["new_shape"], axis=0),
+        helper.make_node("Reshape", ["x", "new_shape"], ["y"]),
+    ]
+    graph = helper.make_graph(
+        nodes, "concat_param_shape", [inp], [out],
+        initializer=[twelve, starts, ends],
+    )
+    model = helper.make_model(
+        graph, opset_imports=[helper.make_opsetid("", 13)]
+    )
+    model.ir_version = 8
+    onnx.checker.check_model(model)
+    # Both modes should succeed; previously True crashed with
+    # "Op(relax.concat) expects the input to be a Tuple of Tensors".
+    from_onnx(model, keep_params_in_input=False)
+    from_onnx(model, keep_params_in_input=True)
+
+
+def test_concat_with_param_tensor_keeps_runtime_param():
+    """Concat(input, weight) under keep_params_in_input=True must keep `weight`
+    as a runtime param, not fold it into a constant."""
+    weight_np = np.arange(8, dtype=np.float32).reshape(2, 4)
+    graph = helper.make_graph(
+        [helper.make_node("Concat", ["x", "w"], ["y"], axis=0)],
+        "concat_param_tensor",
+        [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4])],
+        [helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 4])],
+        initializer=[numpy_helper.from_array(weight_np, "w")],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
13)])
+    model.ir_version = 8
+    onnx.checker.check_model(model)
+
+    mod, params = relax.frontend.detach_params(from_onnx(model, 
keep_params_in_input=True))
+    assert "w" in [p.name_hint for p in mod["main"].params]
+    assert len(params["main"]) == 1
+    np.testing.assert_array_equal(params["main"][0].numpy(), weight_np)
+
+
 @pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"])
 def test_binary(op_name: str):
     verify_binary(op_name, [1, 32], [1, 32], [1, 32])

Reply via email to