This is an automated email from the ASF dual-hosted git repository.
guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bddfcadcfb [Relax] Fix matmul and reductions with zero-size dimension
return uninitialized memory (#19680)
bddfcadcfb is described below
commit bddfcadcfbb70bef2196b857db6f4ede0aa057ca
Author: Neo Chien <[email protected]>
AuthorDate: Fri Jun 19 23:14:55 2026 +0800
[Relax] Fix matmul and reductions with zero-size dimension return
uninitialized memory (#19680)
Hi Committers,
This PR fixes issues https://github.com/apache/tvm/issues/19578. Any
suggestions would be appreciated if you are available.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../relax/transform/legalize_ops/linear_algebra.py | 6 ++
.../relax/transform/legalize_ops/statistical.py | 53 ++++++++++++++++--
..._transform_legalize_ops_index_linear_algebra.py | 16 ++++++
...st_transform_legalize_ops_search_statistical.py | 64 ++++++++++++++++++++++
4 files changed, 135 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index d8dd8aa3b0..2b4d1efd10 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -45,6 +45,12 @@ def _matmul(bb: BlockBuilder, call: Call) -> Expr:
b_relax = relax.Var("b", relax.TensorStructInfo(b.shape))
f_infer_sinfo = call.op.get_attr("FInferStructInfo")
output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax),
bb).shape
+ if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0:
+ return te.compute(
+ output_shape,
+ lambda *_: tirx.const(0, call.struct_info.dtype),
+ name="matmul",
+ )
def matmul_compute(*idx_spatial):
k = te.reduce_axis((0, a_shape[-1]), name="k")
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py
b/python/tvm/relax/transform/legalize_ops/statistical.py
index cbad62e448..168cd71399 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -17,15 +17,57 @@
# pylint: disable=invalid-name
"""Default legalization function for statistical operators."""
+from collections.abc import Callable
+
from tvm import te, tirx, topi
from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, ShapeExpr
from .common import LegalizeFunc, TEFunc, register_legalize
-def _statistical(te_func: TEFunc) -> LegalizeFunc:
+def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]:
+ if axis is None:
+ return list(range(ndim))
+
+ axes = []
+ for dim in axis:
+ if isinstance(dim, tirx.IntImm):
+ dim = dim.value
+ dim = int(dim)
+ axes.append(dim + ndim if dim < 0 else dim)
+ return axes
+
+
+def _has_const_zero_reduction_dim(call: Call) -> bool:
+ input_shape = call.args[0].struct_info.shape
+ if not isinstance(input_shape, ShapeExpr):
+ return False
+
+ axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values))
+ return any(
+ isinstance(input_shape.values[dim], tirx.IntImm) and
input_shape.values[dim] == 0
+ for dim in axes
+ )
+
+
+def _statistical(
+ te_func: TEFunc,
+ zero_dim_identity: int | float | bool | Callable[[str], int | float |
bool] | None = None,
+) -> LegalizeFunc:
def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr:
+ if zero_dim_identity is not None and
_has_const_zero_reduction_dim(call):
+ fill_value = (
+ zero_dim_identity(call.struct_info.dtype)
+ if callable(zero_dim_identity)
+ else zero_dim_identity
+ )
+ return bb.call_te(
+ topi.full,
+ call.struct_info.shape.values,
+ call.struct_info.dtype,
+ fill_value,
+ )
return bb.call_te(te_func, call.args[0], call.attrs.axis,
call.attrs.keepdims)
return statistical_call_te
@@ -129,5 +171,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr:
register_legalize("relax.max", _statistical(topi.max))
register_legalize("relax.min", _statistical(topi.min))
-register_legalize("relax.prod", _statistical(topi.prod))
-register_legalize("relax.sum", _statistical(topi.sum))
+register_legalize(
+ "relax.prod",
+ _statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype ==
"bool" else 1),
+)
+register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0))
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index dbd92ba6d3..9b905dd3da 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -1136,6 +1136,22 @@ def test_matmul_batching_dim_1():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_matmul_zero_k_no_reduction():
+ # fmt: off
+ @tvm.script.ir_module
+ class Matmul:
+ @R.function
+ def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3),
"float32")) -> R.Tensor((2, 3), "float32"):
+ gv: R.Tensor((2, 3), "float32") = R.matmul(x, y)
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Matmul)
+ script = mod.script()
+ assert "T.axis.reduce" not in script
+ assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
def test_einsum():
# fmt: off
@I.ir_module(s_tir=True)
diff --git
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 82c478bd51..4a707352d3 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -629,6 +629,70 @@ def test_prod_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_sum_zero_dim_axis_identity():
+ # fmt: off
+ @tvm.script.ir_module
+ class Sum:
+ @R.function
+ def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ gv: R.Tensor((2, 4), "float32") = R.sum(x, axis=[1],
keepdims=False)
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Sum)
+ script = mod.script()
+ assert "T.axis.reduce" not in script
+ assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
+def test_sum_zero_dim_negative_axis_identity():
+ # fmt: off
+ @tvm.script.ir_module
+ class Sum:
+ @R.function
+ def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3),
"float32"):
+ gv: R.Tensor((2, 3), "float32") = R.sum(x, axis=[-1],
keepdims=False)
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Sum)
+ script = mod.script()
+ assert "T.axis.reduce" not in script
+ assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
+def test_prod_zero_dim_axis_identity():
+ # fmt: off
+ @tvm.script.ir_module
+ class Prod:
+ @R.function
+ def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4),
"float32"):
+ gv: R.Tensor((2, 4), "float32") = R.prod(x, axis=[1],
keepdims=False)
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Prod)
+ script = mod.script()
+ assert "T.axis.reduce" not in script
+ assert "T.float32(1)" in script or "T.float32(1.0)" in script
+
+
+def test_prod_bool_zero_dim_axis_identity():
+ # fmt: off
+ @tvm.script.ir_module
+ class Prod:
+ @R.function
+ def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"):
+ gv: R.Tensor((2, 4), "bool") = R.prod(x, axis=[1], keepdims=False)
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Prod)
+ script = mod.script()
+ assert "T.axis.reduce" not in script
+ assert "T.bool(1)" in script or "T.bool(True)" in script
+
+
def test_mean():
# fmt: off
@tvm.script.ir_module