This is an automated email from the ASF dual-hosted git repository.

mshr-h pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7ecf466e33 [S-TIR][Dlight] Add layered fall back strategy to handle 
missing attr `max_shared_memory_per_block` (#19453)
7ecf466e33 is described below

commit 7ecf466e338a8e50379558b6f3ff8a0b242350df
Author: Neo Chien <[email protected]>
AuthorDate: Wed Apr 29 23:36:11 2026 +0800

    [S-TIR][Dlight] Add layered fall back strategy to handle missing attr 
`max_shared_memory_per_block` (#19453)
    
    Hi Committers,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/19419. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    - auto-detected CUDA might lacks `max_shared_memory_per_block` and it
    would cause `KeyError`
    
    ### Solutions
    - Add layered fall back strategy to handle missing attr
    `max_shared_memory_per_block`
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/s_tir/dlight/analysis/__init__.py       |  1 +
 .../tvm/s_tir/dlight/analysis/common_analysis.py   | 30 ++++++++++++++++----
 python/tvm/s_tir/dlight/gpu/gemv.py                |  4 ++-
 python/tvm/s_tir/dlight/gpu/low_batch_gemv.py      |  4 ++-
 tests/python/s_tir/dlight/test_gpu_gemv.py         | 32 ++++++++++++++++++++++
 .../python/s_tir/dlight/test_gpu_low_batch_gemv.py | 26 ++++++++++++++++++
 6 files changed, 90 insertions(+), 7 deletions(-)

diff --git a/python/tvm/s_tir/dlight/analysis/__init__.py 
b/python/tvm/s_tir/dlight/analysis/__init__.py
index 36d988b896..9e41003d73 100644
--- a/python/tvm/s_tir/dlight/analysis/__init__.py
+++ b/python/tvm/s_tir/dlight/analysis/__init__.py
@@ -27,6 +27,7 @@ from .common_analysis import (
     normalize_prim_func,
     get_root_block,
     get_sblock_info,
+    get_max_shared_memory_per_block,
 )
 from .gemv import (
     is_gemv,
diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py 
b/python/tvm/s_tir/dlight/analysis/common_analysis.py
index 2110f55c8e..ec7a025c54 100644
--- a/python/tvm/s_tir/dlight/analysis/common_analysis.py
+++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py
@@ -19,6 +19,8 @@
 # pylint: disable=unused-argument, unused-variable
 """Analysis on TIR blocks, loops and functions."""
 
+import logging
+
 from collections import namedtuple
 from typing import Literal
 
@@ -30,6 +32,7 @@ from tvm.s_tir import Schedule
 from tvm.s_tir.schedule import SBlockRV
 from tvm.target.target import Target
 
+logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
 
 class IterInfo:
     """Information about a loop/iter var."""
@@ -362,14 +365,31 @@ def get_max_threads_per_block(target: Target) -> int:
     return int(max_threads_per_block)
 
 
+TARGET_KIND_TO_DEFAULT_MAX_SMEM = {
+    "cuda": 49152,
+    "rocm": 65536,
+    "metal": 32768,
+    "opencl": 16384,
+    "vulkan": 16384,
+}
+
 def get_max_shared_memory_per_block(target: Target) -> int:
     _assert_gpu_target(target)
     max_shared_memory_per_block = 
target.attrs.get("max_shared_memory_per_block", None)
-    if max_shared_memory_per_block is None:
-        raise ValueError(
-            f"Cannot find `max_shared_memory_per_block` in {target}, please 
specify it manually"
-        )
-    return int(max_shared_memory_per_block)
+    if max_shared_memory_per_block is not None:
+        return int(max_shared_memory_per_block)
+
+    # Layered fallback strategy for targets that do not carry this attribute
+    # 1) Use explicit target attrs provided (handled above).
+    # 2) Fall back to backend defaults matching target-kind defaults/tag 
defaults.
+    # 3) Use a conservative GPU default as last resort.
+    default_smem = TARGET_KIND_TO_DEFAULT_MAX_SMEM.get(target.kind.name, 16384)
+    logger.warning(
+        "Target %s missing 'max_shared_memory_per_block'; using %d bytes.",
+        target.kind.name,
+        default_smem,
+    )
+    return int(default_smem)
 
 
 def get_root_block(sch: Schedule, func_name: str = "main") -> SBlockRV:
diff --git a/python/tvm/s_tir/dlight/gpu/gemv.py 
b/python/tvm/s_tir/dlight/gpu/gemv.py
index 7e555ce946..efca7541e4 100644
--- a/python/tvm/s_tir/dlight/gpu/gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/gemv.py
@@ -24,6 +24,7 @@ from tvm.target import Target
 
 from ..analysis import (
     SBlockInfo,
+    get_max_shared_memory_per_block,
     is_broadcast_epilogue,
     is_gemv,
     normalize,
@@ -156,10 +157,11 @@ class GEMV(GPUScheduleRule):
                     # is implemented with shared memory.
                     shared_mem_usage += TS * TR * dtype_bytes
 
+            max_smem = get_max_shared_memory_per_block(target)
             LOAD_V_SHARED = (
                 LOAD_V_SHARED
                 and isinstance(shared_mem_usage, tirx.IntImm)
-                and shared_mem_usage.value <= 
int(target.attrs["max_shared_memory_per_block"])
+                and shared_mem_usage.value <= max_smem
             )
 
             # vectorize load A
diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py 
b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
index 197e1f897d..15a9f8f506 100644
--- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
@@ -27,6 +27,7 @@ from ..analysis import (
     SBlockInfo,
     collect_block_iter_vars_used_in_access_region,
     collect_vars_used_in_prim_expr,
+    get_max_shared_memory_per_block,
     is_broadcast_epilogue,
     normalize_prim_func,
 )
@@ -354,10 +355,11 @@ class LowBatchGEMV(GPUScheduleRule):
                     lambda x, y: x * y, buf.shape, 
tirx.IntImm(buf.shape[0].dtype, 1)
                 ) * get_bytes(buf.dtype)
                 shared_mem_usage += buf_size
+            max_smem = get_max_shared_memory_per_block(target)
             LOAD_V_SHARED = (
                 LOAD_V_SHARED
                 and isinstance(shared_mem_usage, tirx.IntImm)
-                and shared_mem_usage.value <= 
int(target.attrs["max_shared_memory_per_block"])
+                and shared_mem_usage.value <= max_smem
             )
 
             # vectorize load A
diff --git a/tests/python/s_tir/dlight/test_gpu_gemv.py 
b/tests/python/s_tir/dlight/test_gpu_gemv.py
index 68aabbd095..cfada1bd2e 100644
--- a/tests/python/s_tir/dlight/test_gpu_gemv.py
+++ b/tests/python/s_tir/dlight/test_gpu_gemv.py
@@ -1054,5 +1054,37 @@ def test_func_to_skip():
         tvm.ir.assert_structural_equal(mod["main"], before)
 
 
+def test_gemv_cuda_target_without_max_shared_memory_per_block():
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(
+        A: T.Buffer((1, 1, 1, 128), "float16"),
+        B: T.Buffer((1, 1, 64, 128), "float16"),
+        C: T.Buffer((1, 1, 1, 64), "float16"),
+    ):
+        T.func_attr({"tirx.noalias": True})
+        for i0, i1, i2, i3, k in T.grid(1, 1, 1, 64, 128):
+            with T.sblock("NT_matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+                T.writes(C[v_i0, v_i1, v_i2, v_i3])
+                with T.init():
+                    C[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+                C[v_i0, v_i1, v_i2, v_i3] = C[v_i0, v_i1, v_i2, v_i3] + A[
+                    v_i0, v_i1, v_i2, v_k
+                ] * B[v_i0, v_i1, v_i3, v_k]
+
+    # fmt: on
+
+    target = Target({"kind": "cuda", "max_num_threads": 1024})
+    assert target.attrs.get("max_shared_memory_per_block", None) is None
+
+    mod = tvm.IRModule({"main": before})
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+
+    assert mod["main"].attrs["tirx.is_scheduled"] == 1
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py 
b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
index 83099682b0..c290720327 100644
--- a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
@@ -528,6 +528,32 @@ def test_outer_reduction():
         mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)  # pylint: 
disable=not-callable
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
+def test_low_batch_gemv_cuda_target_without_max_shared_memory_per_block():
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(var_A: T.handle, B: T.Buffer((T.int64(128), T.int64(128)), 
"float16"), var_C: T.handle):
+        T.func_attr({"tir.noalias": True})
+        batch_size = T.int64()
+        A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(128)), 
"float16")
+        C = T.match_buffer(var_C, (batch_size, T.int64(1), T.int64(128)), 
"float16")
+        for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(128), 
T.int64(128)):
+            with T.sblock("NT_matmul"):
+                v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+                T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+                T.writes(C[v_i0, v_i1, v_i2])
+                with T.init():
+                    C[v_i0, v_i1, v_i2] = T.float16(0)
+                C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] 
* B[v_i2, v_k]
+    # fmt: on
+
+    target = Target({"kind": "cuda", "max_num_threads": 1024})
+    assert target.attrs.get("max_shared_memory_per_block", None) is None
+
+    mod = tvm.IRModule({"main": before})
+    with target:
+        mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+    assert mod["main"].attrs["tirx.is_scheduled"] == 1
+
 
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to