This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7eea6df1b6 [Relax][FRONTEND][ONNX] Support Softmax, LogSoftmax and 
Hardmax when opset version ≤12 (#19428)
7eea6df1b6 is described below

commit 7eea6df1b61683d5a415203332cd4b7ee1d507bf
Author: Neo Chien <[email protected]>
AuthorDate: Thu Apr 23 14:02:51 2026 +0800

    [Relax][FRONTEND][ONNX] Support Softmax, LogSoftmax and Hardmax when opset 
version ≤12 (#19428)
    
    Hi Commiters,
    
    This PR is going to fix Softmax-family legacy semantics for opset<=12
    and harden Hardmax fallback.
    
    ### Summary:
    - Implement legacy ONNX semantics for Softmax / LogSoftmax / Hardmax in
    opset <= 12, including flatten-to-2D + reshape-back behavior.
    - Keep opset >= 13 behavior unchanged (_impl_v13 axis-based path).
    - Add compatibility-first fallback warnings for unknown rank/shape
    paths.
    - Harden Hardmax internals: normalize input before struct_info access
    and add backward-compatible helper signature handling.
    - Extend regression coverage for softmax-family across opset 1/11/13 and
    key axis scenarios.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 165 ++++++++++++++++++++++--
 tests/python/relax/test_frontend_onnx.py        |  51 ++++++++
 2 files changed, 208 insertions(+), 8 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5397f2c309..bf65434db0 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -721,9 +721,101 @@ class Sigmoid(OnnxOpConverter):
         return relax.op.sigmoid(inputs[0])
 
 
+def _normalize_legacy_softmax_axis(axis: int, rank: int, op_name: str) -> int:
+    """Normalize axis for ONNX Softmax/LogSoftmax/Hardmax opset <= 12 
semantics.
+
+    Legacy semantics allow axis in [-rank, rank], where axis == rank means the
+    last dimension after flattening has extent 1.
+    """
+
+    if axis < -rank or axis > rank:
+        raise ValueError(f"{op_name} axis {axis} is out of range for rank 
{rank}.")
+    if axis < 0:
+        axis += rank
+    return axis
+
+
+def _shape_product(dims: list[int | tirx.PrimExpr]) -> int | tirx.PrimExpr:
+    """Compute product of a list of shape dims (supports symbolic dims)."""
+
+    prod = 1
+    for dim in dims:
+        if isinstance(dim, tirx.IntImm):
+            dim = int(dim.value)
+        prod = prod * dim
+    return prod
+
+
+def _legacy_softmax_prepare(
+    data: relax.Expr, axis: int, op_name: str
+) -> tuple[relax.Expr, tuple[int | tirx.PrimExpr, ...]] | None:
+    """Build legacy 2D view for Softmax-family opset <= 12 semantics.
+
+    Returns (reshaped_data, original_shape). If rank/shape isn't statically
+    available, returns None so caller can choose a permissive fallback.
+    """
+
+    rank = _get_known_tensor_rank(data)
+    if rank is None:
+        return None
+
+    axis = _normalize_legacy_softmax_axis(axis, rank, op_name)
+    struct_info = data.struct_info
+    if not isinstance(struct_info, relax.TensorStructInfo):
+        return None
+    if not isinstance(struct_info.shape, relax.ShapeExpr):
+        return None
+
+    original_shape = list(struct_info.shape.values)
+    if len(original_shape) != rank:
+        return None
+
+    dim0 = _shape_product(original_shape[:axis])
+    dim1 = _shape_product(original_shape[axis:])
+    flattened = relax.op.reshape(data, (dim0, dim1))
+    return flattened, tuple(original_shape)
+
+
+def _get_axis_extent(
+    data: relax.Expr, axis: int, op_name: str
+) -> tuple[int, int | tirx.PrimExpr]:
+    """Return normalized axis and axis extent when rank/shape are known."""
+
+    rank = _get_known_tensor_rank(data)
+    if rank is None:
+        raise ValueError(f"{op_name} requires a statically known input rank.")
+
+    normalized_axis = _normalize_constant_axes([axis], rank, op_name)[0]
+    struct_info = data.struct_info
+    if isinstance(struct_info, relax.TensorStructInfo) and 
isinstance(struct_info.shape, relax.ShapeExpr):
+        axis_extent = struct_info.shape.values[normalized_axis]
+        if isinstance(axis_extent, tirx.IntImm):
+            axis_extent = int(axis_extent.value)
+        return normalized_axis, axis_extent
+
+    raise ValueError(f"{op_name} requires a statically known axis extent.")
+
+
 class Softmax(OnnxOpConverter):
     """Converts an onnx Softmax node into an equivalent Relax expression."""
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 1)
+        prepared = _legacy_softmax_prepare(inputs[0], axis, "Softmax")
+        if prepared is None:
+            warnings.warn(
+                "Softmax opset<=12 fallback: static rank/shape is unavailable, 
"
+                "falling back to axis-based softmax semantics."
+            )
+            return relax.op.nn.softmax(inputs[0], axis=axis)
+
+        flattened, original_shape = prepared
+        out = relax.op.nn.softmax(flattened, axis=-1)
+        return relax.op.reshape(out, original_shape)
+
+    _impl_v11 = _impl_v1
+
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
         axis = attr.get("axis", -1)
@@ -733,6 +825,23 @@ class Softmax(OnnxOpConverter):
 class LogSoftmax(OnnxOpConverter):
     """Converts an onnx LogSoftmax node into an equivalent Relax expression."""
 
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 1)
+        prepared = _legacy_softmax_prepare(inputs[0], axis, "LogSoftmax")
+        if prepared is None:
+            warnings.warn(
+                "LogSoftmax opset<=12 fallback: static rank/shape is 
unavailable, "
+                "falling back to axis-based log_softmax semantics."
+            )
+            return relax.op.nn.log_softmax(inputs[0], axis=axis)
+
+        flattened, original_shape = prepared
+        out = relax.op.nn.log_softmax(flattened, axis=-1)
+        return relax.op.reshape(out, original_shape)
+
+    _impl_v11 = _impl_v1
+
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
         axis = attr.get("axis", -1)
@@ -743,17 +852,57 @@ class Hardmax(OnnxOpConverter):
     """Converts an onnx Hardmax node into an equivalent Relax expression."""
 
     @classmethod
-    def _impl_v13(cls, bb, inputs, attr, params):
-        axis = attr.get("axis", -1)
-        indices = inputs[0]
-        dtype = indices.struct_info.dtype
-        axis_len = int(inputs[0].struct_info.shape[axis])
-        argmax = relax.op.argmax(indices, axis=axis)
+    def _hardmax_impl(cls, *args):
+        """Hardmax core implementation.
+
+        Compatibility note:
+        - New signature: _hardmax_impl(bb, data, axis)
+        - Legacy signature: _hardmax_impl(data, axis)
+        """
+        if len(args) == 3:
+            bb, data, axis = args
+        elif len(args) == 2:
+            bb = None
+            data, axis = args
+        else:
+            raise TypeError(
+                "Hardmax._hardmax_impl expects (bb, data, axis) or (data, 
axis)."
+            )
+
+        if bb is not None:
+            data = bb.normalize(data)
+        normalized_axis, axis_extent = _get_axis_extent(data, axis, "Hardmax")
+        dtype = data.struct_info.dtype
+        argmax = relax.op.argmax(data, axis=normalized_axis)
         on_value = relax.PrimValue(tvm.tirx.const(1.0, dtype))
         off_value = relax.PrimValue(tvm.tirx.const(0.0, dtype))
