https://github.com/clementval created https://github.com/llvm/llvm-project/pull/121524
Convert the op to a new entry point in the runtime `CUFSyncGlobalDescriptor` >From 822f3740a56b689c50aa5c983910e2115da0e62c Mon Sep 17 00:00:00 2001 From: Valentin Clement <clement...@gmail.com> Date: Thu, 2 Jan 2025 13:58:14 -0800 Subject: [PATCH] [flang][cuda] Convert cuf.sync_descriptor to runtime call --- flang/include/flang/Runtime/CUDA/descriptor.h | 4 ++ .../Optimizer/Transforms/CUFOpConversion.cpp | 42 ++++++++++++++++++- flang/runtime/CUDA/descriptor.cpp | 7 ++++ flang/test/Fir/CUDA/cuda-sync-desc.mlir | 20 +++++++++ 4 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 flang/test/Fir/CUDA/cuda-sync-desc.mlir diff --git a/flang/include/flang/Runtime/CUDA/descriptor.h b/flang/include/flang/Runtime/CUDA/descriptor.h index 55878aaac57fb3..0ee7feca10e44c 100644 --- a/flang/include/flang/Runtime/CUDA/descriptor.h +++ b/flang/include/flang/Runtime/CUDA/descriptor.h @@ -33,6 +33,10 @@ void *RTDECL(CUFGetDeviceAddress)( void RTDECL(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src, const char *sourceFile = nullptr, int sourceLine = 0); +/// Get the device address of registered with the \p hostPtr and sync them. +void RTDECL(CUFSyncGlobalDescriptor)( + void *hostPtr, const char *sourceFile = nullptr, int sourceLine = 0); + } // extern "C" } // namespace Fortran::runtime::cuda diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index fb0ef246546444..f08f9e412b8857 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -788,6 +788,45 @@ struct CUFLaunchOpConversion const mlir::SymbolTable &symTab; }; +struct CUFSyncDescriptorOpConversion + : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> { + using OpRewritePattern::OpRewritePattern; + + CUFSyncDescriptorOpConversion(mlir::MLIRContext *context, + const mlir::SymbolTable &symTab) + : OpRewritePattern(context), symTab{symTab} {} + + mlir::LogicalResult + matchAndRewrite(cuf::SyncDescriptorOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName()); + if (!globalOp) + return mlir::failure(); + + auto hostAddr = builder.create<fir::AddrOfOp>( + loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName()); + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc, + builder); + auto fTy = callee.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, hostAddr, sourceFile, sourceLine)}; + builder.create<fir::CallOp>(loc, callee, args); + op.erase(); + return mlir::success(); + } + +private: + const mlir::SymbolTable &symTab; +}; + class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { public: void runOnOperation() override { @@ -851,7 +890,8 @@ void cuf::populateCUFToFIRConversionPatterns( CUFFreeOpConversion>(patterns.getContext()); patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, &dl, &converter); - patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab); + patterns.insert<CUFLaunchOpConversion, CUFSyncDescriptorOpConversion>( + patterns.getContext(), symtab); } void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab, diff --git a/flang/runtime/CUDA/descriptor.cpp b/flang/runtime/CUDA/descriptor.cpp index 391c47e84241d4..947eeb66aa3d6c 100644 --- a/flang/runtime/CUDA/descriptor.cpp +++ b/flang/runtime/CUDA/descriptor.cpp @@ -46,6 +46,13 @@ void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src, (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice)); } +void RTDEF(CUFSyncGlobalDescriptor)( + void *hostPtr, const char *sourceFile, int sourceLine) { + void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)}; + RTNAME(CUFDescriptorSync) + ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine); +} + RT_EXT_API_GROUP_END } } // namespace Fortran::runtime::cuda diff --git a/flang/test/Fir/CUDA/cuda-sync-desc.mlir b/flang/test/Fir/CUDA/cuda-sync-desc.mlir new file mode 100644 index 00000000000000..20b317f34a7f26 --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-sync-desc.mlir @@ -0,0 +1,20 @@ +// RUN: fir-opt --cuf-convert %s | FileCheck %s + +module attributes {dlti.dl_spec = #dlti.dl_spec<i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, i1 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (g...@github.com:clementval/llvm-project.git f37e52237791f58438790c77edeb8de08f692987)", llvm.target_triple = "x86_64-unknown-linux-gnu"} { + fir.global @_QMdevptrEdev_ptr {data_attr = #cuf.cuda<device>} : !fir.box<!fir.ptr<!fir.array<?xf32>>> { + %0 = fir.zero_bits !fir.ptr<!fir.array<?xf32>> + %c0 = arith.constant 0 : index + %1 = fir.shape %c0 : (index) -> !fir.shape<1> + %2 = fir.embox %0(%1) {allocator_idx = 2 : i32} : (!fir.ptr<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.ptr<!fir.array<?xf32>>> + fir.has_value %2 : !fir.box<!fir.ptr<!fir.array<?xf32>>> + } + func.func @_QQmain() { + cuf.sync_descriptor @_QMdevptrEdev_ptr + return + } +} + +// CHECK-LABEL: func.func @_QQmain() +// CHECK: %[[HOST_ADDR:.*]] = fir.address_of(@_QMdevptrEdev_ptr) : !fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>> +// CHECK: %[[HOST_ADDR_PTR:.*]] = fir.convert %[[HOST_ADDR]] : (!fir.ref<!fir.box<!fir.ptr<!fir.array<?xf32>>>>) -> !fir.llvm_ptr<i8> +// CHECK: fir.call @_FortranACUFSyncGlobalDescriptor(%[[HOST_ADDR_PTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits