This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 216be63e57 [REFACTOR][CODEGEN] Phase out tvm_global_barrier_state and 
tvm_prepare_global_barrier (#19454)
216be63e57 is described below

commit 216be63e57bf1a547f12e8f43f733e106b8fd73b
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Apr 27 16:36:31 2026 -0400

    [REFACTOR][CODEGEN] Phase out tvm_global_barrier_state and 
tvm_prepare_global_barrier (#19454)
    
    Phase out the legacy spin-on-global-memory CUDA barrier machinery
    (`tvm_global_barrier_state` / `__tvm_prepare_global_barrier` / the
    `tvm_global_barrier_kinit()` builtin and the
    `tirx.detect_global_barrier`
    pass-config option). CUDA's native cooperative groups / grid sync
    primitives cover the use case better; the bespoke implementation has
    been dead in the active codegen pipelines.
    
    This is a deletion-only refactor across 10 files (~−264 lines net):
    
    - Public symbol constants in `include/tvm/runtime/device_api.h`
    - TIR builtin op `tvm_global_barrier_kinit()`
    (`include/tvm/tirx/builtin.h`,
      `src/tirx/op/builtin.cc`)
    - Pass config `tirx.detect_global_barrier` (`src/tirx/ir/transform.cc`)
    - Entire kGlobal branch of `s_tir::ThreadSync` including
      `InitGlobalBarrier`, `MakeGlobalBarrier`, and supporting state in
      `ThreadSyncInserter`
    - `CUDAPrepGlobalBarrier` runtime class + `CUDAModuleNode::GetGlobal()`
    - `CodeGenCUDA::PrintStorageSync` "global" branch and the
      `VisitStmt_(EvaluateNode*)` override + 3 member fields
    - Two Python pipeline opt-in blocks (s_tir/pipeline.py + adreno mirror)
    
    No Python or test references to these symbols. Build clean
    (260 targets), CUDA codegen 50/50 passed, tirx-base + tirx-transform
    621 passed, IRF cpptests 8/8 passed.
---
 include/tvm/runtime/device_api.h            |   4 -
 include/tvm/tirx/builtin.h                  |   6 --
 python/tvm/s_tir/backend/adreno/pipeline.py |   2 -
 python/tvm/s_tir/pipeline.py                |   2 -
 src/runtime/cuda/cuda_module.cc             |  53 -----------
 src/s_tir/transform/thread_storage_sync.cc  | 139 +---------------------------
 src/target/source/codegen_cuda.cc           |  56 +----------
 src/target/source/codegen_cuda.h            |   7 --
 src/tirx/ir/transform.cc                    |   1 -
 src/tirx/op/builtin.cc                      |   3 -
 10 files changed, 6 insertions(+), 267 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 20abff8c46..47607c5b88 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -416,10 +416,6 @@ TVM_RUNTIME_DLL bool RuntimeEnabled(const ffi::String& 
target);
 namespace symbol {
 /*! \brief global function to set device */
 constexpr const char* tvm_set_device = "__tvm_set_device";
-/*! \brief Auxiliary counter to global barrier. */
-constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
-/*! \brief Prepare the global barrier before kernels that uses global barrier. 
*/
-constexpr const char* tvm_prepare_global_barrier = 
"__tvm_prepare_global_barrier";
 }  // namespace symbol
 
 }  // namespace runtime
diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h
index d0d5b3d57e..3339b1aa49 100644
--- a/include/tvm/tirx/builtin.h
+++ b/include/tvm/tirx/builtin.h
@@ -507,12 +507,6 @@ TVM_DLL const Op& tvm_warp_shuffle_up();
 TVM_DLL const Op& tvm_warp_shuffle_down();
 TVM_DLL const Op& tvm_warp_activemask();
 
-/*!
- * \brief Initialize the global barrier.
- *  Call this at beginning of kernel that need global barrier.
- */
-TVM_DLL const Op& tvm_global_barrier_kinit();
-
 /*!
  * \brief See pesudo code
  *
diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py 
b/python/tvm/s_tir/backend/adreno/pipeline.py
index a63fb4346d..51510f2113 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -93,8 +93,6 @@ def default_tir_pipeline():
                 tirx.transform.AnnotateEntryFunc(),
             ]
         )
-        if bool(config.get("tirx.detect_global_barrier", False)):
-            passes.append(s_tir.transform.ThreadSync("global"))
         passes.extend(
             [
                 s_tir.transform.ThreadSync("shared"),
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index f775c0dd1e..85f586660d 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -91,8 +91,6 @@ def default_s_tir_pipeline():
                 tirx.transform.AnnotateEntryFunc(),
             ]
         )
-        if bool(config.get("tirx.detect_global_barrier", False)):
-            passes.append(s_tir.transform.ThreadSync("global"))
         passes.extend(
             [
                 s_tir.transform.ThreadSync("shared"),
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 691e44f0a0..29a04eb185 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -27,7 +27,6 @@
 #include <tvm/ffi/extra/c_env_api.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/device_api.h>
 #include <tvm/support/io.h>
 
 #include <array>
@@ -134,30 +133,6 @@ class CUDAModuleNode : public ffi::ModuleObj {
     }
     return func;
   }
-  // get a global var from primary context in device_id
-  CUdeviceptr GetGlobal(int device_id, const std::string& global_name, size_t 
expect_nbytes) {
-    std::lock_guard<std::mutex> lock(mutex_);
-    // must recheck under the lock scope
-    if (module_[device_id] == nullptr) {
-      CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
-      static auto nvshmem_init_hook = 
ffi::Function::GetGlobal("runtime.nvshmem.cumodule_init");
-      if (nvshmem_init_hook.has_value()) {
-        (*nvshmem_init_hook)(static_cast<void*>(module_[device_id]));
-      }
-    }
-    CUdeviceptr global;
-    size_t nbytes;
-
-    CUresult result = cuModuleGetGlobal(&global, &nbytes, module_[device_id], 
global_name.c_str());
-    TVM_FFI_ICHECK_EQ(nbytes, expect_nbytes);
-    if (result != CUDA_SUCCESS) {
-      const char* msg;
-      cuGetErrorName(result, &msg);
-      TVM_FFI_THROW(CUDAError) << "cuModuleGetGlobal " << global_name
-                               << " failed with error: " << msg;
-    }
-    return global;
-  }
 
  private:
   // the binary data
@@ -269,37 +244,9 @@ class CUDAWrappedFunc {
   LaunchParamConfig launch_param_config_;
 };
 
-class CUDAPrepGlobalBarrier {
- public:
-  CUDAPrepGlobalBarrier(CUDAModuleNode* m, ObjectPtr<Object> sptr) : m_(m), 
sptr_(sptr) {
-    std::fill(pcache_.begin(), pcache_.end(), 0);
-  }
-
-  void operator()(const ffi::PackedArgs& args, ffi::Any* rv) const {
-    int device_id;
-    CUDA_CALL(cudaGetDevice(&device_id));
-    if (pcache_[device_id] == 0) {
-      pcache_[device_id] =
-          m_->GetGlobal(device_id, runtime::symbol::tvm_global_barrier_state, 
sizeof(unsigned));
-    }
-    CUDA_DRIVER_CALL(cuMemsetD32(pcache_[device_id], 0, 1));
-  }
-
- private:
-  // internal module
-  CUDAModuleNode* m_;
-  // the resource holder
-  ObjectPtr<Object> sptr_;
-  // mark as mutable, to enable lazy initialization
-  mutable std::array<CUdeviceptr, kMaxNumGPUs> pcache_;
-};
-
 ffi::Optional<ffi::Function> CUDAModuleNode::GetFunction(const ffi::String& 
name) {
   ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
   TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this);
-  if (name == symbol::tvm_prepare_global_barrier) {
-    return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
-  }
   auto opt_info = fmap_.Get(name);
   if (!opt_info.has_value()) return ffi::Function();
   FunctionInfo info = opt_info.value();
diff --git a/src/s_tir/transform/thread_storage_sync.cc 
b/src/s_tir/transform/thread_storage_sync.cc
index 3ec44d7a3d..b9b9e5f1dd 100644
--- a/src/s_tir/transform/thread_storage_sync.cc
+++ b/src/s_tir/transform/thread_storage_sync.cc
@@ -22,19 +22,17 @@
  */
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
-#include <tvm/runtime/device_api.h>
 #include <tvm/s_tir/stmt.h>
 #include <tvm/s_tir/transform.h>
 #include <tvm/tirx/analysis.h>
 #include <tvm/tirx/builtin.h>
 #include <tvm/tirx/expr.h>
+#include <tvm/tirx/op.h>
 #include <tvm/tirx/stmt_functor.h>
 
-#include <unordered_map>
 #include <unordered_set>
 
 #include "../../runtime/thread_storage_scope.h"
-#include "../../tirx/transform/ir_utils.h"
 #include "storage_access.h"
 
 namespace tvm {
@@ -320,13 +318,8 @@ class ThreadSyncInserter : public StmtExprMutator {
   Stmt VisitStmt(const Stmt& stmt) final {
     if (syncs_.size() == 0) return stmt;
     if (syncs_.count(stmt.get())) {
-      Stmt barrier;
-      if (sync_scope_.rank == StorageRank::kGlobal) {
-        barrier = MakeGlobalBarrier();
-      } else {
-        barrier = Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
-                                {StringImm(sync_scope_.to_string())}));
-      }
+      Stmt barrier = Evaluate(Call(DataType::Int(32), 
builtin::tvm_storage_sync(),
+                                   {StringImm(sync_scope_.to_string())}));
       // Mutate after query, to avoid stmt change.
       auto ret = StmtExprMutator::VisitStmt(stmt);
       ret = SeqStmt({barrier, ret});
@@ -335,137 +328,11 @@ class ThreadSyncInserter : public StmtExprMutator {
       return StmtExprMutator::VisitStmt(stmt);
     }
   }
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    if (sync_scope_.rank == StorageRank::kGlobal &&
-        GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
-      ++rw_stats_[op->buffer->data].read_count;
-    }
-    return StmtExprMutator::VisitExpr_(op);
-  }
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    if (sync_scope_.rank == StorageRank::kGlobal &&
-        GetScope(op->buffer->data).rank == StorageRank::kGlobal) {
-      ++rw_stats_[op->buffer->data].write_count;
-    }
-    return StmtExprMutator::VisitStmt_(op);
-  }
-  Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == tirx::attr::thread_extent) {
-      bool temp = true;
-      std::swap(temp, in_thread_env_);
-      thread_extents_.push_back(op);
-      Stmt ret = StmtExprMutator::VisitStmt_(op);
-      thread_extents_.pop_back();
-      std::swap(temp, in_thread_env_);
-      // first thread scope.
-      if (!in_thread_env_ && sync_scope_.rank == StorageRank::kGlobal) {
-        ret = InitGlobalBarrier(ret.as<AttrStmtNode>());
-        num_blocks_ = PrimExpr();
-        is_lead_ = PrimExpr();
-      }
-      return ret;
-    } else {
-      return StmtExprMutator::VisitStmt_(op);
-    }
-  }
-
-  Stmt VisitStmt_(const AllocBufferNode* op) final {
-    auto node = Downcast<AllocBuffer>(StmtExprMutator::VisitStmt_(op));
-    if (volatile_vars_.count(op->buffer->data.get())) {
-      auto* cow = node.CopyOnWrite();
-      auto annotations = cow->annotations;
-      annotations.Set(tirx::attr::kVolatile, Bool(true));
-      cow->annotations = annotations;
-    }
-    return node;
-  }
-
-  PrimExpr VisitExpr_(const CallNode* op) final {
-    if (op->op.same_as(builtin::tvm_access_ptr())) {
-      PrimExpr expr = StmtExprMutator::VisitExpr_(op);
-      op = expr.as<CallNode>();
-      TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
-      Var buffer_var(Downcast<Var>(op->args[1]));
-      const IntImmNode* flag = op->args[4].as<IntImmNode>();
-      if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal &&
-          GetScope(buffer_var).rank == StorageRank::kGlobal) {
-        ++rw_stats_[buffer_var].read_count;
-      }
-      if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal &&
-          GetScope(buffer_var).rank == StorageRank::kGlobal) {
-        ++rw_stats_[buffer_var].write_count;
-      }
-      return expr;
-    } else {
-      return StmtExprMutator::VisitExpr_(op);
-    }
-  }
 
  private:
-  // RW statistics about data
-  struct Entry {
-    int read_count{0};
-    int write_count{0};
-  };
-
-  // Get current storage scope.
-  StorageScope GetScope(Var buffer_var) const {
-    return StorageScope::Create(GetPtrStorageScope(buffer_var));
-  }
-
-  // private functions.
-  Stmt InitGlobalBarrier(const AttrStmtNode* op) {
-    TVM_FFI_ICHECK(op != nullptr);
-    ffi::Array<PrimExpr> pargs = 
{StringImm(runtime::symbol::tvm_prepare_global_barrier)};
-    Stmt prep = Evaluate(Call(DataType::Int(32), builtin::tvm_call_packed(), 
pargs));
-    Stmt body = op->body;
-    for (const auto& kv : rw_stats_) {
-      const auto& e = kv.second;
-      if (e.read_count != 0 && e.write_count != 0) {
-        volatile_vars_.insert(kv.first.get());
-      }
-    }
-    rw_stats_.clear();
-    Stmt kinit = Evaluate(Call(DataType::Int(32), 
builtin::tvm_global_barrier_kinit(), {}));
-    body = SeqStmt({kinit, body});
-    body = AttrStmt(op->node, op->attr_key, op->value, body);
-    return SeqStmt({prep, body});
-  }
-  Stmt MakeGlobalBarrier() {
-    TVM_FFI_ICHECK(sync_scope_.rank == StorageRank::kGlobal);
-    if (!num_blocks_.defined()) {
-      TVM_FFI_ICHECK(!is_lead_.defined());
-      num_work_dim_ = thread_extents_.size();
-      for (const AttrStmtNode* attr : thread_extents_) {
-        IterVar iv = Downcast<IterVar>(attr->node);
-        runtime::ThreadScope s = runtime::ThreadScope::Create(iv->thread_tag);
-        if (s.rank == 0) {
-          num_blocks_ = (num_blocks_.defined() ? attr->value * num_blocks_ : 
attr->value);
-        } else if (s.rank == 1) {
-          PrimExpr cond = iv->var == make_zero(iv->var.dtype());
-          is_lead_ = is_lead_.defined() ? (is_lead_ && cond) : cond;
-        }
-      }
-    } else {
-      TVM_FFI_ICHECK_EQ(num_work_dim_, thread_extents_.size());
-    }
-    return Evaluate(Call(DataType::Int(32), builtin::tvm_storage_sync(),
-                         {StringImm(sync_scope_.to_string()), is_lead_, 
num_blocks_}));
-  }
   // data structure.
   StorageScope sync_scope_;
   const std::unordered_set<const Object*>& syncs_;