+        return relax.op.one_hot(argmax, on_value, off_value, axis_extent, 
normalized_axis)
 
-        one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
-        return one_hot
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", 1)
+        prepared = _legacy_softmax_prepare(inputs[0], axis, "Hardmax")
+        if prepared is None:
+            warnings.warn(
+                "Hardmax opset<=12 fallback: static rank/shape is unavailable, 
"
+                "falling back to axis-based hardmax semantics."
+            )
+            hardmax_input = inputs[0]
+            hardmax_axis = axis
+            original_shape = None
+        else:
+            hardmax_input, original_shape = prepared
+            hardmax_axis = -1
+
+        out = cls._hardmax_impl(bb, hardmax_input, hardmax_axis)
+        return out if original_shape is None else relax.op.reshape(out, 
original_shape)
+
+    _impl_v11 = _impl_v1
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        axis = attr.get("axis", -1)
+        return cls._hardmax_impl(bb, inputs[0], axis)
 
 
 class Transpose(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 7e434d2659..84348c41ca 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -699,6 +699,57 @@ def test_unary(op_name: str):
     verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype, 
output_dtype=output_dtype)
 
 
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_default_axis_semantics(op_name: str):
+    verify_unary(op_name, [2, 3, 4], opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_negative_axis_semantics(op_name: str):
+    verify_unary(op_name, [2, 3, 4], attrs={"axis": -2}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_positive_axis_semantics(op_name: str):
+    verify_unary(op_name, [2, 3, 4], attrs={"axis": 0}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def tes_softmax_family_opset11_axis_equals_rank_semantics(op_name: str):
+    verify_unary(op_name, [2, 3, 4], attrs={"axis": 3}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset13_default_axis_semantics(op_name: str):
+    verify_unary(op_name, [2, 3, 4], opset=13)
+
+
[email protected](
+    "op_name, expected_core_op",
+    [
+        ("Softmax", "relax.nn.softmax"),
+        ("LogSoftmax", "relax.nn.log_softmax"),
+        ("Hardmax", "relax.one_hot"),
+    ],
+)
+def test_softmax_family_opset1_legacy_ir_semantics(op_name: str, 
expected_core_op: str):
+    node = helper.make_node(op_name, ["x"], ["y"])
+    graph = helper.make_graph(
+        [node],
+        "softmax_family_opset1_ir_test",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 
4])],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 
4])],
+    )
+    model = helper.make_model(
+        graph, producer_name="softmax_family_opset1_ir_test", 
opset_imports=[helper.make_opsetid("", 1)]
+    )
+    tvm_model = from_onnx(model, opset=1, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+
+    assert expected_core_op in call_ops
+    assert call_ops.count("relax.reshape") >= 2
+
+
 def test_round_ties_to_even():
     """ONNX Round must use ties-to-even (banker's rounding), not 
ties-away-from-zero.
 

Reply via email to