This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 216be63e57 [REFACTOR][CODEGEN] Phase out tvm_global_barrier_state and
tvm_prepare_global_barrier (#19454)
216be63e57 is described below
commit 216be63e57bf1a547f12e8f43f733e106b8fd73b
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 27 16:36:31 2026 -0400
[REFACTOR][CODEGEN] Phase out tvm_global_barrier_state and
tvm_prepare_global_barrier (#19454)
Phase out the legacy spin-on-global-memory CUDA barrier machinery
(`tvm_global_barrier_state` / `__tvm_prepare_global_barrier` / the
`tvm_global_barrier_kinit()` builtin and the
`tirx.detect_global_barrier`
pass-config option). CUDA's native cooperative groups / grid sync
primitives cover the use case better; the bespoke implementation has
been dead in the active codegen pipelines.
This is a deletion-only refactor across 10 files (~−264 lines net):
- Public symbol constants in `include/tvm/runtime/device_api.h`
- TIR builtin op `tvm_global_barrier_kinit()`
(`include/tvm/tirx/builtin.h`,
`src/tirx/op/builtin.cc`)
- Pass config `tirx.detect_global_barrier` (`src/tirx/ir/transform.cc`)
- Entire kGlobal branch of `s_tir::ThreadSync` including
`InitGlobalBarrier`, `MakeGlobalBarrier`, and supporting state in
`ThreadSyncInserter`
- `CUDAPrepGlobalBarrier` runtime class + `CUDAModuleNode::GetGlobal()`
- `CodeGenCUDA::PrintStorageSync` "global" branch and the
`VisitStmt_(EvaluateNode*)` override + 3 member fields
- Two Python pipeline opt-in blocks (s_tir/pipeline.py + adreno mirror)
No Python or test references to these symbols. Build clean
(260 targets), CUDA codegen 50/50 passed, tirx-base + tirx-transform
621 passed, IRF cpptests 8/8 passed.
---
include/tvm/runtime/device_api.h | 4 -
include/tvm/tirx/builtin.h | 6 --
python/tvm/s_tir/backend/adreno/pipeline.py | 2 -
python/tvm/s_tir/pipeline.py | 2 -
src/runtime/cuda/cuda_module.cc | 53 -----------
src/s_tir/transform/thread_storage_sync.cc | 139 +---------------------------
src/target/source/codegen_cuda.cc | 56 +----------
src/target/source/codegen_cuda.h | 7 --
src/tirx/ir/transform.cc | 1 -
src/tirx/op/builtin.cc | 3 -
10 files changed, 6 insertions(+), 267 deletions(-)
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 20abff8c46..47607c5b88 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -416,10 +416,6 @@ TVM_RUNTIME_DLL bool RuntimeEnabled(const ffi::String&
target);
namespace symbol {
/*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device";
-/*! \brief Auxiliary counter to global barrier. */
-constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
-/*! \brief Prepare the global barrier before kernels that uses global barrier.
*/
-constexpr const char* tvm_prepare_global_barrier =
"__tvm_prepare_global_barrier";
} // namespace symbol
} // namespace runtime
diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h
index d0d5b3d57e..3339b1aa49 100644
--- a/include/tvm/tirx/builtin.h
+++ b/include/tvm/tirx/builtin.h
@@ -507,12 +507,6 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
TVM_DLL const Op& tvm_warp_shuffle_down();
TVM_DLL const Op& tvm_warp_activemask();
-/*!
- * \brief Initialize the global barrier.
- * Call this at beginning of kernel that need global barrier.
- */
-TVM_DLL const Op& tvm_global_barrier_kinit();
-
/*!
* \brief See pesudo code
*
diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py
b/python/tvm/s_tir/backend/adreno/pipeline.py
index a63fb4346d..51510f2113 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -93,8 +93,6 @@ def default_tir_pipeline():
tirx.transform.AnnotateEntryFunc(),
]
)
- if bool(config.get("tirx.detect_global_barrier", False)):
- passes.append(s_tir.transform.ThreadSync("global"))
passes.extend(
[
s_tir.transform.ThreadSync("shared"),
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index f775c0dd1e..85f586660d 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -91,8 +91,6 @@ def default_s_tir_pipeline():
tirx.transform.AnnotateEntryFunc(),
]
)
- if bool(config.get("tirx.detect_global_barrier", False)):
- passes.append(s_tir.transform.ThreadSync("global"))
passes.extend(
[
s_tir.transform.ThreadSync("shared"),
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 691e44f0a0..29a04eb185 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -27,7 +27,6 @@
#include <tvm/ffi/extra/c_env_api.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/device_api.h>
#include <tvm/support/io.h>
#include <array>
@@ -134,30 +133,6 @@ class CUDAModuleNode : public ffi::ModuleObj {
}
return func;
}
- // get a global var from primary context in device_id
- CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t
expect_nbytes) {
- std::lock_guard<std::mutex> lock(mutex_);
- // must recheck under the lock scope
- if (module_[device_id] == nullptr) {
- CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
- static auto nvshmem_init_hook =
ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
- if (nvshmem_init_hook.has_value()) {
- (*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
- }
- }
- CUdeviceptr global;
- size_t nbytes;
-
- CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id],
global_name.c_str());
- TVM_FFI_ICHECK_EQ(nbytes, expect_nbytes);
- if (result != CUDA_SUCCESS) {
- const char* msg;
- cuGetErrorName(result, &msg);
- TVM_FFI_THROW(CUDAError) << "cuModuleGetGlobal " << global_name
- << " failed with error: " << msg;
- }
- return global;
- }
private:
// the binary data
@@ -269,37 +244,9 @@ class CUDAWrappedFunc {
LaunchParamConfig launch_param_config_;
};
-class CUDAPrepGlobalBarrier {
- public:
- CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr<Object> sptr) : m_(m),
sptr_(sptr) {
- std::fill(pcache_.begin(), pcache_.end(), 0);
- }
-
- void operator()(const ffi::PackedArgs& args, ffi::Any* rv) const {
- int device_id;
- CUDA_CALL(cudaGetDevice(&device_id));
- if (pcache_[device_id] == 0) {
- pcache_[device_id] =
- m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state,
sizeof(unsigned));
- }
- CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
- }
-
- private:
- // internal module
- CUDAModuleNode* m_;
- // the resource holder
- ObjectPtr<Object> sptr_;
- // mark as mutable, to enable lazy initialization
- mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
-};
-
ffi::Optional<ffi::Function> CUDAModuleNode::GetFunction(const ffi::String&
name) {
ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this);
- if (name == symbol::tvm_prepare_global_barrier) {
- return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
- }
auto opt_info = fmap_.Get(name);
if (!opt_info.has_value()) return ffi::Function();
FunctionInfo info = opt_info.value();
diff --git a/src/s_tir/transform/thread_storage_sync.cc
b/src/s_tir/transform/thread_storage_sync.cc
index 3ec44d7a3d..b9b9e5f1dd 100644
--- a/src/s_tir/transform/thread_storage_sync.cc
+++ b/src/s_tir/transform/thread_storage_sync.cc
@@ -22,19 +22,17 @@
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/device_api.h>
#include <tvm/s_tir/stmt.h>
#include <tvm/s_tir/transform.h>
#include <tvm/tirx/analysis.h>
#include <tvm/tirx/builtin.h>
#include <tvm/tirx/expr.h>
+#include <tvm/tirx/op.h>
#include <tvm/tirx/stmt_functor.h>
-#include <unordered_map>
#include <unordered_set>
#include "../../runtime/thread_storage_scope.h"
-#include "../../tirx/transform/ir_utils.h"
#include "storage_access.h"
namespace tvm {
@@ -320,13 +318,8 @@ class ThreadSyncInserter : public StmtExprMutator {
Stmt VisitStmt(const Stmt& stmt) final {
if (syncs_.size() == 0) return stmt;
if (syncs_.count(stmt.get())) {
- Stmt barrier;
- if (sync_scope_.rank == StorageRank::kGlobal) {
- barrier = MakeGlobalBarrier();
- } else {
- barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
- {StringImm(sync_scope_.to_string())}));
- }
+ Stmt barrier = Evaluate(Call(DataType::Int(32),
builtin::tvm_storage_sync(),
+ {StringImm(sync_scope_.to_string())}));
// Mutate after query, to avoid stmt change.
auto ret = StmtExprMutator::VisitStmt(stmt);
ret = SeqStmt({barrier, ret});
@@ -335,137 +328,11 @@ class ThreadSyncInserter : public StmtExprMutator {
return StmtExprMutator::VisitStmt(stmt);
}
}
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- if (sync_scope_.rank == StorageRank::kGlobal &&
- GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
- ++rw_stats_[op->buffer->data].read_count;
- }
- return StmtExprMutator::VisitExpr_(op);
- }
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- if (sync_scope_.rank == StorageRank::kGlobal &&
- GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
- ++rw_stats_[op->buffer->data].write_count;
- }
- return StmtExprMutator::VisitStmt_(op);
- }
- Stmt VisitStmt_(const AttrStmtNode* op) final {
- if (op->attr_key == tirx::attr::thread_extent) {
- bool temp = true;
- std::swap(temp, in_thread_env_);
- thread_extents_.push_back(op);
- Stmt ret = StmtExprMutator::VisitStmt_(op);
- thread_extents_.pop_back();
- std::swap(temp, in_thread_env_);
- // first thread scope.
- if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
- ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
- num_blocks_ = PrimExpr();
- is_lead_ = PrimExpr();
- }
- return ret;
- } else {
- return StmtExprMutator::VisitStmt_(op);
- }
- }
-
- Stmt VisitStmt_(const AllocBufferNode* op) final {
- auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
- if (volatile_vars_.count(op->buffer->data.get())) {
- auto* cow = node.CopyOnWrite();
- auto annotations = cow->annotations;
- annotations.Set(tirx::attr::kVolatile, Bool(true));
- cow->annotations = annotations;
- }
- return node;
- }
-
- PrimExpr VisitExpr_(const CallNode* op) final {
- if (op->op.same_as(builtin::tvm_access_ptr())) {
- PrimExpr expr = StmtExprMutator::VisitExpr_(op);
- op = expr.as<CallNode>();
- TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
- Var buffer_var(Downcast<Var>(op->args[1]));
- const IntImmNode* flag = op->args[4].as<IntImmNode>();
- if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
- GetScope(buffer_var).rank == StorageRank::kGlobal) {
- ++rw_stats_[buffer_var].read_count;
- }
- if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
- GetScope(buffer_var).rank == StorageRank::kGlobal) {
- ++rw_stats_[buffer_var].write_count;
- }
- return expr;
- } else {
- return StmtExprMutator::VisitExpr_(op);
- }
- }
private:
- // RW statistics about data
- struct Entry {
- int read_count{0};
- int write_count{0};
- };
-
- // Get current storage scope.
- StorageScope GetScope(Var buffer_var) const {
- return StorageScope::Create(GetPtrStorageScope(buffer_var));
- }
-
- // private functions.
- Stmt InitGlobalBarrier(const AttrStmtNode* op) {
- TVM_FFI_ICHECK(op != nullptr);
- ffi::Array<PrimExpr> pargs =
{StringImm(runtime::symbol::tvm_prepare_global_barrier)};
- Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(),
pargs));
- Stmt body = op->body;
- for (const auto& kv : rw_stats_) {
- const auto& e = kv.second;
- if (e.read_count != 0 && e.write_count != 0) {
- volatile_vars_.insert(kv.first.get());
- }
- }
- rw_stats_.clear();
- Stmt kinit = Evaluate(Call(DataType::Int(32),
builtin::tvm_global_barrier_kinit(), {}));
- body = SeqStmt({kinit, body});
- body = AttrStmt(op->node, op->attr_key, op->value, body);
- return SeqStmt({prep, body});
- }
- Stmt MakeGlobalBarrier() {
- TVM_FFI_ICHECK(sync_scope_.rank == StorageRank::kGlobal);
- if (!num_blocks_.defined()) {
- TVM_FFI_ICHECK(!is_lead_.defined());
- num_work_dim_ = thread_extents_.size();
- for (const AttrStmtNode* attr : thread_extents_) {
- IterVar iv = Downcast<IterVar>(attr->node);
- runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
- if (s.rank == 0) {
- num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ :
attr->value);
- } else if (s.rank == 1) {
- PrimExpr cond = iv->var == make_zero(iv->var.dtype());
- is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
- }
- }
- } else {
- TVM_FFI_ICHECK_EQ(num_work_dim_, thread_extents_.size());
- }
- return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
- {StringImm(sync_scope_.to_string()), is_lead_,
num_blocks_}));
- }
// data structure.
StorageScope sync_scope_;
const std::unordered_set<const Object*>& syncs_;
- // The read write statistics of storage
- std::unordered_map<Var, Entry> rw_stats_;
- // Set of buffer data vars that should be marked volatile.
- std::unordered_set<const VarNode*> volatile_vars_;
- // The statistics for global barrier
- bool in_thread_env_{false};
- // memorized results
- std::vector<const AttrStmtNode*> thread_extents_;
- size_t num_work_dim_{0};
- PrimExpr num_blocks_;
- PrimExpr is_lead_;
};
Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index 4144326959..fec1c302d4 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -25,7 +25,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/function.h>
-#include <tvm/runtime/device_api.h>
#include <tvm/s_tir/stmt.h>
#include <tvm/tirx/index_map.h>
#include <tvm/tirx/stmt_functor.h>
@@ -136,12 +135,7 @@ std::string GetFP4Type(DataType type) {
CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
-void CodeGenCUDA::Init(bool output_ssa) {
- CodeGenC::Init(output_ssa);
- vid_global_barrier_state_ =
name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
- vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
- TVM_FFI_ICHECK_EQ(vid_global_barrier_state_,
runtime::symbol::tvm_global_barrier_state);
-}
+void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); }
void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name,
const PrimFunc& func,
std::ostream& os) {
@@ -759,35 +753,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
this->PrintIndent();
this->stream << "__syncthreads();\n";
} else if (sync == "global") {
- if (!need_global_barrier_) {
- need_global_barrier_ = true;
- this->decl_stream << "extern \"C\" __device__ unsigned " <<
vid_global_barrier_state_
- << ";\n";
- }
- // global synchronizer
- std::string is_load = PrintExpr(op->args[1]);
- std::string num_blocks = PrintExpr(op->args[2]);
- this->PrintIndent();
- // In theory only threadfence is needed
- // but we observed problems with only threadfence
- this->stream << "__threadfence_system();\n";
- this->PrintIndent();
- this->stream << "if (" << is_load << ") {\n";
- int wb = this->BeginScope();
- this->PrintIndent();
- this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
- this->PrintIndent();
- std::string ptr = name_supply_->FreshName("pf");
- this->stream << "volatile unsigned* " << ptr << " = &" <<
vid_global_barrier_state_ << ";\n";
- this->PrintIndent();
- this->stream << vid_global_barrier_expect_ << " += " << num_blocks <<
";\n";
- this->PrintIndent();
- this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_
<< ");\n";
- this->EndScope(wb);
- this->PrintIndent();
- this->stream << "}\n";
- this->PrintIndent();
- this->stream << "__syncthreads();\n";
+ TVM_FFI_THROW(InternalError)
+ << "Global barrier is no longer supported. Use device-native
synchronization primitives.";
}
}
@@ -1440,23 +1407,6 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) {
}
}
-void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
- if (is_const_int(op->value)) return;
- const CallNode* call = op->value.as<CallNode>();
- if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
- PrintIndent();
- stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
- PrintIndent();
- stream << "if (threadIdx.x == 0) {\n";
- PrintIndent();
- stream << " " << vid_global_barrier_expect_ << " = 0;\n";
- PrintIndent();
- stream << "}\n";
- } else {
- CodeGenC::VisitStmt_(op);
- }
-}
-
void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
int lanes = op->dtype.lanes();
TVM_FFI_CHECK_LE(lanes, 4, ValueError) << "Ramp of more than 4 lanes is not
allowed.";
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 4a384ffe16..39a15def56 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -69,7 +69,6 @@ class CodeGenCUDA final : public CodeGenC {
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const CallNode* op, std::ostream& os) final;
void VisitExpr_(const CastNode* op, std::ostream& os) final;
- void VisitStmt_(const EvaluateNode* op) final;
void VisitStmt_(const AllocBufferNode* op) final;
void VisitStmt_(const AttrStmtNode* op) final;
@@ -85,12 +84,6 @@ class CodeGenCUDA final : public CodeGenC {
// Whether scope such as "__shared__" or "__constant__" is part of type.
bool IsScopePartOfType() const final { return false; }
- // Whether global barrier is needed.
- bool need_global_barrier_{false};
- // Global barrier state
- std::string vid_global_barrier_state_;
- // Global barrier expected node.
- std::string vid_global_barrier_expect_;
// whether enable fp16
bool enable_fp16_{false};
// whether enable bf16
diff --git a/src/tirx/ir/transform.cc b/src/tirx/ir/transform.cc
index 14b0cf6b11..f5a99c454c 100644
--- a/src/tirx/ir/transform.cc
+++ b/src/tirx/ir/transform.cc
@@ -33,7 +33,6 @@ namespace transform {
// Register build pipeline related options
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.noalias", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tirx.detect_global_barrier", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_bound_checkers", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_assert", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_vectorize", Bool);
diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc
index 68f9ce219b..4355583d79 100644
--- a/src/tirx/op/builtin.cc
+++ b/src/tirx/op/builtin.cc
@@ -253,9 +253,6 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down)
TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
-TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit)
- .set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
-
TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));