-  // The read write statistics of storage
-  std::unordered_map<Var, Entry> rw_stats_;
-  // Set of buffer data vars that should be marked volatile.
-  std::unordered_set<const VarNode*> volatile_vars_;
-  // The statistics for global barrier
-  bool in_thread_env_{false};
-  // memorized results
-  std::vector<const AttrStmtNode*> thread_extents_;
-  size_t num_work_dim_{0};
-  PrimExpr num_blocks_;
-  PrimExpr is_lead_;
 };
 
 Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index 4144326959..fec1c302d4 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -25,7 +25,6 @@
 
 #include <tvm/arith/analyzer.h>
 #include <tvm/ffi/function.h>
-#include <tvm/runtime/device_api.h>
 #include <tvm/s_tir/stmt.h>
 #include <tvm/tirx/index_map.h>
 #include <tvm/tirx/stmt_functor.h>
@@ -136,12 +135,7 @@ std::string GetFP4Type(DataType type) {
 
 CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
 
-void CodeGenCUDA::Init(bool output_ssa) {
-  CodeGenC::Init(output_ssa);
-  vid_global_barrier_state_ = 
name_supply_->FreshName(runtime::symbol::tvm_global_barrier_state);
-  vid_global_barrier_expect_ = name_supply_->FreshName("__barrier_expect");
-  TVM_FFI_ICHECK_EQ(vid_global_barrier_state_, 
runtime::symbol::tvm_global_barrier_state);
-}
+void CodeGenCUDA::Init(bool output_ssa) { CodeGenC::Init(output_ssa); }
 
 void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name, 
const PrimFunc& func,
                                          std::ostream& os) {
@@ -759,35 +753,8 @@ void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
     this->PrintIndent();
     this->stream << "__syncthreads();\n";
   } else if (sync == "global") {
-    if (!need_global_barrier_) {
-      need_global_barrier_ = true;
-      this->decl_stream << "extern \"C\" __device__ unsigned " << 
vid_global_barrier_state_
-                        << ";\n";
-    }
-    // global synchronizer
-    std::string is_load = PrintExpr(op->args[1]);
-    std::string num_blocks = PrintExpr(op->args[2]);
-    this->PrintIndent();
-    // In theory only threadfence is needed
-    // but we observed problems with only threadfence
-    this->stream << "__threadfence_system();\n";
-    this->PrintIndent();
-    this->stream << "if (" << is_load << ") {\n";
-    int wb = this->BeginScope();
-    this->PrintIndent();
-    this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
-    this->PrintIndent();
-    std::string ptr = name_supply_->FreshName("pf");
-    this->stream << "volatile unsigned* " << ptr << " = &" << 
vid_global_barrier_state_ << ";\n";
-    this->PrintIndent();
-    this->stream << vid_global_barrier_expect_ << " += " << num_blocks << 
";\n";
-    this->PrintIndent();
-    this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ 
<< ");\n";
-    this->EndScope(wb);
-    this->PrintIndent();
-    this->stream << "}\n";
-    this->PrintIndent();
-    this->stream << "__syncthreads();\n";
+    TVM_FFI_THROW(InternalError)
+        << "Global barrier is no longer supported. Use device-native 
synchronization primitives.";
   }
 }
 
