This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bad4d0a932dd8557d7b8a632bf1fac3e4539d503 Author: Bohan Hou <[email protected]> AuthorDate: Sun May 24 09:52:18 2026 -0700 feat(tirx): add typed pointer byte-offset intrinsic (#641) --- include/tvm/tirx/builtin.h | 9 +++++++++ python/tvm/tirx/__init__.py | 3 ++- python/tvm/tirx/op.py | 11 +++++++++++ python/tvm/tirx/script/builder/ir.py | 2 ++ src/target/source/codegen_c.cc | 33 ++++++++++++++++++++++++++++++++- src/target/source/codegen_c.h | 10 ++++++++++ src/tirx/op/builtin.cc | 4 ++++ src/tirx/op/op.cc | 9 +++++++++ 8 files changed, 79 insertions(+), 2 deletions(-) diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h index 8627d55574..ff61386699 100644 --- a/include/tvm/tirx/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -273,6 +273,15 @@ TVM_DLL const Op& prefetch(); */ TVM_DLL const Op& tvm_access_ptr(); +/*! + * \brief Cast a handle to a typed pointer after adding a byte offset. + * + * DType* ptr_byte_offset(void* data, int byte_offset, Expr dtype) { + * return reinterpret_cast<DType*>(reinterpret_cast<char*>(data) + byte_offset); + * } + */ +TVM_DLL const Op& ptr_byte_offset(); + /*! * \brief Create a function local static handle that iniitalizes to nullptr. * can be used to cache function local static resources. diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 10de65a564..efda655066 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -55,7 +55,8 @@ from .op import tvm_stack_alloca, tvm_stack_make_shape, tvm_stack_make_array from .op import tvm_tuple, handle_add_byte_offset, tvm_struct_get, tvm_struct_set from .op import address_of, lookup_param, assume, undef from .op import continue_loop, break_loop -from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, tvm_throw_last_error +from .op import tvm_thread_allreduce, type_annotation, tvm_access_ptr, ptr_byte_offset +from .op import tvm_throw_last_error from .op import ( tvm_load_matrix_sync, tvm_store_matrix_sync, diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 39276b1e4c..ddf64d2e9c 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -876,6 +876,17 @@ def tvm_access_ptr(ptype, data, offset, extent, rw_mask): return call_intrin("handle", "tirx.tvm_access_ptr", ptype, data, offset, extent, rw_mask) +def ptr_byte_offset(data, byte_offset, dtype): + """Cast ``data + byte_offset`` to ``dtype*``. + + ``byte_offset`` is always in bytes. Use this when the source CUDA shape + needs an explicitly typed local pointer derived from a byte-addressed base. + """ + if isinstance(dtype, str): + dtype = type_annotation(dtype) + return call_intrin("handle", "tirx.ptr_byte_offset", data, byte_offset, dtype) + + def tvm_throw_last_error(): """Throw TVMGetLastError() diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 8adc802cb6..da24e71a7d 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -3558,6 +3558,7 @@ trunc = _op_wrapper(_tir_op.trunc) truncdiv = _op_wrapper(_tir_op.truncdiv) truncmod = _op_wrapper(_tir_op.truncmod) tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) +ptr_byte_offset = _op_wrapper(_tir_op.ptr_byte_offset) tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) @@ -3882,6 +3883,7 @@ __all__ = [ "truncdiv", "truncmod", "tvm_access_ptr", + "ptr_byte_offset", "tvm_throw_last_error", "tvm_stack_alloca", "tvm_stack_make_shape", diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index fb3d5f5f38..421cde706f 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -30,6 +30,7 @@ #include <iomanip> #include "../../arith/pattern_match.h" +#include "../../tirx/ir/buffer_common.h" #include "codegen_params.h" namespace tvm { @@ -42,6 +43,7 @@ void CodeGenC::Init(bool output_ssa) { print_ssa_form_ = output_ssa; } void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); handle_data_type_.clear(); + pointer_offset_vars_.clear(); CodeGenSourceBase::ClearFuncState(); ReserveKeywordsAsUnique(); } @@ -395,6 +397,16 @@ void CodeGenC::RegisterHandleType(const VarNode* buf_var, DataType t) { } } +void CodeGenC::RegisterHandleTypeFromPointer(const tirx::Var& var, const PrimExpr* value) { + if (value == nullptr) return; + auto* call = value->as<tirx::CallNode>(); + if (call == nullptr || !call->op.same_as(builtin::ptr_byte_offset())) return; + std::optional<DataType> value_dtype = tirx::GetPointerType(GetType(*value)); + if (!value_dtype.has_value()) return; + RegisterHandleType(var.get(), value_dtype.value()); + pointer_offset_vars_.insert(var.get()); +} + void CodeGenC::PrintVecElemLoad(const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) os << vec << ".s" << std::hex << i << std::dec; @@ -708,7 +720,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) if (load) { TVM_FFI_ICHECK_EQ(load->indices.size(), 1) << "CodeGenC only supports flat memory allocations."; - os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; + const VarNode* data = load->buffer->data.get(); + if (pointer_offset_vars_.count(data) && HandleTypeMatch(data, load->buffer->dtype) && + !IsVolatile(data)) { + os << "(" << GetVarID(data) << " + "; + this->PrintExpr(load->indices[0], os); + os << ")"; + } else { + os << "(&(" << GetBufferRef(load->dtype, load->buffer.get(), load->indices[0]) << "))"; + } } else { auto* var = op->args[0].as<tirx::VarNode>(); TVM_FFI_ICHECK(var) @@ -738,6 +758,15 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) os << "("; this->PrintExpr(op->args[0], os); os << " == NULL)"; + } else if (op->op.same_as(builtin::ptr_byte_offset())) { + TVM_FFI_ICHECK_EQ(op->args.size(), 3U); + os << "(("; + PrintType(op->args[2].dtype(), os); + os << "*)(((char*)"; + this->PrintExpr(op->args[0], os); + os << ") + "; + this->PrintExpr(op->args[1], os); + os << "))"; } else if (op->op.same_as(builtin::handle_add_byte_offset())) { TVM_FFI_ICHECK_EQ(op->args.size(), 2U); os << "((void*)((char*)"; @@ -953,6 +982,7 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) } else { let_binding_[op->var] = op; } + RegisterHandleTypeFromPointer(op->var, &op->value); std::string value = PrintExpr(op->value); if (print_ssa_form_) { TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); @@ -1077,6 +1107,7 @@ void CodeGenC::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOLINT(* } void CodeGenC::VisitStmt_(const BindNode* op) { + RegisterHandleTypeFromPointer(op->var, &op->value); std::string value = PrintExpr(op->value); if (print_ssa_form_) { TVM_FFI_ICHECK(!var_idmap_.count(op->var.get())); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index f1d04bf4aa..b044f3f3a4 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -302,6 +302,14 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, * \param t The type to be checked. */ void RegisterHandleType(const VarNode* buf_var, DataType t); + /*! + * \brief Register a typed pointer produced by explicit pointer-offset intrinsics. + * + * Ordinary handle lets remain void* so generic buffer views do not change + * code shape. Only explicit pointer-offset values opt into typed pointer + * arithmetic. + */ + void RegisterHandleTypeFromPointer(const tirx::Var& var, const PrimExpr* value); // override void PrintSSAAssign(const std::string& target, const std::string& src, DataType t) override; /*! \brief reserves common C keywords */ @@ -318,6 +326,8 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, std::unordered_map<const VarNode*, std::string> alloc_storage_scope_; /*! \brief the data type of allocated buffers */ std::unordered_map<const VarNode*, DataType> handle_data_type_; + /*! \brief Handle vars whose address_of(buffer[index]) should print as ptr + index. */ + std::unordered_set<const VarNode*> pointer_offset_vars_; /*! \brief Record of ops that have pre-defined global symbol. */ OpAttrMap<TGlobalSymbol> op_attr_global_symbol_ = Op::GetAttrMap<TGlobalSymbol>("TGlobalSymbol"); // cache commonly used ops diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index 40cbddef96..2f34efe7ab 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -176,6 +176,10 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_access_ptr) .set_num_inputs(5) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg)); +TIR_DEFINE_BUILTIN_FUNC(ptr_byte_offset) + .set_num_inputs(3) + .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure)); + TIR_DEFINE_BUILTIN_FUNC(tvm_static_handle) .set_num_inputs(0) .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kSpecialCallArg)); diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc index c2772ad69f..a6f00bc09d 100644 --- a/src/tirx/op/op.cc +++ b/src/tirx/op/op.cc @@ -86,6 +86,15 @@ Type GetType(const PrimExpr& expr) { << "to be a type annotation, but found " << type_annotation->op; return PointerType(PrimType(type_annotation->dtype)); } + if (access->op.same_as(builtin::ptr_byte_offset())) { + TVM_FFI_ICHECK_EQ(access->args.size(), 3U); + auto type_annotation = Downcast<Call>(access->args[2]); + static auto builtin_op = Op::Get("tirx.type_annotation"); + TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op)) + << "Expected the third argument of builtin ptr_byte_offset() " + << "to be a type annotation, but found " << type_annotation->op; + return PointerType(PrimType(type_annotation->dtype)); + } } if (auto* address_of = expr.as<tirx::CallNode>()) {
