This is an automated email from the ASF dual-hosted git repository.

guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ef1904a3e5 [CI] Pin GitHub Actions to SHA for ASF INFRA compliance 
(#19793)
ef1904a3e5 is described below

commit ef1904a3e599530fdfeb819df966877bcdfc864d
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jun 17 00:29:10 2026 +0800

    [CI] Pin GitHub Actions to SHA for ASF INFRA compliance (#19793)
    
    ## Why
    
    ASF INFRA enforces that external GitHub Actions must be pinned to a
    commit SHA on the approved allowlist, failing the workflow with "not
    allowed in apache/tvm". See the
    [policy](https://infra.apache.org/github-actions-policy.html) and the
    [approved
    
allowlist](https://github.com/apache/infrastructure-actions/blob/main/approved_patterns.yml).
    
    ## How
    
    - Pin `pre-commit/action` to `2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd`
    (v3.0.1)
    - Pin `pypa/cibuildwheel` to `294735312765b09d24a2fbec22660ce817587d55`
    (v4.1.0)
    - Pin `pypa/gh-action-pypi-publish` to
    `ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e` (v1.13.0)
    - Leave GitHub-owned `actions/*` and the allowlisted
    `conda-incubator/setup-miniconda@*` pattern untouched
    
    ---------
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 .github/actions/build-wheel-for-publish/action.yml |   2 +-
 .github/workflows/lint.yml                         |   2 +-
 .github/workflows/publish_wheel.yml                |   2 +-
 apps/cpp_rpc/rpc_env.cc                            |  14 ++-
 apps/cpp_rpc/rpc_server.cc                         |   5 +-
 python/tvm/backend/cuda/op.py                      |   1 -
 python/tvm/backend/cuda/script.py                  |   4 +-
 .../extra/contrib/tensorrt/tensorrt_builder.cc     |   5 +-
 src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc |   6 +-
 .../meta_schedule/task_scheduler/task_scheduler.cc |   8 +-
 src/target/llvm/codegen_llvm.cc                    |   7 +-
 src/target/llvm/codegen_params.cc                  |   4 +-
 src/tirx/transform/vectorize_loop.cc               |  10 +-
 tests/python/codegen/test_target_codegen_riscv.py  |   1 +
 tests/python/relax/test_frontend_onnx.py           |   2 +-
 .../tile_primitive/cuda/elementwise/test_unary.py  | 104 +++++++++++++++++----
 16 files changed, 115 insertions(+), 62 deletions(-)

diff --git a/.github/actions/build-wheel-for-publish/action.yml 
b/.github/actions/build-wheel-for-publish/action.yml
index e718442379..d4a3ca14c2 100644
--- a/.github/actions/build-wheel-for-publish/action.yml
+++ b/.github/actions/build-wheel-for-publish/action.yml
@@ -108,7 +108,7 @@ runs:
 
     # ---- Build and test wheels ----
     - name: Build and test wheels
-      uses: pypa/[email protected]
+      uses: pypa/cibuildwheel@294735312765b09d24a2fbec22660ce817587d55  # 
v4.1.0
       with:
         package-dir: .
         output-dir: wheelhouse
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 6c17e0f149..16fa502aaa 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -35,4 +35,4 @@ jobs:
         with:
           fetch-depth: 0
           fetch-tags: true
-      - uses: pre-commit/[email protected]
+      - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd  # 
v3.0.1
diff --git a/.github/workflows/publish_wheel.yml 
b/.github/workflows/publish_wheel.yml
index 63375e6063..a2fedda9e0 100644
--- a/.github/workflows/publish_wheel.yml
+++ b/.github/workflows/publish_wheel.yml
@@ -213,7 +213,7 @@ jobs:
 
       - name: Publish package distributions to PyPI
         if: ${{ inputs.publish_repository == 'pypi' }}
-        uses: pypa/[email protected]
+        uses: 
pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e  # v1.13.0
         with:
           attestations: true
           verbose: true
diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc
index b0b1fe4064..4df5f87024 100644
--- a/apps/cpp_rpc/rpc_env.cc
+++ b/apps/cpp_rpc/rpc_env.cc
@@ -158,8 +158,7 @@ std::string RPCEnv::GetPath(const std::string& file_name) 
const {
  */
 void RPCEnv::CleanUp() const {
   CleanDir(base_);
-  if (!CheckPath(base_))
-    return;
+  if (!CheckPath(base_)) return;
   const int ret = rmdir(base_.c_str());
   if (ret != 0) {
     LOG(WARNING) << "Remove directory " << base_ << " failed";
@@ -325,11 +324,11 @@ std::string BuildSharedLibrary(std::string file) {
  */
 bool CheckPath(const std::string& pathname) {
 #if defined(_WIN32)
-    DWORD attribs = GetFileAttributesA(pathname.c_str());
-    return (attribs != INVALID_FILE_ATTRIBUTES);
+  DWORD attribs = GetFileAttributesA(pathname.c_str());
+  return (attribs != INVALID_FILE_ATTRIBUTES);
 #else
-    struct stat info;
-    return (stat(pathname.c_str(), &info) == 0);
+  struct stat info;
+  return (stat(pathname.c_str(), &info) == 0);
 #endif
 }
 
@@ -338,8 +337,7 @@ bool CheckPath(const std::string& pathname) {
  * \param dirname The name of the directory
  */
 void CleanDir(const std::string& dirname) {
-  if (!CheckPath(dirname))
-    return;
+  if (!CheckPath(dirname)) return;
   auto files = ListDir(dirname);
   for (const auto& filename : files) {
     std::string file_path = dirname + "/";
diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc
index b601478f43..88971cc34c 100644
--- a/apps/cpp_rpc/rpc_server.cc
+++ b/apps/cpp_rpc/rpc_server.cc
@@ -210,7 +210,7 @@ class RPCServer {
                     << ", status = " << status_second;
         } else if (finished_first == worker_pid) {
           LOG(INFO) << "Child pid=" << worker_pid << " finished"
-                    << ", status = "<< status_first;
+                    << ", status = " << status_first;
         }
       } else {
         auto pid = fork();
@@ -334,8 +334,7 @@ class RPCServer {
     RPCServerLoop(int(sock.sockfd));
     const auto e_time = std::chrono::high_resolution_clock::now();
     std::chrono::duration<double> elapsed = e_time - s_time;
-    LOG(INFO) << "Finished serving " << addr.AsString()
-              << " after " << elapsed.count() << " sec";
+    LOG(INFO) << "Finished serving " << addr.AsString() << " after " << 
elapsed.count() << " sec";
     env.CleanUp();
   }
 
diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py
index 9570e26662..bb3c59599e 100644
--- a/python/tvm/backend/cuda/op.py
+++ b/python/tvm/backend/cuda/op.py
@@ -694,7 +694,6 @@ def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
     return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True, 
count)
 
 
-
 def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
     """TVM intrinsic to call
         mbarrier.arrive_expect_tx.shared::cta.b64
diff --git a/python/tvm/backend/cuda/script.py 
b/python/tvm/backend/cuda/script.py
index a46aa7e7e4..76ba87344b 100644
--- a/python/tvm/backend/cuda/script.py
+++ b/python/tvm/backend/cuda/script.py
@@ -278,9 +278,7 @@ class MbarrierNamespace:
         self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init)
         self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait)
         self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once)
-        self.try_wait_acquire_cluster = _op_wrapper(
-            _cuda_op.ptx_mbarrier_try_wait_acquire_cluster
-        )
+        self.try_wait_acquire_cluster = 
_op_wrapper(_cuda_op.ptx_mbarrier_try_wait_acquire_cluster)
         self.arrive = MbarrierArriveNamespace()
 
 
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc 
b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
index f0c2a26b2e..281d64cfbc 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
@@ -201,9 +201,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
     delete runtime;
     TVM_FFI_THROW(InternalError) << "Failed to deserialize the TensorRT 
engine.";
   }
-  TVM_FFI_ICHECK_EQ(
-      engine->getNbIOTensors(),
-      static_cast<int32_t>(network_input_names_.size() + 
network_output_names_.size()));
+  TVM_FFI_ICHECK_EQ(engine->getNbIOTensors(), 
static_cast<int32_t>(network_input_names_.size() +
+                                                                   
network_output_names_.size()));
   nvinfer1::IExecutionContext* context = engine->createExecutionContext();
   CleanUp();
 
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc 
b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
index d3e68778fd..00ca3cea96 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
@@ -422,9 +422,9 @@ class DenseOpConverter : public TensorRTOpConverter {
                               
->addConstant(VectorToTrtDims(params->inputs.at(1).weight_shape),
                                             params->inputs.at(1).weight)
                               ->getOutput(0);
-    auto* matmul_layer = params->network->addMatrixMultiply(
-        *input_tensor, nvinfer1::MatrixOperation::kNONE, *weight_tensor,
-        nvinfer1::MatrixOperation::kTRANSPOSE);
+    auto* matmul_layer =
+        params->network->addMatrixMultiply(*input_tensor, 
nvinfer1::MatrixOperation::kNONE,
+                                           *weight_tensor, 
nvinfer1::MatrixOperation::kTRANSPOSE);
     TVM_FFI_ICHECK(matmul_layer != nullptr);
     params->outputs.push_back(matmul_layer->getOutput(0));
   }
diff --git a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc 
b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
index 76b407b5cf..3d7fadd40a 100644
--- a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
@@ -208,15 +208,13 @@ void TaskSchedulerNode::Tune(ffi::Array<TuneContext> 
ctxs, ffi::Array<FloatImm>
       int n_build_errs = 0;
       const ffi::Array<BuilderResult>& builder_results = 
task->builder_results.value();
       for (int i = 0; i < num_candidates; i++) {
-        if (builder_results[i]->error_msg.has_value())
-          ++n_build_errs;
+        if (builder_results[i]->error_msg.has_value()) ++n_build_errs;
       }
       if (n_build_errs > 0) {
         TVM_PY_LOG(INFO, this->logger) << "Build errors: " << n_build_errs << 
" sample(s)";
       }
-      TVM_PY_LOG(INFO, this->logger) << "Sending "
-                                     << num_candidates - n_build_errs
-                                     << " valid sample(s) to runner";
+      TVM_PY_LOG(INFO, this->logger)
+          << "Sending " << num_candidates - n_build_errs << " valid sample(s) 
to runner";
       SendToRunner(task, runner);
     } else {
       TerminateTask(task_id);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index f32dcdde11..912f8ec8c0 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -269,10 +269,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const 
GlobalVar& gvar, cons
       user_symbol = 
user_symbol.substr(std::char_traits<char>::length(kFFISymbolPrefix));
     }
     TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" << 
user_symbol
-                                 << "' in LLVM codegen: IRModule keys '" << 
it->second
-                                 << "' and '" << gvar->name_hint
-                                 << "' both lower to the same exported symbol 
'" << symbol_name
-                                 << "'. "
+                                 << "' in LLVM codegen: IRModule keys '" << 
it->second << "' and '"
+                                 << gvar->name_hint << "' both lower to the 
same exported symbol '"
+                                 << symbol_name << "'. "
                                  << "Each exposed PrimFunc in one IRModule 
must have a unique "
                                     "global_symbol.";
   }
diff --git a/src/target/llvm/codegen_params.cc 
b/src/target/llvm/codegen_params.cc
index 6d8684a87e..0633c4fcb3 100644
--- a/src/target/llvm/codegen_params.cc
+++ b/src/target/llvm/codegen_params.cc
@@ -61,8 +61,8 @@ struct LLVMConstantGetter<T, 
std::enable_if_t<std::is_floating_point<T>::value>>
   static llvm::Constant* getElement(llvm::Type* ty, T t) { return 
llvm::ConstantFP::get(ty, t); }
 };
 
-template <typename T,
-          typename = std::enable_if_t<std::is_standard_layout<T>::value && 
std::is_trivial<T>::value>>
+template <typename T, typename = 
std::enable_if_t<std::is_standard_layout<T>::value &&
+                                                  std::is_trivial<T>::value>>
 void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t 
num_elements,
                      std::vector<llvm::Constant*>* elements) {
   elements->resize(num_elements, nullptr);
diff --git a/src/tirx/transform/vectorize_loop.cc 
b/src/tirx/transform/vectorize_loop.cc
index fe6734863b..e746c6ac95 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -71,9 +71,8 @@ bool TargetHasVLA(Target target) {
 }
 
 bool ContainsCallNode(const Stmt& stmt) {
-  return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) {
-    return expr.as<CallNode>() != nullptr;
-  });
+  return CheckContains::StmtContains(
+      stmt, [](const PrimExpr& expr) { return expr.as<CallNode>() != nullptr; 
});
 }
 }  // namespace
 
@@ -1067,9 +1066,8 @@ class LoopVectorizer : public StmtMutator {
     PrimExpr index = outer * scalable_lanes_index + inner_index;
     Stmt body = Substitute(op->body, {{op->loop_var, index}});
     Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt, 
op->span);
-    Stmt vector_loop =
-        For(inner, make_const(lane_dtype, 0), scalable_lanes, 
ForKind::kVectorized, guarded_body,
-            std::nullopt, op->annotations, std::nullopt, op->span);
+    Stmt vector_loop = For(inner, make_const(lane_dtype, 0), scalable_lanes, 
ForKind::kVectorized,
+                           guarded_body, std::nullopt, op->annotations, 
std::nullopt, op->span);
     Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop, 
std::nullopt, {},
                     std::nullopt, op->span);
 
diff --git a/tests/python/codegen/test_target_codegen_riscv.py 
b/tests/python/codegen/test_target_codegen_riscv.py
index 3ac75dc337..5447fddcde 100644
--- a/tests/python/codegen/test_target_codegen_riscv.py
+++ b/tests/python/codegen/test_target_codegen_riscv.py
@@ -17,6 +17,7 @@
 # ruff: noqa: E501, F841
 
 import re
+
 import pytest
 
 import tvm
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index db8b977efc..414c3d5bbf 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -837,7 +837,7 @@ def test_reduce_min_max_nan_preserve(op_name, x):
     ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
 
     tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
-    out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else 
tvm_out).numpy()
+    out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else 
tvm_out).numpy()
 
     np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
     if not np.isnan(ref_out):
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py 
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
index fb70b37541..97a1be256e 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
@@ -1326,14 +1326,28 @@ def test_cast_wg_rejects_thread_local_view():
 
     @T.prim_func
     def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
-        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
-        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        A = T.match_buffer(
+            A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
+        B = T.match_buffer(
+            B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
         T.device_entry()
         _bx = T.cta_id([1])
         _wg = T.warpgroup_id([1])
         tid = T.thread_id_in_wg([_SL_ROWS])
-        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
-        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float32",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
+        dst = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float16",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
         src_row = src.local(_SL_COLS)
         for i in T.serial(_SL_COLS):
             src_row[i] = A[tid, i]
@@ -1351,13 +1365,27 @@ def test_cast_cta_rejects_thread_local_view():
 
     @T.prim_func
     def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
-        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
-        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        A = T.match_buffer(
+            A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
+        B = T.match_buffer(
+            B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
         T.device_entry()
         _bx = T.cta_id([1])
         tx_var = T.thread_id([_SL_ROWS])
-        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
-        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+        src = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float32",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]),
+        )
+        dst = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float16",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]),
+        )
         src_row = src.local(_SL_COLS)
         for i in T.serial(_SL_COLS):
             src_row[i] = A[tx_var, i]
@@ -1376,14 +1404,28 @@ def test_cast_wg_rejects_partial_thread_coverage():
 
     @T.prim_func
     def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
-        A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32", 
layout=TileLayout(S[(half, _SL_COLS)]))
-        B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16", 
layout=TileLayout(S[(half, _SL_COLS)]))
+        A = T.match_buffer(
+            A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half, 
_SL_COLS)])
+        )
+        B = T.match_buffer(
+            B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half, 
_SL_COLS)])
+        )
         T.device_entry()
         _bx = T.cta_id([1])
         _wg = T.warpgroup_id([1])
         tid = T.thread_id_in_wg([_SL_ROWS])
-        src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
-        dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src = T.alloc_buffer(
+            (half, _SL_COLS),
+            "float32",
+            scope="local",
+            layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
+        dst = T.alloc_buffer(
+            (half, _SL_COLS),
+            "float16",
+            scope="local",
+            layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
         src_row = src.local(_SL_COLS)
         for i in T.serial(_SL_COLS):
             src_row[i] = A[tid, i]
@@ -1401,14 +1443,28 @@ def test_cast_wg_accepts_wg_level_layout():
 
     @T.prim_func
     def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
-        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
-        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        A = T.match_buffer(
+            A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
+        B = T.match_buffer(
+            B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
         T.device_entry()
         _bx = T.cta_id([1])
         _wg = T.warpgroup_id([1])
         tid = T.thread_id_in_wg([_SL_ROWS])
-        src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
-        dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+        src = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float32",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
+        dst = T.alloc_buffer(
+            (_SL_ROWS, _SL_COLS),
+            "float16",
+            scope="local",
+            layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+        )
         src_row = src.local(_SL_COLS)
         for i in T.serial(_SL_COLS):
             src_row[i] = A[tid, i]
@@ -1425,13 +1481,21 @@ def test_cast_thread_accepts_local_view():
 
     @T.prim_func
     def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
-        A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
-        B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+        A = T.match_buffer(
+            A_ptr, (_SL_ROWS, _SL_COLS), "float32", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
+        B = T.match_buffer(
+            B_ptr, (_SL_ROWS, _SL_COLS), "float16", 
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+        )
         T.device_entry()
         _bx = T.cta_id([1])
         tx_var = T.thread_id([_SL_ROWS])
-        src = T.alloc_buffer((_SL_COLS,), "float32", scope="local", 
layout=TileLayout(S[(_SL_COLS,)]))
-        dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local", 
layout=TileLayout(S[(_SL_COLS,)]))
+        src = T.alloc_buffer(
+            (_SL_COLS,), "float32", scope="local", 
layout=TileLayout(S[(_SL_COLS,)])
+        )
+        dst = T.alloc_buffer(
+            (_SL_COLS,), "float16", scope="local", 
layout=TileLayout(S[(_SL_COLS,)])
+        )
         for i in T.serial(_SL_COLS):
             src[i] = A[tx_var, i]
         Tx.cast(dst, src)

Reply via email to