@@ -1440,23 +1407,6 @@ void CodeGenCUDA::VisitStmt_(const AllocBufferNode* op) {
   }
 }
 
-void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
-  if (is_const_int(op->value)) return;
-  const CallNode* call = op->value.as<CallNode>();
-  if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
-    PrintIndent();
-    stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
-    PrintIndent();
-    stream << "if (threadIdx.x == 0) {\n";
-    PrintIndent();
-    stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
-    PrintIndent();
-    stream << "}\n";
-  } else {
-    CodeGenC::VisitStmt_(op);
-  }
-}
-
 void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
   int lanes = op->dtype.lanes();
   TVM_FFI_CHECK_LE(lanes, 4, ValueError) << "Ramp of more than 4 lanes is not 
allowed.";
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 4a384ffe16..39a15def56 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -69,7 +69,6 @@ class CodeGenCUDA final : public CodeGenC {
   void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
   void VisitExpr_(const CallNode* op, std::ostream& os) final;
   void VisitExpr_(const CastNode* op, std::ostream& os) final;
-  void VisitStmt_(const EvaluateNode* op) final;
   void VisitStmt_(const AllocBufferNode* op) final;
   void VisitStmt_(const AttrStmtNode* op) final;
 
@@ -85,12 +84,6 @@ class CodeGenCUDA final : public CodeGenC {
   // Whether scope such as "__shared__" or "__constant__"  is part of type.
   bool IsScopePartOfType() const final { return false; }
 
-  // Whether global barrier is needed.
-  bool need_global_barrier_{false};
-  // Global barrier state
-  std::string vid_global_barrier_state_;
-  // Global barrier expected node.
-  std::string vid_global_barrier_expect_;
   // whether enable fp16
   bool enable_fp16_{false};
   // whether enable bf16
diff --git a/src/tirx/ir/transform.cc b/src/tirx/ir/transform.cc
index 14b0cf6b11..f5a99c454c 100644
--- a/src/tirx/ir/transform.cc
+++ b/src/tirx/ir/transform.cc
@@ -33,7 +33,6 @@ namespace transform {
 
 // Register build pipeline related options
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.noalias", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tirx.detect_global_barrier", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.instrument_bound_checkers", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_assert", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tirx.disable_vectorize", Bool);
diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc
index 68f9ce219b..4355583d79 100644
--- a/src/tirx/op/builtin.cc
+++ b/src/tirx/op/builtin.cc
@@ -253,9 +253,6 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_warp_shuffle_down)
 TIR_DEFINE_BUILTIN_FUNC(tvm_warp_activemask)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 
-TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit)
-    .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
-
 TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce)
     .set_attr<TCallEffectKind>("TCallEffectKind", 
Integer(CallEffectKind::kOpaque));
 

Reply via email to