https://github.com/ergawy created https://github.com/llvm/llvm-project/pull/156610
Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part. >From f748bd2e10415fc11f55bde946cab3a72e33ab2f Mon Sep 17 00:00:00 2001 From: ergawy <kareem.erg...@amd.com> Date: Tue, 2 Sep 2025 08:36:34 -0500 Subject: [PATCH] [flang][OpenMP] `do concurrent`: support `reduce` on device Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part. --- .../OpenMP/DoConcurrentConversion.cpp | 117 ++++++++++-------- .../DoConcurrent/reduce_device.mlir | 53 ++++++++ 2 files changed, 121 insertions(+), 49 deletions(-) create mode 100644 flang/test/Transforms/DoConcurrent/reduce_device.mlir diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 66b778fecc208..135382abb0227 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -140,6 +140,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop, for (mlir::Value local : loop.getLocalVars()) liveIns.push_back(local); + + for (mlir::Value reduce : loop.getReduceVars()) + liveIns.push_back(reduce); } /// Collects values that are local to a loop: "loop-local values". A loop-local @@ -272,7 +275,7 @@ class DoConcurrentConversion targetOp = genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns, targetClauseOps, loopNestClauseOps, liveInShapeInfoMap); - genTeamsOp(doLoop.getLoc(), rewriter); + genTeamsOp(rewriter, loop, mapper); } mlir::omp::ParallelOp parallelOp = @@ -488,46 +491,7 @@ class DoConcurrentConversion if (!mapToDevice) genPrivatizers(rewriter, mapper, loop, wsloopClauseOps); - if (!loop.getReduceVars().empty()) { - for (auto [op, byRef, sym, arg] : llvm::zip_equal( - loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), - loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(), - loop.getRegionReduceArgs())) { - auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>( - sym.getLeafReference()); - - mlir::OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(firReducer); - std::string ompReducerName = sym.getLeafReference().str() + ".omp"; - - auto ompReducer = - moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>( - rewriter.getStringAttr(ompReducerName)); - - if (!ompReducer) { - ompReducer = mlir::omp::DeclareReductionOp::create( - rewriter, firReducer.getLoc(), ompReducerName, - firReducer.getTypeAttr().getValue()); - - cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(), - ompReducer.getAllocRegion()); - cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(), - ompReducer.getInitializerRegion()); - cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(), - ompReducer.getReductionRegion()); - cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(), - ompReducer.getAtomicReductionRegion()); - cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(), - ompReducer.getCleanupRegion()); - moduleSymbolTable.insert(ompReducer); - } - - wsloopClauseOps.reductionVars.push_back(op); - wsloopClauseOps.reductionByref.push_back(byRef); - wsloopClauseOps.reductionSyms.push_back( - mlir::SymbolRefAttr::get(ompReducer)); - } - } + genReductions(rewriter, mapper, loop, wsloopClauseOps); auto wsloopOp = mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps); @@ -549,8 +513,6 @@ class DoConcurrentConversion rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back()); mlir::omp::YieldOp::create(rewriter, loop->getLoc()); - loop->getParentOfType<mlir::ModuleOp>().print( - llvm::errs(), mlir::OpPrintingFlags().assumeVerified()); return {loopNestOp, wsloopOp}; } @@ -771,15 +733,26 @@ class DoConcurrentConversion liveInName, shape); } - mlir::omp::TeamsOp - genTeamsOp(mlir::Location loc, - mlir::ConversionPatternRewriter &rewriter) const { - auto teamsOp = rewriter.create<mlir::omp::TeamsOp>( - loc, /*clauses=*/mlir::omp::TeamsOperands{}); + mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter, + fir::DoConcurrentLoopOp loop, + mlir::IRMapping &mapper) const { + mlir::omp::TeamsOperands teamsOps; + genReductions(rewriter, mapper, loop, teamsOps); + + mlir::Location loc = loop.getLoc(); + auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps); + Fortran::common::openmp::EntryBlockArgs teamsArgs; + teamsArgs.reduction.vars = teamsOps.reductionVars; + Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs, + teamsOp.getRegion()); - rewriter.createBlock(&teamsOp.getRegion()); rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc)); + for (auto [loopVar, teamsArg] : llvm::zip_equal( + loop.getReduceVars(), teamsOp.getRegion().getArguments())) { + mapper.map(loopVar, teamsArg); + } + return teamsOp; } @@ -846,6 +819,52 @@ class DoConcurrentConversion } } + void genReductions(mlir::ConversionPatternRewriter &rewriter, + mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop, + mlir::omp::ReductionClauseOps &reductionClauseOps) const { + if (!loop.getReduceVars().empty()) { + for (auto [var, byRef, sym, arg] : llvm::zip_equal( + loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), + loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(), + loop.getRegionReduceArgs())) { + auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>( + sym.getLeafReference()); + + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(firReducer); + std::string ompReducerName = sym.getLeafReference().str() + ".omp"; + + auto ompReducer = + moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>( + rewriter.getStringAttr(ompReducerName)); + + if (!ompReducer) { + ompReducer = mlir::omp::DeclareReductionOp::create( + rewriter, firReducer.getLoc(), ompReducerName, + firReducer.getTypeAttr().getValue()); + + cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(), + ompReducer.getAllocRegion()); + cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(), + ompReducer.getInitializerRegion()); + cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(), + ompReducer.getReductionRegion()); + cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(), + ompReducer.getAtomicReductionRegion()); + cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(), + ompReducer.getCleanupRegion()); + moduleSymbolTable.insert(ompReducer); + } + + reductionClauseOps.reductionVars.push_back( + mapToDevice ? mapper.lookup(var) : var); + reductionClauseOps.reductionByref.push_back(byRef); + reductionClauseOps.reductionSyms.push_back( + mlir::SymbolRefAttr::get(ompReducer)); + } + } + } + bool mapToDevice; llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip; mlir::SymbolTable &moduleSymbolTable; diff --git a/flang/test/Transforms/DoConcurrent/reduce_device.mlir b/flang/test/Transforms/DoConcurrent/reduce_device.mlir new file mode 100644 index 0000000000000..3e46692a15dca --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/reduce_device.mlir @@ -0,0 +1,53 @@ +// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s + +fir.declare_reduction @add_reduction_f32 : f32 init { +^bb0(%arg0: f32): + %cst = arith.constant 0.000000e+00 : f32 + fir.yield(%cst : f32) +} combiner { +^bb0(%arg0: f32, %arg1: f32): + %0 = arith.addf %arg0, %arg1 fastmath<contract> : f32 + fir.yield(%0 : f32) +} + +func.func @_QPfoo() { + %0 = fir.dummy_scope : !fir.dscope + %3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"} + %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>) + %c1 = arith.constant 1 : index + %c10 = arith.constant 1 : index + fir.do_concurrent { + %7 = fir.alloca i32 {bindc_name = "i"} + %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) + fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) { + %9 = fir.convert %arg0 : (index) -> i32 + fir.store %9 to %8#0 : !fir.ref<i32> + %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>) + %11 = fir.load %10#0 : !fir.ref<f32> + %cst = arith.constant 1.000000e+00 : f32 + %12 = arith.addf %11, %cst fastmath<contract> : f32 + hlfir.assign %12 to %10#0 : f32, !fir.ref<f32> + } + } + return +} + +// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32 + +// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"} +// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1 + +// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) { +// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]] +// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) { +// CHECK: omp.parallel { +// CHECK: omp.distribute { +// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) { +// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"} +// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0 +// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32 +// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits