This is an automated email from the ASF dual-hosted git repository.

guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bddfcadcfb [Relax] Fix matmul and reductions with zero-size dimension 
return uninitialized memory (#19680)
bddfcadcfb is described below

commit bddfcadcfbb70bef2196b857db6f4ede0aa057ca
Author: Neo Chien <[email protected]>
AuthorDate: Fri Jun 19 23:14:55 2026 +0800

    [Relax] Fix matmul and reductions with zero-size dimension return 
uninitialized memory (#19680)
    
    Hi Committers,
    
    This PR fixes issues https://github.com/apache/tvm/issues/19578. Any
    suggestions would be appreciated if you are available.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../relax/transform/legalize_ops/linear_algebra.py |  6 ++
 .../relax/transform/legalize_ops/statistical.py    | 53 ++++++++++++++++--
 ..._transform_legalize_ops_index_linear_algebra.py | 16 ++++++
 ...st_transform_legalize_ops_search_statistical.py | 64 ++++++++++++++++++++++
 4 files changed, 135 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/linear_algebra.py 
b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
index d8dd8aa3b0..2b4d1efd10 100644
--- a/python/tvm/relax/transform/legalize_ops/linear_algebra.py
+++ b/python/tvm/relax/transform/legalize_ops/linear_algebra.py
@@ -45,6 +45,12 @@ def _matmul(bb: BlockBuilder, call: Call) -> Expr:
         b_relax = relax.Var("b", relax.TensorStructInfo(b.shape))
         f_infer_sinfo = call.op.get_attr("FInferStructInfo")
         output_shape = f_infer_sinfo(relax.op.matmul(a_relax, b_relax), 
bb).shape
+        if isinstance(a_shape[-1], tirx.IntImm) and a_shape[-1] == 0:
+            return te.compute(
+                output_shape,
+                lambda *_: tirx.const(0, call.struct_info.dtype),
+                name="matmul",
+            )
 
         def matmul_compute(*idx_spatial):
             k = te.reduce_axis((0, a_shape[-1]), name="k")
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py 
b/python/tvm/relax/transform/legalize_ops/statistical.py
index cbad62e448..168cd71399 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -17,15 +17,57 @@
 # pylint: disable=invalid-name
 """Default legalization function for statistical operators."""
 
+from collections.abc import Callable
+
 from tvm import te, tirx, topi
 
 from ...block_builder import BlockBuilder
-from ...expr import Call, Expr
+from ...expr import Call, Expr, ShapeExpr
 from .common import LegalizeFunc, TEFunc, register_legalize
 
 
-def _statistical(te_func: TEFunc) -> LegalizeFunc:
+def _normalize_reduction_axes(axis: list[int] | None, ndim: int) -> list[int]:
+    if axis is None:
+        return list(range(ndim))
+
+    axes = []
+    for dim in axis:
+        if isinstance(dim, tirx.IntImm):
+            dim = dim.value
+        dim = int(dim)
+        axes.append(dim + ndim if dim < 0 else dim)
+    return axes
+
+
+def _has_const_zero_reduction_dim(call: Call) -> bool:
+    input_shape = call.args[0].struct_info.shape
+    if not isinstance(input_shape, ShapeExpr):
+        return False
+
+    axes = _normalize_reduction_axes(call.attrs.axis, len(input_shape.values))
+    return any(
+        isinstance(input_shape.values[dim], tirx.IntImm) and 
input_shape.values[dim] == 0
+        for dim in axes
+    )
+
+
+def _statistical(
+    te_func: TEFunc,
+    zero_dim_identity: int | float | bool | Callable[[str], int | float | 
bool] | None = None,
+) -> LegalizeFunc:
     def statistical_call_te(bb: BlockBuilder, call: Call) -> Expr:
+        if zero_dim_identity is not None and 
_has_const_zero_reduction_dim(call):
+            fill_value = (
+                zero_dim_identity(call.struct_info.dtype)
+                if callable(zero_dim_identity)
+                else zero_dim_identity
+            )
+            return bb.call_te(
+                topi.full,
+                call.struct_info.shape.values,
+                call.struct_info.dtype,
+                fill_value,
+            )
         return bb.call_te(te_func, call.args[0], call.attrs.axis, 
call.attrs.keepdims)
 
     return statistical_call_te
@@ -129,5 +171,8 @@ def _median(bb: BlockBuilder, call: Call) -> Expr:
 
 register_legalize("relax.max", _statistical(topi.max))
 register_legalize("relax.min", _statistical(topi.min))
-register_legalize("relax.prod", _statistical(topi.prod))
-register_legalize("relax.sum", _statistical(topi.sum))
+register_legalize(
+    "relax.prod",
+    _statistical(topi.prod, zero_dim_identity=lambda dtype: True if dtype == 
"bool" else 1),
+)
+register_legalize("relax.sum", _statistical(topi.sum, zero_dim_identity=0))
diff --git 
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py 
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index dbd92ba6d3..9b905dd3da 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -1136,6 +1136,22 @@ def test_matmul_batching_dim_1():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_matmul_zero_k_no_reduction():
+    # fmt: off
+    @tvm.script.ir_module
+    class Matmul:
+        @R.function
+        def main(x: R.Tensor((2, 0), "float32"), y: R.Tensor((0, 3), 
"float32")) -> R.Tensor((2, 3), "float32"):
+            gv: R.Tensor((2, 3), "float32") = R.matmul(x, y)
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Matmul)
+    script = mod.script()
+    assert "T.axis.reduce" not in script
+    assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
 def test_einsum():
     # fmt: off
     @I.ir_module(s_tir=True)
diff --git 
a/tests/python/relax/test_transform_legalize_ops_search_statistical.py 
b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
index 82c478bd51..4a707352d3 100644
--- a/tests/python/relax/test_transform_legalize_ops_search_statistical.py
+++ b/tests/python/relax/test_transform_legalize_ops_search_statistical.py
@@ -629,6 +629,70 @@ def test_prod_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_sum_zero_dim_axis_identity():
+    # fmt: off
+    @tvm.script.ir_module
+    class Sum:
+        @R.function
+        def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            gv: R.Tensor((2, 4), "float32") = R.sum(x, axis=[1], 
keepdims=False)
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Sum)
+    script = mod.script()
+    assert "T.axis.reduce" not in script
+    assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
+def test_sum_zero_dim_negative_axis_identity():
+    # fmt: off
+    @tvm.script.ir_module
+    class Sum:
+        @R.function
+        def main(x: R.Tensor((2, 3, 0), "float32")) -> R.Tensor((2, 3), 
"float32"):
+            gv: R.Tensor((2, 3), "float32") = R.sum(x, axis=[-1], 
keepdims=False)
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Sum)
+    script = mod.script()
+    assert "T.axis.reduce" not in script
+    assert "T.float32(0)" in script or "T.float32(0.0)" in script
+
+
+def test_prod_zero_dim_axis_identity():
+    # fmt: off
+    @tvm.script.ir_module
+    class Prod:
+        @R.function
+        def main(x: R.Tensor((2, 0, 4), "float32")) -> R.Tensor((2, 4), 
"float32"):
+            gv: R.Tensor((2, 4), "float32") = R.prod(x, axis=[1], 
keepdims=False)
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Prod)
+    script = mod.script()
+    assert "T.axis.reduce" not in script
+    assert "T.float32(1)" in script or "T.float32(1.0)" in script
+
+
+def test_prod_bool_zero_dim_axis_identity():
+    # fmt: off
+    @tvm.script.ir_module
+    class Prod:
+        @R.function
+        def main(x: R.Tensor((2, 0, 4), "bool")) -> R.Tensor((2, 4), "bool"):
+            gv: R.Tensor((2, 4), "bool") = R.prod(x, axis=[1], keepdims=False)
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Prod)
+    script = mod.script()
+    assert "T.axis.reduce" not in script
+    assert "T.bool(1)" in script or "T.bool(True)" in script
+
+
 def test_mean():
     # fmt: off
     @tvm.script.ir_module

Reply via email to