This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7eea6df1b6 [Relax][FRONTEND][ONNX] Support Softmax, LogSoftmax and
Hardmax when opset version ≤12 (#19428)
7eea6df1b6 is described below
commit 7eea6df1b61683d5a415203332cd4b7ee1d507bf
Author: Neo Chien <[email protected]>
AuthorDate: Thu Apr 23 14:02:51 2026 +0800
[Relax][FRONTEND][ONNX] Support Softmax, LogSoftmax and Hardmax when opset
version ≤12 (#19428)
Hi Commiters,
This PR is going to fix Softmax-family legacy semantics for opset<=12
and harden Hardmax fallback.
### Summary:
- Implement legacy ONNX semantics for Softmax / LogSoftmax / Hardmax in
opset <= 12, including flatten-to-2D + reshape-back behavior.
- Keep opset >= 13 behavior unchanged (_impl_v13 axis-based path).
- Add compatibility-first fallback warnings for unknown rank/shape
paths.
- Harden Hardmax internals: normalize input before struct_info access
and add backward-compatible helper signature handling.
- Extend regression coverage for softmax-family across opset 1/11/13 and
key axis scenarios.
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 165 ++++++++++++++++++++++--
tests/python/relax/test_frontend_onnx.py | 51 ++++++++
2 files changed, 208 insertions(+), 8 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5397f2c309..bf65434db0 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -721,9 +721,101 @@ class Sigmoid(OnnxOpConverter):
return relax.op.sigmoid(inputs[0])
+def _normalize_legacy_softmax_axis(axis: int, rank: int, op_name: str) -> int:
+ """Normalize axis for ONNX Softmax/LogSoftmax/Hardmax opset <= 12
semantics.
+
+ Legacy semantics allow axis in [-rank, rank], where axis == rank means the
+ last dimension after flattening has extent 1.
+ """
+
+ if axis < -rank or axis > rank:
+ raise ValueError(f"{op_name} axis {axis} is out of range for rank
{rank}.")
+ if axis < 0:
+ axis += rank
+ return axis
+
+
+def _shape_product(dims: list[int | tirx.PrimExpr]) -> int | tirx.PrimExpr:
+ """Compute product of a list of shape dims (supports symbolic dims)."""
+
+ prod = 1
+ for dim in dims:
+ if isinstance(dim, tirx.IntImm):
+ dim = int(dim.value)
+ prod = prod * dim
+ return prod
+
+
+def _legacy_softmax_prepare(
+ data: relax.Expr, axis: int, op_name: str
+) -> tuple[relax.Expr, tuple[int | tirx.PrimExpr, ...]] | None:
+ """Build legacy 2D view for Softmax-family opset <= 12 semantics.
+
+ Returns (reshaped_data, original_shape). If rank/shape isn't statically
+ available, returns None so caller can choose a permissive fallback.
+ """
+
+ rank = _get_known_tensor_rank(data)
+ if rank is None:
+ return None
+
+ axis = _normalize_legacy_softmax_axis(axis, rank, op_name)
+ struct_info = data.struct_info
+ if not isinstance(struct_info, relax.TensorStructInfo):
+ return None
+ if not isinstance(struct_info.shape, relax.ShapeExpr):
+ return None
+
+ original_shape = list(struct_info.shape.values)
+ if len(original_shape) != rank:
+ return None
+
+ dim0 = _shape_product(original_shape[:axis])
+ dim1 = _shape_product(original_shape[axis:])
+ flattened = relax.op.reshape(data, (dim0, dim1))
+ return flattened, tuple(original_shape)
+
+
+def _get_axis_extent(
+ data: relax.Expr, axis: int, op_name: str
+) -> tuple[int, int | tirx.PrimExpr]:
+ """Return normalized axis and axis extent when rank/shape are known."""
+
+ rank = _get_known_tensor_rank(data)
+ if rank is None:
+ raise ValueError(f"{op_name} requires a statically known input rank.")
+
+ normalized_axis = _normalize_constant_axes([axis], rank, op_name)[0]
+ struct_info = data.struct_info
+ if isinstance(struct_info, relax.TensorStructInfo) and
isinstance(struct_info.shape, relax.ShapeExpr):
+ axis_extent = struct_info.shape.values[normalized_axis]
+ if isinstance(axis_extent, tirx.IntImm):
+ axis_extent = int(axis_extent.value)
+ return normalized_axis, axis_extent
+
+ raise ValueError(f"{op_name} requires a statically known axis extent.")
+
+
class Softmax(OnnxOpConverter):
"""Converts an onnx Softmax node into an equivalent Relax expression."""
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 1)
+ prepared = _legacy_softmax_prepare(inputs[0], axis, "Softmax")
+ if prepared is None:
+ warnings.warn(
+ "Softmax opset<=12 fallback: static rank/shape is unavailable,
"
+ "falling back to axis-based softmax semantics."
+ )
+ return relax.op.nn.softmax(inputs[0], axis=axis)
+
+ flattened, original_shape = prepared
+ out = relax.op.nn.softmax(flattened, axis=-1)
+ return relax.op.reshape(out, original_shape)
+
+ _impl_v11 = _impl_v1
+
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", -1)
@@ -733,6 +825,23 @@ class Softmax(OnnxOpConverter):
class LogSoftmax(OnnxOpConverter):
"""Converts an onnx LogSoftmax node into an equivalent Relax expression."""
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 1)
+ prepared = _legacy_softmax_prepare(inputs[0], axis, "LogSoftmax")
+ if prepared is None:
+ warnings.warn(
+ "LogSoftmax opset<=12 fallback: static rank/shape is
unavailable, "
+ "falling back to axis-based log_softmax semantics."
+ )
+ return relax.op.nn.log_softmax(inputs[0], axis=axis)
+
+ flattened, original_shape = prepared
+ out = relax.op.nn.log_softmax(flattened, axis=-1)
+ return relax.op.reshape(out, original_shape)
+
+ _impl_v11 = _impl_v1
+
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", -1)
@@ -743,17 +852,57 @@ class Hardmax(OnnxOpConverter):
"""Converts an onnx Hardmax node into an equivalent Relax expression."""
@classmethod
- def _impl_v13(cls, bb, inputs, attr, params):
- axis = attr.get("axis", -1)
- indices = inputs[0]
- dtype = indices.struct_info.dtype
- axis_len = int(inputs[0].struct_info.shape[axis])
- argmax = relax.op.argmax(indices, axis=axis)
+ def _hardmax_impl(cls, *args):
+ """Hardmax core implementation.
+
+ Compatibility note:
+ - New signature: _hardmax_impl(bb, data, axis)
+ - Legacy signature: _hardmax_impl(data, axis)
+ """
+ if len(args) == 3:
+ bb, data, axis = args
+ elif len(args) == 2:
+ bb = None
+ data, axis = args
+ else:
+ raise TypeError(
+ "Hardmax._hardmax_impl expects (bb, data, axis) or (data,
axis)."
+ )
+
+ if bb is not None:
+ data = bb.normalize(data)
+ normalized_axis, axis_extent = _get_axis_extent(data, axis, "Hardmax")
+ dtype = data.struct_info.dtype
+ argmax = relax.op.argmax(data, axis=normalized_axis)
on_value = relax.PrimValue(tvm.tirx.const(1.0, dtype))
off_value = relax.PrimValue(tvm.tirx.const(0.0, dtype))
+ return relax.op.one_hot(argmax, on_value, off_value, axis_extent,
normalized_axis)
- one_hot = relax.op.one_hot(argmax, on_value, off_value, axis_len, axis)
- return one_hot
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", 1)
+ prepared = _legacy_softmax_prepare(inputs[0], axis, "Hardmax")
+ if prepared is None:
+ warnings.warn(
+ "Hardmax opset<=12 fallback: static rank/shape is unavailable,
"
+ "falling back to axis-based hardmax semantics."
+ )
+ hardmax_input = inputs[0]
+ hardmax_axis = axis
+ original_shape = None
+ else:
+ hardmax_input, original_shape = prepared
+ hardmax_axis = -1
+
+ out = cls._hardmax_impl(bb, hardmax_input, hardmax_axis)
+ return out if original_shape is None else relax.op.reshape(out,
original_shape)
+
+ _impl_v11 = _impl_v1
+
+ @classmethod
+ def _impl_v13(cls, bb, inputs, attr, params):
+ axis = attr.get("axis", -1)
+ return cls._hardmax_impl(bb, inputs[0], axis)
class Transpose(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7e434d2659..84348c41ca 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -699,6 +699,57 @@ def test_unary(op_name: str):
verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype,
output_dtype=output_dtype)
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_default_axis_semantics(op_name: str):
+ verify_unary(op_name, [2, 3, 4], opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_negative_axis_semantics(op_name: str):
+ verify_unary(op_name, [2, 3, 4], attrs={"axis": -2}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset11_positive_axis_semantics(op_name: str):
+ verify_unary(op_name, [2, 3, 4], attrs={"axis": 0}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def tes_softmax_family_opset11_axis_equals_rank_semantics(op_name: str):
+ verify_unary(op_name, [2, 3, 4], attrs={"axis": 3}, opset=11)
+
+
[email protected]("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
+def test_softmax_family_opset13_default_axis_semantics(op_name: str):
+ verify_unary(op_name, [2, 3, 4], opset=13)
+
+
[email protected](
+ "op_name, expected_core_op",
+ [
+ ("Softmax", "relax.nn.softmax"),
+ ("LogSoftmax", "relax.nn.log_softmax"),
+ ("Hardmax", "relax.one_hot"),
+ ],
+)
+def test_softmax_family_opset1_legacy_ir_semantics(op_name: str,
expected_core_op: str):
+ node = helper.make_node(op_name, ["x"], ["y"])
+ graph = helper.make_graph(
+ [node],
+ "softmax_family_opset1_ir_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3,
4])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3,
4])],
+ )
+ model = helper.make_model(
+ graph, producer_name="softmax_family_opset1_ir_test",
opset_imports=[helper.make_opsetid("", 1)]
+ )
+ tvm_model = from_onnx(model, opset=1, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+
+ assert expected_core_op in call_ops
+ assert call_ops.count("relax.reshape") >= 2
+
+
def test_round_ties_to_even():
"""ONNX Round must use ties-to-even (banker's rounding), not
ties-away-from-zero.