This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch tvm-splithostdevice-staged-pass-consolidation in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 85cb3e77f899e600de948864629978d397fb6f17 Author: Tianqi Chen <[email protected]> AuthorDate: Wed Jun 3 14:18:55 2026 +0000 [REFACTOR][TIRX] Consolidate split host device stages The host/device split flow already depends on annotation, extraction, and kernel-launch lowering running as one sequence, so exposing the three stages separately leaves an unnecessary public surface and makes their ordering easier to misuse. This consolidates the stage implementations into the single SplitHostDevice pass, removes the old public wrappers and registrations, and updates the lowering pipelines and tests to exercise the one-pass API. --- include/tvm/tirx/transform.h | 33 +- python/tvm/s_tir/backend/adreno/pipeline.py | 2 - python/tvm/s_tir/pipeline.py | 2 - python/tvm/tirx/compilation_pipeline.py | 6 - python/tvm/tirx/transform/transform.py | 44 +- src/tirx/transform/annotate_device_regions.cc | 85 ---- src/tirx/transform/lower_device_kernel_launch.cc | 456 ------------------- src/tirx/transform/split_host_device.cc | 493 ++++++++++++++++++++- ...form_merge_dynamic_shared_memory_allocations.py | 7 +- .../transform/test_s_tir_transform_thread_sync.py | 5 +- .../test_tir_transform_annotate_device_regions.py | 44 +- .../test_tir_transform_device_kernel_launch.py | 14 +- .../test_tir_transform_split_host_device.py | 43 +- 13 files changed, 572 insertions(+), 662 deletions(-) diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h index 186ebf3f52..32a3ea8b29 100644 --- a/include/tvm/tirx/transform.h +++ b/include/tvm/tirx/transform.h @@ -163,19 +163,11 @@ TVM_DLL Pass RemapThreadAxis(ffi::Map<ffi::String, IterVar> axis_map); TVM_DLL Pass LowerCustomDatatypes(); /*! - * \brief Annotate locations that should be run on the device + * \brief Annotate, split, and lower host/device functions. * - * Insert `AttrStmt` nodes specifying a target on which regions within - * the PrimFunc should be executed. Only modifies functions that have - * a `tvm::attr::kTarget` attribute, and where that target defines a - * host. - * - * \return The pass. - */ -TVM_DLL Pass AnnotateDeviceRegions(); - -/*! - * \brief Split the function into a host function and device functions. + * This pass first annotates device regions within host functions, + * then splits them into host and device-side PrimFuncs, and finally + * lowers host-to-device calls into the device kernel launch ABI. * * The resulting host-side function will keep the same * `tvm::attr::kTarget` attribute (e.g. `T.target("cuda", @@ -190,23 +182,6 @@ TVM_DLL Pass AnnotateDeviceRegions(); */ TVM_DLL Pass SplitHostDevice(); -/*! - * \brief Lower cross-device function calls. - * - * Prior to this pass, host to device calls are represented as - * subroutine calls, with environment parameters (e.g. env_thread) - * specified internally. The device function is an internal function, - * without a `tvm::attr::kGlobalSymbol` attribute. - * - * After this pass, host to device calls are represented as - * tvm_call_packed built-in. The device function is an - * externally-exposed function, with a non-empty - * `tvm::attr::kGlobalSymbol` attribute. - * - * \return The pass. - */ -TVM_DLL Pass LowerDeviceKernelLaunch(); - /*! * \brief skip assert stmt. * diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py b/python/tvm/s_tir/backend/adreno/pipeline.py index 618970b37e..a185f2e4f0 100644 --- a/python/tvm/s_tir/backend/adreno/pipeline.py +++ b/python/tvm/s_tir/backend/adreno/pipeline.py @@ -109,9 +109,7 @@ def default_tir_pipeline(): passes.extend( [ s_tir.transform.MergeSharedMemoryAllocations(), - tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index a127e43a0e..070deb7681 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -109,9 +109,7 @@ def default_s_tir_pipeline(): passes.extend( [ s_tir.transform.MergeSharedMemoryAllocations(), - tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), diff --git a/python/tvm/tirx/compilation_pipeline.py b/python/tvm/tirx/compilation_pipeline.py index f964f50668..f79af3493f 100644 --- a/python/tvm/tirx/compilation_pipeline.py +++ b/python/tvm/tirx/compilation_pipeline.py @@ -48,9 +48,7 @@ def default_tir_pipeline(): tirx.transform.FP8ComputeLegalize(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), - tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), @@ -89,9 +87,7 @@ def tirx_pipeline(): tirx.transform.FP8ComputeLegalize(), tirx.transform.VerifyMemory(), tirx.transform.AnnotateEntryFunc(), - tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), tirx.transform.FP8StorageLegalize(), tirx.transform.BF16StorageLegalize(), @@ -122,9 +118,7 @@ def trn_pipeline(): tirx.transform.StmtSimplify(), tirx.transform.RemoveNoOp(), tirx.transform.AnnotateEntryFunc(), - tirx.transform.AnnotateDeviceRegions(), tirx.transform.SplitHostDevice(), - tirx.transform.LowerDeviceKernelLaunch(), tirx.transform.MakePackedAPI(), ] return tvm.ir.transform.Sequential(passes)(mod) diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 2c01863d32..56b32dcd8f 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -288,53 +288,19 @@ def MakePackedAPI(): return _ffi_api.MakePackedAPI() # type: ignore -def AnnotateDeviceRegions(): - """Annotate locations that should be run on the device - - Insert `AttrStmt` nodes specifying a target on which regions - within the PrimFunc should be executed. Only modifies functions - that have a `tvm::attr::kTarget` attribute, and where that target - defines a host. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.AnnotateDeviceRegions() # type: ignore - - def SplitHostDevice(): - """Split the function into a host function and device functions. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.SplitHostDevice() # type: ignore - - -def LowerDeviceKernelLaunch(): - """Lower cross-device function calls. - - Prior to this pass, host to device calls are represented as - subroutine calls, with environment parameters (e.g. env_thread) - specified internally. The device function is an internal - function, without a `tvm::attr::kGlobalSymbol` attribute. - - After this pass, host to device calls are represented as - tvm_call_packed built-in. The device function is an - externally-exposed function, with a non-empty - `tvm::attr::kGlobalSymbol` attribute. + """Annotate, split, and lower host/device functions. + This pass first annotates device regions within host functions, + then splits them into host and device-side PrimFuncs, and finally + lowers host-to-device calls into the device kernel launch ABI. Returns ------- fpass : tvm.transform.Pass The result pass """ - return _ffi_api.LowerDeviceKernelLaunch() # type: ignore + return _ffi_api.SplitHostDevice() # type: ignore def SkipAssert(): diff --git a/src/tirx/transform/annotate_device_regions.cc b/src/tirx/transform/annotate_device_regions.cc deleted file mode 100644 index 542acc1876..0000000000 --- a/src/tirx/transform/annotate_device_regions.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file annotate_device_regions.cc - * \brief Split device function from host. - */ -#include <tvm/ffi/cast.h> -#include <tvm/ffi/function.h> -#include <tvm/ffi/reflection/registry.h> -#include <tvm/ir/transform.h> -#include <tvm/target/target.h> -#include <tvm/tirx/builtin.h> -#include <tvm/tirx/expr.h> -#include <tvm/tirx/stmt_functor.h> -#include <tvm/tirx/transform.h> - -namespace tvm { -namespace tirx { - -class DeviceRegionAnnotater : public StmtMutator { - public: - explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == tvm::attr::kTarget) { - // If a target attribute already exists, use it as-is. - return ffi::GetRef<Stmt>(op); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::device_scope) { - // These attributes are only allowed in device-side code, so - // they should be annotated with the function's default target. - Stmt body = ffi::GetRef<Stmt>(op); - return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); - } else { - // All other annotations are ignored - return StmtMutator::VisitStmt_(op); - } - } - - private: - Target device_target_; -}; - -namespace transform { - -Pass AnnotateDeviceRegions() { - auto pass_func = [](PrimFunc func, IRModule mod, PassContext ctx) -> PrimFunc { - auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget); - TVM_FFI_ICHECK(opt_target) << "AnnotateDeviceRegions: Require the target attribute"; - Target target = opt_target.value(); - - if (target->GetHost()) { - DeviceRegionAnnotater mutator(target.WithoutHost()); - func.CopyOnWrite()->body = mutator(func->body); - } - return func; - }; - - return CreatePrimFuncPass(pass_func, 0, "tirx.AnnotateDeviceRegions", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tirx.transform.AnnotateDeviceRegions", AnnotateDeviceRegions); -} - -} // namespace transform -} // namespace tirx -} // namespace tvm diff --git a/src/tirx/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc deleted file mode 100644 index 29a9c02e43..0000000000 --- a/src/tirx/transform/lower_device_kernel_launch.cc +++ /dev/null @@ -1,456 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file lower_device_kernel_launch.cc - * \brief Split device function from host. - */ -#include <tvm/ffi/cast.h> -#include <tvm/ffi/function.h> -#include <tvm/ffi/reflection/registry.h> -#include <tvm/ir/transform.h> -#include <tvm/target/target.h> -#include <tvm/tirx/builtin.h> -#include <tvm/tirx/expr.h> -#include <tvm/tirx/stmt_functor.h> -#include <tvm/tirx/transform.h> - -#include "../../runtime/thread_storage_scope.h" -#include "ir_utils.h" - -namespace tvm { -namespace tirx { - -namespace { -struct KernelInfo { - // The device on which the PrimFunc runs - Target target; - - // The externally visible symbol which may refer to the PrimFunc - // when launching a device kernel. - ffi::String global_symbol; - - // The parameters accepted by the PrimFunc. Used to rewrite - // `launch_args` to be in terms of the calling scope. - ffi::Array<Var> params; - - // The launch parameters that should annotate the PrimFunc, if the - // kernel is ever called from the host. - ffi::Array<ffi::String> launch_params; - - // Additional arguments which must be provided to the host-side - // ffi::Function. These may be in terms of the function's parameters - // (e.g. a function that computes the average of `N` elements, and - // which must be launched with `N` CUDA threads). - ffi::Array<PrimExpr> launch_args; -}; - -/*! - * \brief Visitor class to collect device-side program information. - */ -class DeviceInfoCollector : public StmtVisitor { - public: - static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) { - DeviceInfoCollector collector; - collector.info_.target = func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost(); - collector.info_.params = func->params; - - collector(func->body); - - // The dynamic shared memory is required to be the last of the - // kernel launch parameters - if (collector.dyn_shmem_size) { - collector.info_.launch_params.push_back( - tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); - } - - collector.info_.global_symbol = - func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); - - collector.info_.launch_args = collector.info_.launch_params.Map( - [&](const auto& param) { return collector.GetArgument(param); }); - - return collector.info_; - } - - private: - PrimExpr GetArgument(const ffi::String& launch_param) const { - if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { - TVM_FFI_ICHECK(dyn_shmem_size.defined()) - << "Compute kernel requires launch parameter \"" << launch_param - << "\", but PrimFunc did not contain AllocBuffer node with shared dynamic scope."; - return dyn_shmem_size.value(); - } - - auto extent = thread_extent.Get(launch_param); - TVM_FFI_ICHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param - << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent - << "\" defining this thread extent"; - return extent.value(); - } - - void VisitStmt_(const BindNode* op) final { - // Track Bind definitions so that thread_extent values and - // dyn_shmem_size expressions that reference locally-bound - // variables (e.g. CSE variables) can be inlined back to - // expressions over function parameters. Substitute earlier - // bindings into the value to handle chains (cse_v2 = f(cse_v1)). - PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; - bind_map_.Set(op->var, value); - StmtVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast<IterVar>(op->node); - TVM_FFI_ICHECK_NE(iv->thread_tag.length(), 0U); - // thread_extent can appear multiple times - // use the first appearance as def. - if (!defined_thread.count(iv.get())) { - defined_thread.insert(iv.get()); - info_.launch_params.push_back(iv->thread_tag); - // Inline any locally-bound variables (e.g. from CSE) so - // that the extent is expressible in terms of function params. - PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; - thread_extent.Set(iv->thread_tag, value); - } - } - - StmtVisitor::VisitStmt_(op); - } - - void VisitStmt_(const AllocBufferNode* op) final { - auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { - TVM_FFI_ICHECK(!dyn_shmem_size.defined()) - << "Only one dynamic shared memory allocation is allowed."; - TVM_FFI_ICHECK_GT(op->buffer->shape.size(), 0); - - PrimExpr dyn_size = IntImm(DataType::Int(32), 1); - for (const auto& extent : op->buffer->shape) { - dyn_size *= extent; - } - dyn_size *= op->buffer->dtype.bytes(); - - // Inline any locally-bound variables (e.g. from CSE). - if (bind_map_.size()) { - dyn_size = Substitute(dyn_size, bind_map_); - } - dyn_shmem_size = dyn_size; - } - StmtVisitor::VisitStmt_(op); - } - - // The collected results - KernelInfo info_; - // recording what thread axis have been visited. - std::unordered_set<const IterVarNode*> defined_thread; - // The extent of each thread - ffi::Map<ffi::String, PrimExpr> thread_extent; - // The amount of dynamic shared memory used - ffi::Optional<PrimExpr> dyn_shmem_size{std::nullopt}; - // Accumulated Bind definitions for inlining into extent/size expressions. - ffi::Map<Var, PrimExpr> bind_map_; -}; - -class ReturnRemover : public StmtExprMutator { - public: - static Stmt Apply(const Stmt& stmt) { - ReturnRemover mutator; - return mutator(stmt); - } - - private: - using Parent = StmtExprMutator; - Stmt VisitStmt_(const EvaluateNode* op) override { - if (auto* call = op->value.as<CallNode>()) { - if (call->op.same_as(builtin::ret())) { - TVM_FFI_ICHECK_EQ(call->args.size(), 1); - auto as_int = call->args[0].as<IntImmNode>(); - TVM_FFI_ICHECK(as_int && as_int->value == 0) - << "Device kernel may only contain successful return, T.ret(0)"; - return Evaluate(0); - } - } - return Parent::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const CallNode* op) override { - if (op->op.same_as(builtin::ret())) { - TVM_FFI_THROW(InternalError) - << "Call to builtin::ret() should only appear within an Evaluate node"; - } - return Parent::VisitExpr_(op); - } -}; -} // namespace - -class DeviceKernelMutator : public StmtExprMutator { - public: - using Parent = StmtExprMutator; - - explicit DeviceKernelMutator(std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map) - : device_info_map_(std::move(device_info_map)) {} - - PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { - TVM_FFI_ICHECK(!current_target_.defined()); - auto it = device_info_map_.find(gvar.get()); - TVM_FFI_ICHECK(it != device_info_map_.end()); - current_target_ = it->second.target; - // Track whether the caller is a host function (i.e. its target - // still has a host attached) and capture its host target. The - // same-target shortcut at the call site is only safe when caller - // and callee are both device-resident; a host caller must take - // the kernel-launch path even if Target::WithoutHost() makes the - // strings match. Conversely, a host caller invoking another host - // helper (e.g. a same-target subroutine that SplitHostDevice - // emitted on the host side) should compare against the host - // target, not the device target stripped by WithoutHost(). - auto full_target = func->GetAttr<Target>(tvm::attr::kTarget).value(); - if (full_target->GetHost().defined()) { - current_caller_host_target_ = full_target->GetHost().value(); - } else { - current_caller_host_target_ = std::nullopt; - } - - auto body = VisitStmt(func->body); - if (!body.same_as(func->body)) { - func.CopyOnWrite()->body = body; - } - - current_target_ = std::nullopt; - current_caller_host_target_ = std::nullopt; - return func; - } - - PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { - bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); - bool is_call_extern = extern_function_call_.count(gvar.get()); - TVM_FFI_ICHECK(!is_kernel_launch || !is_call_extern) - << "Function " << gvar << " has multiple callees, " - << "and would need to be lowered into a call_extern at some call sites, " - << "and a device kernel launch at others. " - << "This case is not yet supported."; - - if (is_kernel_launch || is_call_extern) { - func = WithAttr(std::move(func), tvm::tirx::attr::kIsGlobalFunc, true); - } - - if (is_kernel_launch) { - const auto& info = device_info_map_.at(gvar.get()); - - // Kernel launches provide an int32 error code to the caller, - // but do not accept any return type from the callee. - { - auto write_ptr = func.CopyOnWrite(); - write_ptr->ret_type = VoidType(); - write_ptr->body = ReturnRemover::Apply(write_ptr->body); - } - - func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, - static_cast<int>(tvm::CallingConv::kDeviceKernelLaunch)}, - {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, - {tvm::attr::kGlobalSymbol, info.global_symbol}}); - - } else if (is_call_extern && !func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) { - func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); - } - - return func; - } - - private: - PrimExpr VisitExpr_(const CallNode* op) override { - auto node = Downcast<Call>(Parent::VisitExpr_(op)); - - auto* gvar = op->op.as<GlobalVarNode>(); - if (!gvar) return node; - - auto it = device_info_map_.find(gvar); - TVM_FFI_ICHECK(it != device_info_map_.end()) - << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " - << gvar->name_hint << " did not appear within the IRModule"; - const KernelInfo& dev_info = it->second; - - auto callee_target = dev_info.target; - - // A callee with non-empty launch_params has thread_extent - // bindings in its body, i.e. it is a real device kernel that - // must be invoked via a kernel-launch ABI. Conversely a callee - // with empty launch_params is a plain subroutine (host helper - // or intra-device helper) and is never invoked via kernel launch. - bool callee_is_kernel = dev_info.launch_params.size() > 0; - bool caller_is_host = current_caller_host_target_.has_value(); - - // For host callers, comparisons against the callee target must - // use the caller's *host* target, not the device target stripped - // by WithoutHost(). This handles two cases that the device-side - // comparison gets wrong: - // 1. A host caller invoking a real device kernel whose - // WithoutHost() target happens to match (e.g. kernel target - // "cuda" matches "cuda+host=c" after stripping host). Must - // go through kernel launch, not the same-target shortcut. - // 2. A host caller invoking another host helper with a - // different host target (e.g. SplitHostDevice emits an - // "add_host" with target "c" while the host body still - // carries "cuda+host=c"). Must go through call_extern (or - // same-target subroutine), not kernel launch. - auto caller_target = - caller_is_host ? current_caller_host_target_.value() : current_target_.value(); - - // A host caller invoking a real device kernel must always go - // through the kernel-launch ABI, regardless of any same-target / - // same-device-type coincidence. - bool force_kernel_launch = callee_is_kernel && caller_is_host; - - if (!force_kernel_launch) { - bool same_target = caller_target->str() == callee_target->str(); - if (same_target) { - // Calls within the same target may be handled at codegen time - // as internal subroutine calls. - return node; - } - - bool same_device_type = - caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); - if (same_device_type) { - // Calls to another target using the same device (e.g. LLVM - // calling a custom TIRToRuntime target) do not require a kernel - // launch, but need to be replaced with call_extern. - extern_function_call_.insert(gvar); - ffi::Array<PrimExpr> args; - args.push_back(StringImm(gvar->name_hint)); - for (const auto& arg : node->args) { - args.push_back(arg); - } - return Call(node->dtype, builtin::call_extern(), args); - } - } - - TVM_FFI_ICHECK(dev_info.launch_params.defined()) - << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " - << dev_info.target << ", but subroutine " << gvar->name_hint - << " did not have the tirx::attr::kKernelLaunchParams attribute " - << "required for cross-target kernel launch"; - - // Collected kernel information may be in terms of the callee's - // arguments, but we need expressions for them in terms of the - // caller's parameters. The param_map allows substitution of - // parameter values into the thread extents, to generate - // expressions that are valid within the caller. - ffi::Map<Var, PrimExpr> param_map = [&]() { - ffi::Map<Var, PrimExpr> param_map; - TVM_FFI_ICHECK_EQ(node->args.size(), dev_info.params.size()) - << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() - << " arguments as input, but is called using " << node->args.size() << " arguments"; - for (size_t i = 0; i < node->args.size(); i++) { - param_map.Set(dev_info.params[i], node->args[i]); - } - return param_map; - }(); - - device_kernel_launch_.insert(gvar); - - ffi::Array<PrimExpr> call_args; - call_args.push_back(StringImm(dev_info.global_symbol)); - for (PrimExpr arg : node->args) { - call_args.push_back(arg); - } - for (const auto& launch_arg : dev_info.launch_args) { - call_args.push_back(Substitute(launch_arg, param_map)); - } - - auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; - - return Call(dtype, builtin::tvm_call_packed(), call_args); - } - - ffi::Optional<Target> current_target_; - // The host target of the caller currently being rewritten, if the - // caller is a host function (its kTarget has a host attached). - // Used both to detect that the caller is a host function and to - // compare against the callee target on the host side, so that - // host-to-host subroutine calls are not misrouted through the - // device kernel-launch ABI. - ffi::Optional<Target> current_caller_host_target_; - std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_; - std::unordered_set<const GlobalVarNode*> device_kernel_launch_; - std::unordered_set<const GlobalVarNode*> extern_function_call_; -}; - -namespace transform { - -Pass LowerDeviceKernelLaunch() { - auto pass_func = [](IRModule mod, PassContext ctx) -> IRModule { - auto mutator = [&mod]() { - std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map; - for (const auto& [gvar, base_func] : mod->functions) { - if (auto prim_func = base_func.as<PrimFunc>()) { - device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value()); - } - } - return DeviceKernelMutator(std::move(device_info_map)); - }(); - - { - IRModule updates; - for (const auto& [gvar, base_func] : mod->functions) { - if (auto* ptr = base_func.as<PrimFuncNode>()) { - auto prim_func = mutator.RewriteKernelLaunchSite(gvar, ffi::GetRef<PrimFunc>(ptr)); - if (!prim_func.same_as(base_func)) { - updates->Add(gvar, prim_func); - } - } - } - - if (updates->functions.size()) { - mod.CopyOnWrite()->Update(updates); - } - } - - { - IRModule updates; - for (const auto& [gvar, base_func] : mod->functions) { - if (auto* ptr = base_func.as<PrimFuncNode>()) { - auto prim_func = mutator.UpdateKernelAttributes(gvar, ffi::GetRef<PrimFunc>(ptr)); - if (!prim_func.same_as(base_func)) { - updates->Add(gvar, prim_func); - } - } - } - - if (updates->functions.size()) { - mod.CopyOnWrite()->Update(updates); - } - } - - return mod; - }; - - return tvm::transform::CreateModulePass(pass_func, 0, "tirx.LowerDeviceKernelLaunch", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tirx.transform.LowerDeviceKernelLaunch", LowerDeviceKernelLaunch); -} - -} // namespace transform -} // namespace tirx -} // namespace tvm diff --git a/src/tirx/transform/split_host_device.cc b/src/tirx/transform/split_host_device.cc index 2192691634..801554389c 100644 --- a/src/tirx/transform/split_host_device.cc +++ b/src/tirx/transform/split_host_device.cc @@ -19,8 +19,9 @@ /*! * \file split_host_device.cc - * \brief Split device function from host. + * \brief Annotate and split device functions from host, then lower kernel launches. */ +#include <tvm/ffi/cast.h> #include <tvm/ffi/function.h> #include <tvm/ffi/reflection/registry.h> #include <tvm/ir/global_var_supply.h> @@ -33,11 +34,55 @@ #include <tvm/tirx/stmt_functor.h> #include <tvm/tirx/transform.h> +#include "../../runtime/thread_storage_scope.h" #include "../analysis/var_use_def_analysis.h" +#include "ir_utils.h" namespace tvm { namespace tirx { +// Device-region annotation + +class DeviceRegionAnnotater : public StmtMutator { + public: + explicit DeviceRegionAnnotater(Target device_target) : device_target_(device_target) {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == tvm::attr::kTarget) { + // If a target attribute already exists, use it as-is. + return ffi::GetRef<Stmt>(op); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::device_scope) { + // These attributes are only allowed in device-side code, so + // they should be annotated with the function's default target. + Stmt body = ffi::GetRef<Stmt>(op); + return AttrStmt(device_target_, tvm::attr::kTarget, 0, body); + } else { + // All other annotations are ignored. + return StmtMutator::VisitStmt_(op); + } + } + + private: + Target device_target_; +}; + +PrimFunc AnnotateDeviceRegionsForSplit(PrimFunc func) { + auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget); + TVM_FFI_ICHECK(opt_target) << "SplitHostDevice: Require the target attribute"; + Target target = opt_target.value(); + + if (target->GetHost()) { + DeviceRegionAnnotater mutator(target.WithoutHost()); + auto body = mutator(func->body); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + } + return func; +} + +// Host/device function extraction + class HostDeviceSplitter : public StmtMutator { public: explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> var_supply, @@ -152,6 +197,448 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, return func; } +// Device kernel launch lowering + +namespace { + +struct KernelInfo { + // The device on which the PrimFunc runs. + Target target; + + // The externally visible symbol which may refer to the PrimFunc + // when launching a device kernel. + ffi::String global_symbol; + + // The parameters accepted by the PrimFunc. Used to rewrite + // `launch_args` to be in terms of the calling scope. + ffi::Array<Var> params; + + // The launch parameters that should annotate the PrimFunc, if the + // kernel is ever called from the host. + ffi::Array<ffi::String> launch_params; + + // Additional arguments which must be provided to the host-side + // ffi::Function. These may be in terms of the function's parameters + // (e.g. a function that computes the average of `N` elements, and + // which must be launched with `N` CUDA threads). + ffi::Array<PrimExpr> launch_args; +}; + +/*! + * \brief Visitor class to collect device-side program information. + */ +class DeviceInfoCollector : public StmtVisitor { + public: + static KernelInfo Collect(const GlobalVar& gvar, const PrimFunc& func) { + DeviceInfoCollector collector; + collector.info_.target = func->GetAttr<Target>(tvm::attr::kTarget).value().WithoutHost(); + collector.info_.params = func->params; + + collector(func->body); + + // The dynamic shared memory is required to be the last of the + // kernel launch parameters. + if (collector.dyn_shmem_size) { + collector.info_.launch_params.push_back( + tvm::runtime::launch_param::kUseDynamicSharedMemoryTag); + } + + collector.info_.global_symbol = + func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol).value_or(gvar->name_hint); + + collector.info_.launch_args = collector.info_.launch_params.Map( + [&](const auto& param) { return collector.GetArgument(param); }); + + return collector.info_; + } + + private: + PrimExpr GetArgument(const ffi::String& launch_param) const { + if (launch_param == tvm::runtime::launch_param::kUseDynamicSharedMemoryTag) { + TVM_FFI_ICHECK(dyn_shmem_size.defined()) + << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc did not contain AllocBuffer node with shared dynamic scope."; + return dyn_shmem_size.value(); + } + + auto extent = thread_extent.Get(launch_param); + TVM_FFI_ICHECK(extent) << "Compute kernel requires launch parameter \"" << launch_param + << "\", but PrimFunc does not contain AttrStmt \"" << attr::thread_extent + << "\" defining this thread extent"; + return extent.value(); + } + + void VisitStmt_(const BindNode* op) final { + // Track Bind definitions so that thread_extent values and + // dyn_shmem_size expressions that reference locally-bound + // variables (e.g. CSE variables) can be inlined back to + // expressions over function parameters. Substitute earlier + // bindings into the value to handle chains (cse_v2 = f(cse_v1)). + PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; + bind_map_.Set(op->var, value); + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent) { + ffi::String thread_tag; + if (auto iv = op->node.as<IterVar>()) { + thread_tag = iv.value()->thread_tag; + TVM_FFI_ICHECK_NE(thread_tag.length(), 0U); + } else if (auto var = op->node.as<Var>()) { + thread_tag = var.value()->name_hint; + } else { + TVM_FFI_THROW(TypeError) << "thread_extent node must be an IterVar or Var, but was " + << op->node.GetTypeKey(); + } + // thread_extent can appear multiple times + // use the first appearance as def. + std::string thread_key = thread_tag; + if (!defined_thread.count(thread_key)) { + defined_thread.insert(thread_key); + info_.launch_params.push_back(thread_tag); + // Inline any locally-bound variables (e.g. from CSE) so + // that the extent is expressible in terms of function params. + PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; + thread_extent.Set(thread_tag, value); + } + } + + StmtVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocBufferNode* op) final { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data)); + if (storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn") { + TVM_FFI_ICHECK(!dyn_shmem_size.defined()) + << "Only one dynamic shared memory allocation is allowed."; + TVM_FFI_ICHECK_GT(op->buffer->shape.size(), 0); + + PrimExpr dyn_size = IntImm(DataType::Int(32), 1); + for (const auto& extent : op->buffer->shape) { + dyn_size *= extent; + } + dyn_size *= op->buffer->dtype.bytes(); + + // Inline any locally-bound variables (e.g. from CSE). + if (bind_map_.size()) { + dyn_size = Substitute(dyn_size, bind_map_); + } + dyn_shmem_size = dyn_size; + } + StmtVisitor::VisitStmt_(op); + } + + // The collected results. + KernelInfo info_; + // Recording what thread axis have been visited. + std::unordered_set<std::string> defined_thread; + // The extent of each thread. + ffi::Map<ffi::String, PrimExpr> thread_extent; + // The amount of dynamic shared memory used. + ffi::Optional<PrimExpr> dyn_shmem_size{std::nullopt}; + // Accumulated Bind definitions for inlining into extent/size expressions. + ffi::Map<Var, PrimExpr> bind_map_; +}; + +class ReturnRemover : public StmtExprMutator { + public: + static Stmt Apply(const Stmt& stmt) { + ReturnRemover mutator; + return mutator(stmt); + } + + private: + using Parent = StmtExprMutator; + Stmt VisitStmt_(const EvaluateNode* op) override { + if (auto* call = op->value.as<CallNode>()) { + if (call->op.same_as(builtin::ret())) { + TVM_FFI_ICHECK_EQ(call->args.size(), 1); + auto as_int = call->args[0].as<IntImmNode>(); + TVM_FFI_ICHECK(as_int && as_int->value == 0) + << "Device kernel may only contain successful return, T.ret(0)"; + return Evaluate(0); + } + } + return Parent::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const CallNode* op) override { + if (op->op.same_as(builtin::ret())) { + TVM_FFI_THROW(InternalError) + << "Call to builtin::ret() should only appear within an Evaluate node"; + } + return Parent::VisitExpr_(op); + } +}; + +class GlobalVarCallCollector : public StmtExprVisitor { + public: + static std::unordered_set<const GlobalVarNode*> Collect(const IRModule& mod) { + GlobalVarCallCollector collector; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto prim_func = base_func.as<PrimFunc>()) { + collector(prim_func.value()->body); + } + } + return collector.called_gvars_; + } + + private: + using Parent = StmtExprVisitor; + + void VisitExpr_(const CallNode* op) final { + if (auto* gvar = op->op.as<GlobalVarNode>()) { + called_gvars_.insert(gvar); + } + Parent::VisitExpr_(op); + } + + std::unordered_set<const GlobalVarNode*> called_gvars_; +}; + +} // namespace + +class DeviceKernelMutator : public StmtExprMutator { + public: + using Parent = StmtExprMutator; + + explicit DeviceKernelMutator(std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map) + : device_info_map_(std::move(device_info_map)) {} + + PrimFunc RewriteKernelLaunchSite(const GlobalVar& gvar, PrimFunc func) { + TVM_FFI_ICHECK(!current_target_.defined()); + // Track whether the caller is a host function (i.e. its target + // still has a host attached) and capture its host target. The + // same-target shortcut at the call site is only safe when caller + // and callee are both device-resident; a host caller must take + // the kernel-launch path even if Target::WithoutHost() makes the + // strings match. Conversely, a host caller invoking another host + // helper (e.g. a same-target subroutine that SplitHostDevice + // emitted on the host side) should compare against the host + // target, not the device target stripped by WithoutHost(). + auto full_target = func->GetAttr<Target>(tvm::attr::kTarget).value(); + current_target_ = full_target.WithoutHost(); + if (full_target->GetHost().defined()) { + current_caller_host_target_ = full_target->GetHost().value(); + } else { + current_caller_host_target_ = std::nullopt; + } + + auto body = VisitStmt(func->body); + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + + current_target_ = std::nullopt; + current_caller_host_target_ = std::nullopt; + return func; + } + + PrimFunc UpdateKernelAttributes(const GlobalVar& gvar, PrimFunc func) const { + bool is_kernel_launch = device_kernel_launch_.count(gvar.get()); + bool is_call_extern = extern_function_call_.count(gvar.get()); + TVM_FFI_ICHECK(!is_kernel_launch || !is_call_extern) + << "Function " << gvar << " has multiple callees, " + << "and would need to be lowered into a call_extern at some call sites, " + << "and a device kernel launch at others. " + << "This case is not yet supported."; + + if (is_kernel_launch || is_call_extern) { + func = WithAttr(std::move(func), tvm::tirx::attr::kIsGlobalFunc, true); + } + + if (is_kernel_launch) { + const auto& info = device_info_map_.at(gvar.get()); + + // Kernel launches provide an int32 error code to the caller, + // but do not accept any return type from the callee. + { + auto write_ptr = func.CopyOnWrite(); + write_ptr->ret_type = VoidType(); + write_ptr->body = ReturnRemover::Apply(write_ptr->body); + } + + func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, + static_cast<int>(tvm::CallingConv::kDeviceKernelLaunch)}, + {tvm::tirx::attr::kKernelLaunchParams, info.launch_params}, + {tvm::attr::kGlobalSymbol, info.global_symbol}}); + + } else if (is_call_extern && !func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) { + func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); + } + + return func; + } + + private: + PrimExpr VisitExpr_(const CallNode* op) override { + auto node = Downcast<Call>(Parent::VisitExpr_(op)); + + auto* gvar = op->op.as<GlobalVarNode>(); + if (!gvar) return node; + + auto it = device_info_map_.find(gvar); + TVM_FFI_ICHECK(it != device_info_map_.end()) + << "CallNode attempted subroutine call to " << gvar->name_hint << ", but " + << gvar->name_hint << " did not appear within the IRModule"; + const KernelInfo& dev_info = it->second; + + auto callee_target = dev_info.target; + + // A callee with non-empty launch_params has thread_extent + // bindings in its body, i.e. it is a real device kernel that + // must be invoked via a kernel-launch ABI. Conversely a callee + // with empty launch_params is a plain subroutine (host helper + // or intra-device helper) and is never invoked via kernel launch. + bool callee_is_kernel = dev_info.launch_params.size() > 0; + bool caller_is_host = current_caller_host_target_.has_value(); + + // For host callers, comparisons against the callee target must + // use the caller's *host* target, not the device target stripped + // by WithoutHost(). This handles two cases that the device-side + // comparison gets wrong: + // 1. A host caller invoking a real device kernel whose + // WithoutHost() target happens to match (e.g. kernel target + // "cuda" matches "cuda+host=c" after stripping host). Must + // go through kernel launch, not the same-target shortcut. + // 2. A host caller invoking another host helper with a + // different host target (e.g. SplitHostDevice emits an + // "add_host" with target "c" while the host body still + // carries "cuda+host=c"). Must go through call_extern (or + // same-target subroutine), not kernel launch. + auto caller_target = + caller_is_host ? current_caller_host_target_.value() : current_target_.value(); + + // A host caller invoking a real device kernel must always go + // through the kernel-launch ABI, regardless of any same-target / + // same-device-type coincidence. + bool force_kernel_launch = callee_is_kernel && caller_is_host; + + if (!force_kernel_launch) { + bool same_target = caller_target->str() == callee_target->str(); + if (same_target) { + // Calls within the same target may be handled at codegen time + // as internal subroutine calls. + return node; + } + + bool same_device_type = + caller_target->GetTargetDeviceType() == callee_target->GetTargetDeviceType(); + if (same_device_type) { + // Calls to another target using the same device (e.g. LLVM + // calling a custom TIRToRuntime target) do not require a kernel + // launch, but need to be replaced with call_extern. + extern_function_call_.insert(gvar); + ffi::Array<PrimExpr> args; + args.push_back(StringImm(gvar->name_hint)); + for (const auto& arg : node->args) { + args.push_back(arg); + } + return Call(node->dtype, builtin::call_extern(), args); + } + } + + TVM_FFI_ICHECK(dev_info.launch_params.defined()) + << "CallNode attempted kernel launch to " << gvar->name_hint << " on target " + << dev_info.target << ", but subroutine " << gvar->name_hint + << " did not have the tirx::attr::kKernelLaunchParams attribute " + << "required for cross-target kernel launch"; + + // Collected kernel information may be in terms of the callee's + // arguments, but we need expressions for them in terms of the + // caller's parameters. The param_map allows substitution of + // parameter values into the thread extents, to generate + // expressions that are valid within the caller. + ffi::Map<Var, PrimExpr> param_map = [&]() { + ffi::Map<Var, PrimExpr> param_map; + TVM_FFI_ICHECK_EQ(node->args.size(), dev_info.params.size()) + << "Function " << gvar->name_hint << " accepts " << dev_info.params.size() + << " arguments as input, but is called using " << node->args.size() << " arguments"; + for (size_t i = 0; i < node->args.size(); i++) { + param_map.Set(dev_info.params[i], node->args[i]); + } + return param_map; + }(); + + device_kernel_launch_.insert(gvar); + + ffi::Array<PrimExpr> call_args; + call_args.push_back(StringImm(dev_info.global_symbol)); + for (PrimExpr arg : node->args) { + call_args.push_back(arg); + } + for (const auto& launch_arg : dev_info.launch_args) { + call_args.push_back(Substitute(launch_arg, param_map)); + } + + auto dtype = node->dtype.is_void() ? DataType::Int(32) : node->dtype; + + return Call(dtype, builtin::tvm_call_packed(), call_args); + } + + ffi::Optional<Target> current_target_; + // The host target of the caller currently being rewritten, if the + // caller is a host function (its kTarget has a host attached). + // Used both to detect that the caller is a host function and to + // compare against the callee target on the host side, so that + // host-to-host subroutine calls are not misrouted through the + // device kernel-launch ABI. + ffi::Optional<Target> current_caller_host_target_; + std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_; + std::unordered_set<const GlobalVarNode*> device_kernel_launch_; + std::unordered_set<const GlobalVarNode*> extern_function_call_; +}; + +IRModule LowerDeviceKernelLaunches(IRModule mod) { + auto mutator = [&mod]() { + std::unordered_set<const GlobalVarNode*> called_gvars = GlobalVarCallCollector::Collect(mod); + std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map; + for (const auto& [gvar, base_func] : mod->functions) { + if (called_gvars.count(gvar.get())) { + if (auto prim_func = base_func.as<PrimFunc>()) { + device_info_map[gvar.get()] = DeviceInfoCollector::Collect(gvar, prim_func.value()); + } + } + } + return DeviceKernelMutator(std::move(device_info_map)); + }(); + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as<PrimFuncNode>()) { + auto prim_func = mutator.RewriteKernelLaunchSite(gvar, ffi::GetRef<PrimFunc>(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + { + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto* ptr = base_func.as<PrimFuncNode>()) { + auto prim_func = mutator.UpdateKernelAttributes(gvar, ffi::GetRef<PrimFunc>(ptr)); + if (!prim_func.same_as(base_func)) { + updates->Add(gvar, prim_func); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + } + + return mod; +} + namespace transform { Pass SplitHostDevice() { @@ -164,6 +651,7 @@ Pass SplitHostDevice() { for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as<PrimFunc>()) { PrimFunc func = opt.value(); + func = AnnotateDeviceRegionsForSplit(std::move(func)); auto global_symbol = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol); auto name_prefix = global_symbol.value_or(gvar->name_hint); @@ -181,7 +669,8 @@ Pass SplitHostDevice() { mod->Update(updates); mod->Update(device_mod); - return ConvertSSA()(mod); + mod = ConvertSSA()(mod); + return LowerDeviceKernelLaunches(mod); }; return tvm::transform::CreateModulePass(pass_func, 0, "tirx.SplitHostDevice", {}); diff --git a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py index b09c1fd796..86ff73d273 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -339,12 +339,7 @@ def test_multi_thread_extent_blocks(): # PR #19605 that triggers the scoping bug. target = tvm.target.Target("llvm") mod_with_target = tvm.IRModule({"main": After["main"].with_attr({"target": target})}) - split = tvm.transform.Sequential( - [ - tvm.tirx.transform.AnnotateDeviceRegions(), - tvm.tirx.transform.SplitHostDevice(), - ] - ) + split = tvm.tirx.transform.SplitHostDevice() # If kernel #1 referenced an undefined buf_dyn_shmem, this # would raise during well-formedness checking inside SplitHostDevice. split(mod_with_target) diff --git a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py index 3c4b1397b2..1afe7028b9 100644 --- a/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py +++ b/tests/python/s_tir/transform/test_s_tir_transform_thread_sync.py @@ -30,7 +30,6 @@ def run_passes(func: tvm.tirx.PrimFunc): lambda f: f.with_attr({"global_symbol": "test", "target": cuda_target}) )(mod) - mod = tvm.tirx.transform.AnnotateDeviceRegions()(mod) mod = tvm.tirx.transform.SplitHostDevice()(mod) return tvm.s_tir.transform.ThreadSync("shared")(mod) @@ -89,7 +88,7 @@ def test_sync_shared_dyn(): C_1_1 = T.decl_buffer((1,), data=C_1.data, scope="local") C_1_1[0] = B_1_1[threadIdx_x // 4 * 6 + threadIdx_x % 4] D_1_1 = T.decl_buffer((16,), data=D_1.data, scope="shared.dyn") - T.tvm_storage_sync("shared.dyn") + T.evaluate(T.call_intrin("int32", "tirx.tvm_storage_sync", "shared.dyn")) D_1_1[threadIdx_x] = C_1_1[0] E_1 = T.decl_buffer((16,), data=E.data) E_1[threadIdx_x] = D_1_1[threadIdx_x] @@ -147,7 +146,7 @@ def test_sync_bind(): A_shared_1_1[ax0] = A[blockIdx_x * 512 + ax0] in_thread_A_temp_1_1 = T.decl_buffer((1,), data=in_thread_A_temp_1.data, scope="local") in_thread_A_temp_1_1[0] = T.float32(0) - T.tvm_storage_sync("shared") + T.evaluate(T.call_intrin("int32", "tirx.tvm_storage_sync", "shared")) A_temp_1 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x]) in_thread_A_temp_1_1[0] = A_temp_1 A_temp_2 = T.bind(in_thread_A_temp_1_1[0] + A_shared_1_1[threadIdx_x + 128]) diff --git a/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py b/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py index 2c3cb659e3..b91eac07b7 100644 --- a/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py +++ b/tests/python/tirx-transform/test_tir_transform_annotate_device_regions.py @@ -21,8 +21,8 @@ from tvm.script import ir as I from tvm.script import tirx as T -def test_annotate_thread_extent(): - """Annotation inserted at the "thread_extent" attribute""" +def test_thread_extent_region_extracted_as_device_kernel(): + """A bare thread_extent is annotated and extracted as a device kernel.""" @I.ir_module class Before: @@ -37,16 +37,30 @@ def test_annotate_thread_extent(): @T.prim_func(s_tir=True) def main(A: T.Buffer(16, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - T.attr(T.target("cuda"), "target", 0) + T.call_packed("main_kernel", A.data, 16) + + @T.prim_func(s_tir=True) + def main_kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": ["threadIdx.x"], + "global_symbol": "main_kernel", + "tirx.noalias": True, + "tirx.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - After = tvm.tirx.transform.AnnotateDeviceRegions()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) -def test_annotate_device_scope(): - """Annotation inserted at the "device_scope" attribute""" +def test_device_scope_region_extracted_as_device_kernel(): + """A bare device_scope is annotated and extracted as a device kernel.""" @I.ir_module class Before: @@ -61,11 +75,25 @@ def test_annotate_device_scope(): @T.prim_func(s_tir=True) def main(A: T.Buffer(1, "float32")): T.func_attr({"target": T.target("cuda", host="llvm")}) - T.attr(T.target("cuda"), "target", 0) + T.call_packed("main_kernel", A.data) + + @T.prim_func(s_tir=True) + def main_kernel(A_data: T.handle("float32")): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": [], + "global_symbol": "main_kernel", + "tirx.noalias": True, + "tirx.is_global_func": True, + } + ) + A = T.decl_buffer(1, dtype="float32", data=A_data) T.attr(0, "device_scope", 0) A[0] = 0.0 - After = tvm.tirx.transform.AnnotateDeviceRegions()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py index 3c3ec106cf..9d596bfdca 100644 --- a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py +++ b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py @@ -66,12 +66,12 @@ def test_lower_device_kernel_launch(): A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) def test_externally_visible_kernel_launch(): - """Like TestLowerDeviceKernelLaunch, with pre-defined global_symbol + """Like the basic kernel launch lowering case, with pre-defined global_symbol Because the host and kernel will be handled by different code generators, the device-side kernel must be externally exposed for @@ -117,7 +117,7 @@ def test_externally_visible_kernel_launch(): A = T.decl_buffer(1, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -173,7 +173,7 @@ def test_collect_launch_parameter(): i = T.launch_thread("threadIdx.x", 16) A[i] = 0.0 - After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -219,7 +219,7 @@ def test_same_device_different_target(): A = T.decl_buffer(16, dtype="float32", data=A_data) A[0] = 0.0 - After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) @@ -229,7 +229,7 @@ def test_bind_before_thread_extent(): When CSE (or another pass) inserts Bind statements before thread_extent AttrStmts, the extent value may reference a locally-bound variable instead of function parameters. - LowerDeviceKernelLaunch must inline these bindings so that the + SplitHostDevice must inline these bindings so that the launch argument is expressible in terms of the caller's arguments. """ @@ -271,7 +271,7 @@ def test_bind_before_thread_extent(): i = T.launch_thread("threadIdx.x", v) A[i] = 0.0 - After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + After = tvm.tirx.transform.SplitHostDevice()(Before) tvm.ir.assert_structural_equal(After, Expected) diff --git a/tests/python/tirx-transform/test_tir_transform_split_host_device.py b/tests/python/tirx-transform/test_tir_transform_split_host_device.py index fc8ac8419b..9ea715984a 100644 --- a/tests/python/tirx-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tirx-transform/test_tir_transform_split_host_device.py @@ -21,6 +21,12 @@ from tvm.script import ir as I from tvm.script import tirx as T +def test_public_api_surface(): + assert hasattr(tvm.tirx.transform, "SplitHostDevice") + assert not hasattr(tvm.tirx.transform, "AnnotateDeviceRegions") + assert not hasattr(tvm.tirx.transform, "LowerDeviceKernelLaunch") + + def test_ssa_across_entire_module(): """The host and device functions should not share TIR vars @@ -38,13 +44,7 @@ def test_ssa_across_entire_module(): for j in range(16): T.evaluate(i) - after = tvm.ir.transform.Sequential( - [ - tvm.tirx.transform.AnnotateDeviceRegions(), - tvm.tirx.transform.SplitHostDevice(), - tvm.tirx.transform.LowerDeviceKernelLaunch(), - ] - )(before) + after = tvm.tirx.transform.SplitHostDevice()(before) loop_var = after["main"].body.loop_var param_var = after["main_kernel"].params[0] @@ -67,13 +67,16 @@ def test_split_host_device(): @T.prim_func(s_tir=True) def main(n: T.int32): T.func_attr({"target": T.target("cuda", host={"kind": "llvm", "opt-level": 0})}) - Expected.main_kernel(n) + T.call_packed("main_kernel", n) - @T.prim_func(private=True, s_tir=True) + @T.prim_func(s_tir=True) def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": [], + "global_symbol": "main_kernel", "tirx.noalias": True, "tirx.is_global_func": True, } @@ -100,10 +103,10 @@ def test_split_host_device_on_cpu(): @T.prim_func(s_tir=True) def main(n: T.int32): T.func_attr({"target": T.target("cuda", host={"kind": "llvm", "opt-level": 0})}) - err: T.let[T.int32] = Expected.main_kernel(n) - assert err == 0, "Error executing compute kernel" + kernel_error_code: T.let[T.int32] = T.call_extern("int32", "main_kernel", n) + assert kernel_error_code == 0, "Error executing compute kernel" - @T.prim_func(private=True, s_tir=True) + @T.prim_func(s_tir=True) def main_kernel(n: T.int32) -> T.int32: T.func_attr( { @@ -139,13 +142,16 @@ def test_split_host_device_without_func_host_attribute(): @T.prim_func(s_tir=True) def main(n: T.int32): T.func_attr({"target": T.target("llvm")}) - Expected.main_kernel(n) + T.call_packed("main_kernel", n) - @T.prim_func(private=True, s_tir=True) + @T.prim_func(s_tir=True) def main_kernel(n: T.int32): T.func_attr( { "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": [], + "global_symbol": "main_kernel", "tirx.noalias": True, "tirx.is_global_func": True, } @@ -201,13 +207,16 @@ def test_split_host_device_name_collision(): @T.prim_func(s_tir=True) def main(n: T.int32): T.func_attr({"target": T.target("cuda", host={"kind": "llvm", "opt-level": 0})}) - Expected.main_kernel_1(n) + T.call_packed("main_kernel_1", n) - @T.prim_func(private=True, s_tir=True) + @T.prim_func(s_tir=True) def main_kernel_1(n: T.int32): T.func_attr( { "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": [], + "global_symbol": "main_kernel_1", "tirx.noalias": True, "tirx.is_global_func": True, } @@ -234,7 +243,7 @@ def test_dynamic_launch_thread(): if the only use of a variable occurred in the extent of a `T.launch_thread` statement. - While the lowering pass `LowerDeviceKernelLaunch` will hoist the + While the launch-lowering stage will hoist the computation of the extent from the device kernel to the host function, the IRModule must be well-defined at all stages of lowering. Even if a variable is only used as part of a thread
