llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-fir-hlfir Author: Valentin Clement (バレンタイン クレメン) (clementval) <details> <summary>Changes</summary> Process `gpu.return` in AbstractResult pass when the function is a `gpu.func`. --- Full diff: https://github.com/llvm/llvm-project/pull/119035.diff 2 Files Affected: - (modified) flang/lib/Optimizer/Transforms/AbstractResult.cpp (+74-49) - (added) flang/test/Fir/CUDA/cuda-abstract-result.mlir (+37) ``````````diff diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 2ed66cc83eefb5..b0327cc10e9de6 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -234,6 +234,60 @@ class SaveResultOpConversion } }; +template <typename OpTy> +static mlir::LogicalResult +processReturnLikeOp(OpTy ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) { + auto loc = ret.getLoc(); + rewriter.setInsertionPoint(ret); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->template getParentOfType<mlir::ModuleOp>(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); + } + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create<fir::StoreOp>(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + return mlir::success(); +} + class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { public: using OpRewritePattern::OpRewritePattern; @@ -242,55 +296,23 @@ class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { llvm::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { - auto loc = ret.getLoc(); - rewriter.setInsertionPoint(ret); - mlir::Value resultValue = ret.getOperand(0); - fir::LoadOp resultLoad; - mlir::Value resultStorage; - // Identify result local storage. - if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { - resultLoad = load; - resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) - resultStorage = declare.getMemref(); - } - // Replace old local storage with new storage argument, unless - // the derived type is C_PTR/C_FUN_PTR, in which case the return - // type is updated to return void* (no new argument is passed). - if (fir::isa_builtin_cptr_type(resultValue.getType())) { - auto module = ret->getParentOfType<mlir::ModuleOp>(); - FirOpBuilder builder(rewriter, module); - mlir::Value cptr = resultValue; - if (resultLoad) { - // Replace whole derived type load by component load. - cptr = resultLoad.getMemref(); - rewriter.setInsertionPoint(resultLoad); - } - mlir::Value newResultValue = - fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); - newResultValue = builder.createConvert( - loc, getVoidPtrType(ret.getContext()), newResultValue); - rewriter.setInsertionPoint(ret); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( - ret, mlir::ValueRange{newResultValue}); - } else if (resultStorage) { - resultStorage.replaceAllUsesWith(newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } else { - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result - // storage. at the return point. - rewriter.create<fir::StoreOp>(loc, resultValue, newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } - // Delete result old local storage if unused. - if (resultStorage) - if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); - return mlir::success(); + return processReturnLikeOp(ret, newArg, rewriter); + } + +private: + mlir::Value newArg; +}; + +class GPUReturnOpConversion + : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { +public: + using OpRewritePattern::OpRewritePattern; + GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} + llvm::LogicalResult + matchAndRewrite(mlir::gpu::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + return processReturnLikeOp(ret, newArg, rewriter); } private: @@ -373,6 +395,9 @@ class AbstractResultOpt patterns.insert<ReturnOpConversion>(context, newArg); target.addDynamicallyLegalOp<mlir::func::ReturnOp>( [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); + patterns.insert<GPUReturnOpConversion>(context, newArg); + target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( + [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); assert(func.getFunctionType() == getNewFunctionType(funcTy, shouldBoxResult)); } else { diff --git a/flang/test/Fir/CUDA/cuda-abstract-result.mlir b/flang/test/Fir/CUDA/cuda-abstract-result.mlir new file mode 100644 index 00000000000000..8c59487ca5cd5c --- /dev/null +++ b/flang/test/Fir/CUDA/cuda-abstract-result.mlir @@ -0,0 +1,37 @@ +// RUN: fir-opt -pass-pipeline='builtin.module(gpu.module(gpu.func(abstract-result)))' %s | FileCheck %s + +gpu.module @test { + gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<f32>) -> !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> { + %c1_i32 = arith.constant 1 : i32 + %18 = fir.dummy_scope : !fir.dscope + %19 = fir.declare %arg0 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Ea"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.dscope) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>> + %20 = fir.declare %arg1 dummy_scope %18 {uniq_name = "_QMinterval_mFtest1Eb"} : (!fir.ref<f32>, !fir.dscope) -> !fir.ref<f32> + %21 = fir.alloca !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> {bindc_name = "c", uniq_name = "_QMinterval_mFtest1Ec"} + %22 = fir.declare %21 {uniq_name = "_QMinterval_mFtest1Ec"} : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>) -> !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>> + %23 = fir.alloca i32 {bindc_name = "warpsize", uniq_name = "_QMcudadeviceECwarpsize"} + %24 = fir.declare %23 {uniq_name = "_QMcudadeviceECwarpsize"} : (!fir.ref<i32>) -> !fir.ref<i32> + %25 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %26 = fir.coordinate_of %19, %25 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32> + %27 = fir.load %20 : !fir.ref<f32> + %28 = arith.negf %27 fastmath<contract> : f32 + %29 = fir.load %26 : !fir.ref<f32> + %30 = fir.call @__fadd_rd(%29, %28) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32 + %31 = fir.field_index inf, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %32 = fir.coordinate_of %22, %31 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32> + fir.store %30 to %32 : !fir.ref<f32> + %33 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %34 = fir.coordinate_of %19, %33 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32> + %35 = fir.load %20 : !fir.ref<f32> + %36 = arith.negf %35 fastmath<contract> : f32 + %37 = fir.load %34 : !fir.ref<f32> + %38 = fir.call @__fadd_ru(%37, %36) proc_attrs<bind_c> fastmath<contract> : (f32, f32) -> f32 + %39 = fir.field_index sup, !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + %40 = fir.coordinate_of %22, %39 : (!fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, !fir.field) -> !fir.ref<f32> + fir.store %38 to %40 : !fir.ref<f32> + %41 = fir.load %22 : !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>> + gpu.return %41 : !fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}> + } +} + +// CHECK: gpu.func @_QMinterval_mPtest1(%arg0: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg1: !fir.ref<!fir.type<_QMinterval_mTinterval{inf:f32,sup:f32}>>, %arg2: !fir.ref<f32>) { +// CHECK: gpu.return{{$}} `````````` </details> https://github.com/llvm/llvm-project/pull/119035 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits