This is an automated email from the ASF dual-hosted git repository.
mshr-h pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7ecf466e33 [S-TIR][Dlight] Add layered fall back strategy to handle
missing attr `max_shared_memory_per_block` (#19453)
7ecf466e33 is described below
commit 7ecf466e338a8e50379558b6f3ff8a0b242350df
Author: Neo Chien <[email protected]>
AuthorDate: Wed Apr 29 23:36:11 2026 +0800
[S-TIR][Dlight] Add layered fall back strategy to handle missing attr
`max_shared_memory_per_block` (#19453)
Hi Committers,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/19419. Any suggestions would be
appreciated if you are available.
### Root Cause
- auto-detected CUDA might lacks `max_shared_memory_per_block` and it
would cause `KeyError`
### Solutions
- Add layered fall back strategy to handle missing attr
`max_shared_memory_per_block`
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/s_tir/dlight/analysis/__init__.py | 1 +
.../tvm/s_tir/dlight/analysis/common_analysis.py | 30 ++++++++++++++++----
python/tvm/s_tir/dlight/gpu/gemv.py | 4 ++-
python/tvm/s_tir/dlight/gpu/low_batch_gemv.py | 4 ++-
tests/python/s_tir/dlight/test_gpu_gemv.py | 32 ++++++++++++++++++++++
.../python/s_tir/dlight/test_gpu_low_batch_gemv.py | 26 ++++++++++++++++++
6 files changed, 90 insertions(+), 7 deletions(-)
diff --git a/python/tvm/s_tir/dlight/analysis/__init__.py
b/python/tvm/s_tir/dlight/analysis/__init__.py
index 36d988b896..9e41003d73 100644
--- a/python/tvm/s_tir/dlight/analysis/__init__.py
+++ b/python/tvm/s_tir/dlight/analysis/__init__.py
@@ -27,6 +27,7 @@ from .common_analysis import (
normalize_prim_func,
get_root_block,
get_sblock_info,
+ get_max_shared_memory_per_block,
)
from .gemv import (
is_gemv,
diff --git a/python/tvm/s_tir/dlight/analysis/common_analysis.py
b/python/tvm/s_tir/dlight/analysis/common_analysis.py
index 2110f55c8e..ec7a025c54 100644
--- a/python/tvm/s_tir/dlight/analysis/common_analysis.py
+++ b/python/tvm/s_tir/dlight/analysis/common_analysis.py
@@ -19,6 +19,8 @@
# pylint: disable=unused-argument, unused-variable
"""Analysis on TIR blocks, loops and functions."""
+import logging
+
from collections import namedtuple
from typing import Literal
@@ -30,6 +32,7 @@ from tvm.s_tir import Schedule
from tvm.s_tir.schedule import SBlockRV
from tvm.target.target import Target
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
class IterInfo:
"""Information about a loop/iter var."""
@@ -362,14 +365,31 @@ def get_max_threads_per_block(target: Target) -> int:
return int(max_threads_per_block)
+TARGET_KIND_TO_DEFAULT_MAX_SMEM = {
+ "cuda": 49152,
+ "rocm": 65536,
+ "metal": 32768,
+ "opencl": 16384,
+ "vulkan": 16384,
+}
+
def get_max_shared_memory_per_block(target: Target) -> int:
_assert_gpu_target(target)
max_shared_memory_per_block =
target.attrs.get("max_shared_memory_per_block", None)
- if max_shared_memory_per_block is None:
- raise ValueError(
- f"Cannot find `max_shared_memory_per_block` in {target}, please
specify it manually"
- )
- return int(max_shared_memory_per_block)
+ if max_shared_memory_per_block is not None:
+ return int(max_shared_memory_per_block)
+
+ # Layered fallback strategy for targets that do not carry this attribute
+ # 1) Use explicit target attrs provided (handled above).
+ # 2) Fall back to backend defaults matching target-kind defaults/tag
defaults.
+ # 3) Use a conservative GPU default as last resort.
+ default_smem = TARGET_KIND_TO_DEFAULT_MAX_SMEM.get(target.kind.name, 16384)
+ logger.warning(
+ "Target %s missing 'max_shared_memory_per_block'; using %d bytes.",
+ target.kind.name,
+ default_smem,
+ )
+ return int(default_smem)
def get_root_block(sch: Schedule, func_name: str = "main") -> SBlockRV:
diff --git a/python/tvm/s_tir/dlight/gpu/gemv.py
b/python/tvm/s_tir/dlight/gpu/gemv.py
index 7e555ce946..efca7541e4 100644
--- a/python/tvm/s_tir/dlight/gpu/gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/gemv.py
@@ -24,6 +24,7 @@ from tvm.target import Target
from ..analysis import (
SBlockInfo,
+ get_max_shared_memory_per_block,
is_broadcast_epilogue,
is_gemv,
normalize,
@@ -156,10 +157,11 @@ class GEMV(GPUScheduleRule):
# is implemented with shared memory.
shared_mem_usage += TS * TR * dtype_bytes
+ max_smem = get_max_shared_memory_per_block(target)
LOAD_V_SHARED = (
LOAD_V_SHARED
and isinstance(shared_mem_usage, tirx.IntImm)
- and shared_mem_usage.value <=
int(target.attrs["max_shared_memory_per_block"])
+ and shared_mem_usage.value <= max_smem
)
# vectorize load A
diff --git a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
index 197e1f897d..15a9f8f506 100644
--- a/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/s_tir/dlight/gpu/low_batch_gemv.py
@@ -27,6 +27,7 @@ from ..analysis import (
SBlockInfo,
collect_block_iter_vars_used_in_access_region,
collect_vars_used_in_prim_expr,
+ get_max_shared_memory_per_block,
is_broadcast_epilogue,
normalize_prim_func,
)
@@ -354,10 +355,11 @@ class LowBatchGEMV(GPUScheduleRule):
lambda x, y: x * y, buf.shape,
tirx.IntImm(buf.shape[0].dtype, 1)
) * get_bytes(buf.dtype)
shared_mem_usage += buf_size
+ max_smem = get_max_shared_memory_per_block(target)
LOAD_V_SHARED = (
LOAD_V_SHARED
and isinstance(shared_mem_usage, tirx.IntImm)
- and shared_mem_usage.value <=
int(target.attrs["max_shared_memory_per_block"])
+ and shared_mem_usage.value <= max_smem
)
# vectorize load A
diff --git a/tests/python/s_tir/dlight/test_gpu_gemv.py
b/tests/python/s_tir/dlight/test_gpu_gemv.py
index 68aabbd095..cfada1bd2e 100644
--- a/tests/python/s_tir/dlight/test_gpu_gemv.py
+++ b/tests/python/s_tir/dlight/test_gpu_gemv.py
@@ -1054,5 +1054,37 @@ def test_func_to_skip():
tvm.ir.assert_structural_equal(mod["main"], before)
+def test_gemv_cuda_target_without_max_shared_memory_per_block():
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(
+ A: T.Buffer((1, 1, 1, 128), "float16"),
+ B: T.Buffer((1, 1, 64, 128), "float16"),
+ C: T.Buffer((1, 1, 1, 64), "float16"),
+ ):
+ T.func_attr({"tirx.noalias": True})
+ for i0, i1, i2, i3, k in T.grid(1, 1, 1, 64, 128):
+ with T.sblock("NT_matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1,
i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_i3, v_k])
+ T.writes(C[v_i0, v_i1, v_i2, v_i3])
+ with T.init():
+ C[v_i0, v_i1, v_i2, v_i3] = T.float16(0)
+ C[v_i0, v_i1, v_i2, v_i3] = C[v_i0, v_i1, v_i2, v_i3] + A[
+ v_i0, v_i1, v_i2, v_k
+ ] * B[v_i0, v_i1, v_i3, v_k]
+
+ # fmt: on
+
+ target = Target({"kind": "cuda", "max_num_threads": 1024})
+ assert target.attrs.get("max_shared_memory_per_block", None) is None
+
+ mod = tvm.IRModule({"main": before})
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+
+ assert mod["main"].attrs["tirx.is_scheduled"] == 1
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
index 83099682b0..c290720327 100644
--- a/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/s_tir/dlight/test_gpu_low_batch_gemv.py
@@ -528,6 +528,32 @@ def test_outer_reduction():
mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod) # pylint:
disable=not-callable
tvm.ir.assert_structural_equal(mod["main"], expected)
+def test_low_batch_gemv_cuda_target_without_max_shared_memory_per_block():
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(var_A: T.handle, B: T.Buffer((T.int64(128), T.int64(128)),
"float16"), var_C: T.handle):
+ T.func_attr({"tir.noalias": True})
+ batch_size = T.int64()
+ A = T.match_buffer(var_A, (batch_size, T.int64(1), T.int64(128)),
"float16")
+ C = T.match_buffer(var_C, (batch_size, T.int64(1), T.int64(128)),
"float16")
+ for i0, i1, i2, k in T.grid(batch_size, T.int64(1), T.int64(128),
T.int64(128)):
+ with T.sblock("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(C[v_i0, v_i1, v_i2])
+ with T.init():
+ C[v_i0, v_i1, v_i2] = T.float16(0)
+ C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k]
* B[v_i2, v_k]
+ # fmt: on
+
+ target = Target({"kind": "cuda", "max_num_threads": 1024})
+ assert target.attrs.get("max_shared_memory_per_block", None) is None
+
+ mod = tvm.IRModule({"main": before})
+ with target:
+ mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+ assert mod["main"].attrs["tirx.is_scheduled"] == 1
+
if __name__ == "__main__":
tvm.testing.main()