This is an automated email from the ASF dual-hosted git repository.
guan404ming pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ef1904a3e5 [CI] Pin GitHub Actions to SHA for ASF INFRA compliance
(#19793)
ef1904a3e5 is described below
commit ef1904a3e599530fdfeb819df966877bcdfc864d
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jun 17 00:29:10 2026 +0800
[CI] Pin GitHub Actions to SHA for ASF INFRA compliance (#19793)
## Why
ASF INFRA enforces that external GitHub Actions must be pinned to a
commit SHA on the approved allowlist, failing the workflow with "not
allowed in apache/tvm". See the
[policy](https://infra.apache.org/github-actions-policy.html) and the
[approved
allowlist](https://github.com/apache/infrastructure-actions/blob/main/approved_patterns.yml).
## How
- Pin `pre-commit/action` to `2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd`
(v3.0.1)
- Pin `pypa/cibuildwheel` to `294735312765b09d24a2fbec22660ce817587d55`
(v4.1.0)
- Pin `pypa/gh-action-pypi-publish` to
`ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e` (v1.13.0)
- Leave GitHub-owned `actions/*` and the allowlisted
`conda-incubator/setup-miniconda@*` pattern untouched
---------
Signed-off-by: Guan-Ming (Wesley) Chiu
<[email protected]>
---
.github/actions/build-wheel-for-publish/action.yml | 2 +-
.github/workflows/lint.yml | 2 +-
.github/workflows/publish_wheel.yml | 2 +-
apps/cpp_rpc/rpc_env.cc | 14 ++-
apps/cpp_rpc/rpc_server.cc | 5 +-
python/tvm/backend/cuda/op.py | 1 -
python/tvm/backend/cuda/script.py | 4 +-
.../extra/contrib/tensorrt/tensorrt_builder.cc | 5 +-
src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc | 6 +-
.../meta_schedule/task_scheduler/task_scheduler.cc | 8 +-
src/target/llvm/codegen_llvm.cc | 7 +-
src/target/llvm/codegen_params.cc | 4 +-
src/tirx/transform/vectorize_loop.cc | 10 +-
tests/python/codegen/test_target_codegen_riscv.py | 1 +
tests/python/relax/test_frontend_onnx.py | 2 +-
.../tile_primitive/cuda/elementwise/test_unary.py | 104 +++++++++++++++++----
16 files changed, 115 insertions(+), 62 deletions(-)
diff --git a/.github/actions/build-wheel-for-publish/action.yml
b/.github/actions/build-wheel-for-publish/action.yml
index e718442379..d4a3ca14c2 100644
--- a/.github/actions/build-wheel-for-publish/action.yml
+++ b/.github/actions/build-wheel-for-publish/action.yml
@@ -108,7 +108,7 @@ runs:
# ---- Build and test wheels ----
- name: Build and test wheels
- uses: pypa/[email protected]
+ uses: pypa/cibuildwheel@294735312765b09d24a2fbec22660ce817587d55 #
v4.1.0
with:
package-dir: .
output-dir: wheelhouse
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index 6c17e0f149..16fa502aaa 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -35,4 +35,4 @@ jobs:
with:
fetch-depth: 0
fetch-tags: true
- - uses: pre-commit/[email protected]
+ - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd #
v3.0.1
diff --git a/.github/workflows/publish_wheel.yml
b/.github/workflows/publish_wheel.yml
index 63375e6063..a2fedda9e0 100644
--- a/.github/workflows/publish_wheel.yml
+++ b/.github/workflows/publish_wheel.yml
@@ -213,7 +213,7 @@ jobs:
- name: Publish package distributions to PyPI
if: ${{ inputs.publish_repository == 'pypi' }}
- uses: pypa/[email protected]
+ uses:
pypa/gh-action-pypi-publish@ed0c53931b1dc9bd32cbe73a98c7f6766f8a527e # v1.13.0
with:
attestations: true
verbose: true
diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc
index b0b1fe4064..4df5f87024 100644
--- a/apps/cpp_rpc/rpc_env.cc
+++ b/apps/cpp_rpc/rpc_env.cc
@@ -158,8 +158,7 @@ std::string RPCEnv::GetPath(const std::string& file_name)
const {
*/
void RPCEnv::CleanUp() const {
CleanDir(base_);
- if (!CheckPath(base_))
- return;
+ if (!CheckPath(base_)) return;
const int ret = rmdir(base_.c_str());
if (ret != 0) {
LOG(WARNING) << "Remove directory " << base_ << " failed";
@@ -325,11 +324,11 @@ std::string BuildSharedLibrary(std::string file) {
*/
bool CheckPath(const std::string& pathname) {
#if defined(_WIN32)
- DWORD attribs = GetFileAttributesA(pathname.c_str());
- return (attribs != INVALID_FILE_ATTRIBUTES);
+ DWORD attribs = GetFileAttributesA(pathname.c_str());
+ return (attribs != INVALID_FILE_ATTRIBUTES);
#else
- struct stat info;
- return (stat(pathname.c_str(), &info) == 0);
+ struct stat info;
+ return (stat(pathname.c_str(), &info) == 0);
#endif
}
@@ -338,8 +337,7 @@ bool CheckPath(const std::string& pathname) {
* \param dirname The name of the directory
*/
void CleanDir(const std::string& dirname) {
- if (!CheckPath(dirname))
- return;
+ if (!CheckPath(dirname)) return;
auto files = ListDir(dirname);
for (const auto& filename : files) {
std::string file_path = dirname + "/";
diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc
index b601478f43..88971cc34c 100644
--- a/apps/cpp_rpc/rpc_server.cc
+++ b/apps/cpp_rpc/rpc_server.cc
@@ -210,7 +210,7 @@ class RPCServer {
<< ", status = " << status_second;
} else if (finished_first == worker_pid) {
LOG(INFO) << "Child pid=" << worker_pid << " finished"
- << ", status = "<< status_first;
+ << ", status = " << status_first;
}
} else {
auto pid = fork();
@@ -334,8 +334,7 @@ class RPCServer {
RPCServerLoop(int(sock.sockfd));
const auto e_time = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = e_time - s_time;
- LOG(INFO) << "Finished serving " << addr.AsString()
- << " after " << elapsed.count() << " sec";
+ LOG(INFO) << "Finished serving " << addr.AsString() << " after " <<
elapsed.count() << " sec";
env.CleanUp();
}
diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py
index 9570e26662..bb3c59599e 100644
--- a/python/tvm/backend/cuda/op.py
+++ b/python/tvm/backend/cuda/op.py
@@ -694,7 +694,6 @@ def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True,
count)
-
def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
"""TVM intrinsic to call
mbarrier.arrive_expect_tx.shared::cta.b64
diff --git a/python/tvm/backend/cuda/script.py
b/python/tvm/backend/cuda/script.py
index a46aa7e7e4..76ba87344b 100644
--- a/python/tvm/backend/cuda/script.py
+++ b/python/tvm/backend/cuda/script.py
@@ -278,9 +278,7 @@ class MbarrierNamespace:
self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init)
self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait)
self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once)
- self.try_wait_acquire_cluster = _op_wrapper(
- _cuda_op.ptx_mbarrier_try_wait_acquire_cluster
- )
+ self.try_wait_acquire_cluster =
_op_wrapper(_cuda_op.ptx_mbarrier_try_wait_acquire_cluster)
self.arrive = MbarrierArriveNamespace()
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
index f0c2a26b2e..281d64cfbc 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
@@ -201,9 +201,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
delete runtime;
TVM_FFI_THROW(InternalError) << "Failed to deserialize the TensorRT
engine.";
}
- TVM_FFI_ICHECK_EQ(
- engine->getNbIOTensors(),
- static_cast<int32_t>(network_input_names_.size() +
network_output_names_.size()));
+ TVM_FFI_ICHECK_EQ(engine->getNbIOTensors(),
static_cast<int32_t>(network_input_names_.size() +
+
network_output_names_.size()));
nvinfer1::IExecutionContext* context = engine->createExecutionContext();
CleanUp();
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
index d3e68778fd..00ca3cea96 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
@@ -422,9 +422,9 @@ class DenseOpConverter : public TensorRTOpConverter {
->addConstant(VectorToTrtDims(params->inputs.at(1).weight_shape),
params->inputs.at(1).weight)
->getOutput(0);
- auto* matmul_layer = params->network->addMatrixMultiply(
- *input_tensor, nvinfer1::MatrixOperation::kNONE, *weight_tensor,
- nvinfer1::MatrixOperation::kTRANSPOSE);
+ auto* matmul_layer =
+ params->network->addMatrixMultiply(*input_tensor,
nvinfer1::MatrixOperation::kNONE,
+ *weight_tensor,
nvinfer1::MatrixOperation::kTRANSPOSE);
TVM_FFI_ICHECK(matmul_layer != nullptr);
params->outputs.push_back(matmul_layer->getOutput(0));
}
diff --git a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
index 76b407b5cf..3d7fadd40a 100644
--- a/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
+++ b/src/s_tir/meta_schedule/task_scheduler/task_scheduler.cc
@@ -208,15 +208,13 @@ void TaskSchedulerNode::Tune(ffi::Array<TuneContext>
ctxs, ffi::Array<FloatImm>
int n_build_errs = 0;
const ffi::Array<BuilderResult>& builder_results =
task->builder_results.value();
for (int i = 0; i < num_candidates; i++) {
- if (builder_results[i]->error_msg.has_value())
- ++n_build_errs;
+ if (builder_results[i]->error_msg.has_value()) ++n_build_errs;
}
if (n_build_errs > 0) {
TVM_PY_LOG(INFO, this->logger) << "Build errors: " << n_build_errs <<
" sample(s)";
}
- TVM_PY_LOG(INFO, this->logger) << "Sending "
- << num_candidates - n_build_errs
- << " valid sample(s) to runner";
+ TVM_PY_LOG(INFO, this->logger)
+ << "Sending " << num_candidates - n_build_errs << " valid sample(s)
to runner";
SendToRunner(task, runner);
} else {
TerminateTask(task_id);
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index f32dcdde11..912f8ec8c0 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -269,10 +269,9 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const
GlobalVar& gvar, cons
user_symbol =
user_symbol.substr(std::char_traits<char>::length(kFFISymbolPrefix));
}
TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" <<
user_symbol
- << "' in LLVM codegen: IRModule keys '" <<
it->second
- << "' and '" << gvar->name_hint
- << "' both lower to the same exported symbol
'" << symbol_name
- << "'. "
+ << "' in LLVM codegen: IRModule keys '" <<
it->second << "' and '"
+ << gvar->name_hint << "' both lower to the
same exported symbol '"
+ << symbol_name << "'. "
<< "Each exposed PrimFunc in one IRModule
must have a unique "
"global_symbol.";
}
diff --git a/src/target/llvm/codegen_params.cc
b/src/target/llvm/codegen_params.cc
index 6d8684a87e..0633c4fcb3 100644
--- a/src/target/llvm/codegen_params.cc
+++ b/src/target/llvm/codegen_params.cc
@@ -61,8 +61,8 @@ struct LLVMConstantGetter<T,
std::enable_if_t<std::is_floating_point<T>::value>>
static llvm::Constant* getElement(llvm::Type* ty, T t) { return
llvm::ConstantFP::get(ty, t); }
};
-template <typename T,
- typename = std::enable_if_t<std::is_standard_layout<T>::value &&
std::is_trivial<T>::value>>
+template <typename T, typename =
std::enable_if_t<std::is_standard_layout<T>::value &&
+ std::is_trivial<T>::value>>
void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t
num_elements,
std::vector<llvm::Constant*>* elements) {
elements->resize(num_elements, nullptr);
diff --git a/src/tirx/transform/vectorize_loop.cc
b/src/tirx/transform/vectorize_loop.cc
index fe6734863b..e746c6ac95 100644
--- a/src/tirx/transform/vectorize_loop.cc
+++ b/src/tirx/transform/vectorize_loop.cc
@@ -71,9 +71,8 @@ bool TargetHasVLA(Target target) {
}
bool ContainsCallNode(const Stmt& stmt) {
- return CheckContains::StmtContains(stmt, [](const PrimExpr& expr) {
- return expr.as<CallNode>() != nullptr;
- });
+ return CheckContains::StmtContains(
+ stmt, [](const PrimExpr& expr) { return expr.as<CallNode>() != nullptr;
});
}
} // namespace
@@ -1067,9 +1066,8 @@ class LoopVectorizer : public StmtMutator {
PrimExpr index = outer * scalable_lanes_index + inner_index;
Stmt body = Substitute(op->body, {{op->loop_var, index}});
Stmt guarded_body = IfThenElse(index < fixed_extent, body, std::nullopt,
op->span);
- Stmt vector_loop =
- For(inner, make_const(lane_dtype, 0), scalable_lanes,
ForKind::kVectorized, guarded_body,
- std::nullopt, op->annotations, std::nullopt, op->span);
+ Stmt vector_loop = For(inner, make_const(lane_dtype, 0), scalable_lanes,
ForKind::kVectorized,
+ guarded_body, std::nullopt, op->annotations,
std::nullopt, op->span);
Stmt loop = For(outer, zero, num_chunks, ForKind::kSerial, vector_loop,
std::nullopt, {},
std::nullopt, op->span);
diff --git a/tests/python/codegen/test_target_codegen_riscv.py
b/tests/python/codegen/test_target_codegen_riscv.py
index 3ac75dc337..5447fddcde 100644
--- a/tests/python/codegen/test_target_codegen_riscv.py
+++ b/tests/python/codegen/test_target_codegen_riscv.py
@@ -17,6 +17,7 @@
# ruff: noqa: E501, F841
import re
+
import pytest
import tvm
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index db8b977efc..414c3d5bbf 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -837,7 +837,7 @@ def test_reduce_min_max_nan_preserve(op_name, x):
ref_out = (np.max if op_name == "ReduceMax" else np.min)(x)
tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
- out_np = (tvm_out[0] if isinstance(tvm_out, (list, tuple)) else
tvm_out).numpy()
+ out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ref_out))
if not np.isnan(ref_out):
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
index fb70b37541..97a1be256e 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
@@ -1326,14 +1326,28 @@ def test_cast_wg_rejects_thread_local_view():
@T.prim_func
def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
- A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
- B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ A = T.match_buffer(
+ A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
+ B = T.match_buffer(
+ B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
T.device_entry()
_bx = T.cta_id([1])
_wg = T.warpgroup_id([1])
tid = T.thread_id_in_wg([_SL_ROWS])
- src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
- dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float32",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
+ dst = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float16",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
src_row = src.local(_SL_COLS)
for i in T.serial(_SL_COLS):
src_row[i] = A[tid, i]
@@ -1351,13 +1365,27 @@ def test_cast_cta_rejects_thread_local_view():
@T.prim_func
def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
- A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
- B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ A = T.match_buffer(
+ A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
+ B = T.match_buffer(
+ B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
T.device_entry()
_bx = T.cta_id([1])
tx_var = T.thread_id([_SL_ROWS])
- src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
- dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+ src = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float32",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]),
+ )
+ dst = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float16",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]),
+ )
src_row = src.local(_SL_COLS)
for i in T.serial(_SL_COLS):
src_row[i] = A[tx_var, i]
@@ -1376,14 +1404,28 @@ def test_cast_wg_rejects_partial_thread_coverage():
@T.prim_func
def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
- A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32",
layout=TileLayout(S[(half, _SL_COLS)]))
- B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16",
layout=TileLayout(S[(half, _SL_COLS)]))
+ A = T.match_buffer(
+ A_ptr, (half, _SL_COLS), "float32", layout=TileLayout(S[(half,
_SL_COLS)])
+ )
+ B = T.match_buffer(
+ B_ptr, (half, _SL_COLS), "float16", layout=TileLayout(S[(half,
_SL_COLS)])
+ )
T.device_entry()
_bx = T.cta_id([1])
_wg = T.warpgroup_id([1])
tid = T.thread_id_in_wg([_SL_ROWS])
- src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
- dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src = T.alloc_buffer(
+ (half, _SL_COLS),
+ "float32",
+ scope="local",
+ layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
+ dst = T.alloc_buffer(
+ (half, _SL_COLS),
+ "float16",
+ scope="local",
+ layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
src_row = src.local(_SL_COLS)
for i in T.serial(_SL_COLS):
src_row[i] = A[tid, i]
@@ -1401,14 +1443,28 @@ def test_cast_wg_accepts_wg_level_layout():
@T.prim_func
def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
- A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
- B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ A = T.match_buffer(
+ A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
+ B = T.match_buffer(
+ B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
T.device_entry()
_bx = T.cta_id([1])
_wg = T.warpgroup_id([1])
tid = T.thread_id_in_wg([_SL_ROWS])
- src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
- dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float32",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
+ dst = T.alloc_buffer(
+ (_SL_ROWS, _SL_COLS),
+ "float16",
+ scope="local",
+ layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]),
+ )
src_row = src.local(_SL_COLS)
for i in T.serial(_SL_COLS):
src_row[i] = A[tid, i]
@@ -1425,13 +1481,21 @@ def test_cast_thread_accepts_local_view():
@T.prim_func
def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
- A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
- B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ A = T.match_buffer(
+ A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
+ B = T.match_buffer(
+ B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)])
+ )
T.device_entry()
_bx = T.cta_id([1])
tx_var = T.thread_id([_SL_ROWS])
- src = T.alloc_buffer((_SL_COLS,), "float32", scope="local",
layout=TileLayout(S[(_SL_COLS,)]))
- dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local",
layout=TileLayout(S[(_SL_COLS,)]))
+ src = T.alloc_buffer(
+ (_SL_COLS,), "float32", scope="local",
layout=TileLayout(S[(_SL_COLS,)])
+ )
+ dst = T.alloc_buffer(
+ (_SL_COLS,), "float16", scope="local",
layout=TileLayout(S[(_SL_COLS,)])
+ )
for i in T.serial(_SL_COLS):
src[i] = A[tx_var, i]
Tx.cast(dst, src)