llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-flang-fir-hlfir Author: Kareem Ergawy (ergawy) <details> <summary>Changes</summary> Now that we have changes introduced by #<!-- -->145837, mapping reductions from `do concurrent` to OpenMP is almost trivial. This PR adds such mapping. TODO: Add tests --- Full diff: https://github.com/llvm/llvm-project/pull/146033.diff 1 Files Affected: - (modified) flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp (+56-27) ``````````diff diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 709cf1d0938fa..31076f6eb328f 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -312,6 +312,19 @@ class DoConcurrentConversion bool isComposite) const { mlir::omp::WsloopOperands wsloopClauseOps; + auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion, + mlir::Region &ompRegion) { + if (!firRegion.empty()) { + rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin()); + auto firYield = + mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator()); + rewriter.setInsertionPoint(firYield); + rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(), + firYield.getOperands()); + rewriter.eraseOp(firYield); + } + }; + // For `local` (and `local_init`) opernads, emit corresponding `private` // clauses and attach these clauses to the workshare loop. if (!loop.getLocalVars().empty()) @@ -326,50 +339,65 @@ class DoConcurrentConversion TODO(localizer.getLoc(), "local_init conversion is not supported yet"); - auto oldIP = rewriter.saveInsertionPoint(); + mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(localizer); + auto privatizer = rewriter.create<mlir::omp::PrivateClauseOp>( localizer.getLoc(), sym.getLeafReference().str() + ".omp", localizer.getTypeAttr().getValue(), mlir::omp::DataSharingClauseType::Private); - if (!localizer.getInitRegion().empty()) { - rewriter.cloneRegionBefore(localizer.getInitRegion(), - privatizer.getInitRegion(), - privatizer.getInitRegion().begin()); - auto firYield = mlir::cast<fir::YieldOp>( - privatizer.getInitRegion().back().getTerminator()); - rewriter.setInsertionPoint(firYield); - rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(), - firYield.getOperands()); - rewriter.eraseOp(firYield); - } - - if (!localizer.getDeallocRegion().empty()) { - rewriter.cloneRegionBefore(localizer.getDeallocRegion(), - privatizer.getDeallocRegion(), - privatizer.getDeallocRegion().begin()); - auto firYield = mlir::cast<fir::YieldOp>( - privatizer.getDeallocRegion().back().getTerminator()); - rewriter.setInsertionPoint(firYield); - rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(), - firYield.getOperands()); - rewriter.eraseOp(firYield); - } - - rewriter.restoreInsertionPoint(oldIP); + cloneFIRRegionToOMP(localizer.getInitRegion(), + privatizer.getInitRegion()); + cloneFIRRegionToOMP(localizer.getDeallocRegion(), + privatizer.getDeallocRegion()); wsloopClauseOps.privateVars.push_back(op); wsloopClauseOps.privateSyms.push_back( mlir::SymbolRefAttr::get(privatizer)); } + if (!loop.getReduceVars().empty()) { + for (auto [op, byRef, sym, arg] : llvm::zip_equal( + loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), + loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(), + loop.getRegionReduceArgs())) { + auto firReducer = + mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>( + loop, sym); + + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointAfter(firReducer); + + auto ompReducer = rewriter.create<mlir::omp::DeclareReductionOp>( + firReducer.getLoc(), sym.getLeafReference().str() + ".omp", + firReducer.getTypeAttr().getValue()); + + cloneFIRRegionToOMP(firReducer.getAllocRegion(), + ompReducer.getAllocRegion()); + cloneFIRRegionToOMP(firReducer.getInitializerRegion(), + ompReducer.getInitializerRegion()); + cloneFIRRegionToOMP(firReducer.getReductionRegion(), + ompReducer.getReductionRegion()); + cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), + ompReducer.getAtomicReductionRegion()); + cloneFIRRegionToOMP(firReducer.getCleanupRegion(), + ompReducer.getCleanupRegion()); + + wsloopClauseOps.reductionVars.push_back(op); + wsloopClauseOps.reductionByref.push_back(byRef); + wsloopClauseOps.reductionSyms.push_back( + mlir::SymbolRefAttr::get(ompReducer)); + } + } + auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(loop.getLoc(), wsloopClauseOps); wsloopOp.setComposite(isComposite); Fortran::common::openmp::EntryBlockArgs wsloopArgs; wsloopArgs.priv.vars = wsloopClauseOps.privateVars; + wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars; Fortran::common::openmp::genEntryBlock(rewriter, wsloopArgs, wsloopOp.getRegion()); @@ -393,7 +421,8 @@ class DoConcurrentConversion clauseOps.loopLowerBounds.size()))) rewriter.replaceAllUsesWith(loopNestArg, wsloopArg); - for (unsigned i = 0; i < loop.getLocalVars().size(); ++i) + for (unsigned i = 0; + i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i) loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size()); return loopNestOp; `````````` </details> https://github.com/llvm/llvm-project/pull/146033 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits