https://github.com/kparzysz created https://github.com/llvm/llvm-project/pull/77758
Introduce `genNestedEvaluations` that will lower all evaluations nested in the given, accouting for a potential COLLAPSE directive. Recursive lowering [2/5] >From fd51d9b3ad850579787cb31a5423498e09d51f0c Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Tue, 9 Jan 2024 12:22:06 -0600 Subject: [PATCH] [Flang][OpenMP] Push genEval calls to individual operations, NFC Introduce `genNestedEvaluations` that will lower all evaluations nested in the given, accouting for a potential COLLAPSE directive. Recursive lowering [2/5] --- flang/lib/Lower/OpenMP.cpp | 128 +++++++++++++++++++------------------ 1 file changed, 66 insertions(+), 62 deletions(-) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index 99690b03eca1d3..496b4ba27a0533 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -110,6 +110,32 @@ static void gatherFuncAndVarSyms( } } +static Fortran::lower::pft::Evaluation * +getEvalPastCollapse(Fortran::lower::pft::Evaluation &eval, int collapseValue) { + if (collapseValue == 0) + return &eval; + + Fortran::lower::pft::Evaluation *curEval = &eval.getFirstNestedEvaluation(); + for (int i = 1; i < collapseValue; i++) { + // The nested evaluations should be DoConstructs (i.e. they should form + // a loop nest). Each DoConstruct is a tuple <NonLabelDoStmt, Block, + // EndDoStmt>. + assert(curEval->isA<Fortran::parser::DoConstruct>()); + curEval = &*std::next(curEval->getNestedEvaluations().begin()); + } + return curEval; +} + +static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + int collapseValue = 0) { + Fortran::lower::pft::Evaluation *curEval = + getEvalPastCollapse(eval, collapseValue); + + for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) + converter.genEval(e); +} + //===----------------------------------------------------------------------===// // DataSharingProcessor //===----------------------------------------------------------------------===// @@ -2944,7 +2970,7 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter, static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { std::visit( @@ -3025,15 +3051,17 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter, createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, eval, &loopOpClauseList, iv, /*outer=*/false, &dsp); + + genNestedEvaluations(converter, eval, + Fortran::lower::getCollapseValue(loopOpClauseList)); } -static void -createWsLoop(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - llvm::omp::Directive ompDirective, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, - mlir::Location loc) { +static void createWsLoop(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + llvm::omp::Directive ompDirective, + const Fortran::parser::OmpClauseList &beginClauseList, + const Fortran::parser::OmpClauseList *endClauseList, + mlir::Location loc) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); DataSharingProcessor dsp(converter, beginClauseList, eval); dsp.processStep1(); @@ -3107,9 +3135,13 @@ createWsLoop(Fortran::lower::AbstractConverter &converter, createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, eval, &beginClauseList, iv, /*outer=*/false, &dsp); + + genNestedEvaluations(converter, eval, + Fortran::lower::getCollapseValue(beginClauseList)); } static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { @@ -3179,11 +3211,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, createWsLoop(converter, eval, ompDirective, loopOpClauseList, endClauseList, currentLocation); } + genOpenMPReduction(converter, loopOpClauseList); } static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, Fortran::semantics::SemanticsContext &semanticsContext, const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { const auto &beginBlockDirective = @@ -3298,11 +3331,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, break; } } + + genNestedEvaluations(converter, eval); + genOpenMPReduction(converter, beginClauseList); } static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::Location currentLocation = converter.getCurrentLocation(); @@ -3336,11 +3372,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, }(); createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, converter, currentLocation, eval); + genNestedEvaluations(converter, eval); } static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { mlir::Location currentLocation = converter.getCurrentLocation(); const Fortran::parser::OpenMPConstruct *parentOmpConstruct = @@ -3359,14 +3396,17 @@ genOMP(Fortran::lower::AbstractConverter &converter, .t); // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. + symTable.pushScope(); genOpWithBody<mlir::omp::SectionOp>(converter, eval, currentLocation, /*outerCombined=*/false, §ionsClauseList); + genNestedEvaluations(converter, eval); + symTable.popScope(); } static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { mlir::Location currentLocation = converter.getCurrentLocation(); llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands; @@ -3406,11 +3446,13 @@ genOMP(Fortran::lower::AbstractConverter &converter, /*reduction_vars=*/mlir::ValueRange(), /*reductions=*/nullptr, allocateOperands, allocatorOperands, nowaitClauseOperand); + + genNestedEvaluations(converter, eval); } static void genOMP(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, + Fortran::lower::SymMap &symTable, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { std::visit( Fortran::common::visitors{ @@ -3504,6 +3546,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, } static void genOMP(Fortran::lower::AbstractConverter &converter, + Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPConstruct &ompConstruct) { @@ -3511,17 +3554,18 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, Fortran::common::visitors{ [&](const Fortran::parser::OpenMPStandaloneConstruct &standaloneConstruct) { - genOMP(converter, eval, semanticsContext, standaloneConstruct); + genOMP(converter, symTable, eval, semanticsContext, + standaloneConstruct); }, [&](const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { - genOMP(converter, eval, sectionsConstruct); + genOMP(converter, symTable, eval, sectionsConstruct); }, [&](const Fortran::parser::OpenMPSectionConstruct §ionConstruct) { - genOMP(converter, eval, sectionConstruct); + genOMP(converter, symTable, eval, sectionConstruct); }, [&](const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { - genOMP(converter, eval, semanticsContext, loopConstruct); + genOMP(converter, symTable, eval, semanticsContext, loopConstruct); }, [&](const Fortran::parser::OpenMPDeclarativeAllocate &execAllocConstruct) { @@ -3536,14 +3580,14 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, TODO(converter.getCurrentLocation(), "OpenMPAllocatorsConstruct"); }, [&](const Fortran::parser::OpenMPBlockConstruct &blockConstruct) { - genOMP(converter, eval, semanticsContext, blockConstruct); + genOMP(converter, symTable, eval, semanticsContext, blockConstruct); }, [&](const Fortran::parser::OpenMPAtomicConstruct &atomicConstruct) { - genOMP(converter, eval, atomicConstruct); + genOMP(converter, symTable, eval, atomicConstruct); }, [&](const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { - genOMP(converter, eval, criticalConstruct); + genOMP(converter, symTable, eval, criticalConstruct); }, }, ompConstruct.u); @@ -3607,47 +3651,8 @@ void Fortran::lower::genOpenMPConstruct( Fortran::semantics::SemanticsContext &semanticsContext, Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPConstruct &omp) { - symTable.pushScope(); - genOMP(converter, semanticsContext, eval, omp); - - const Fortran::parser::OpenMPLoopConstruct *ompLoop = - std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u); - const Fortran::parser::OpenMPBlockConstruct *ompBlock = - std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u); - - // If loop is part of an OpenMP Construct then the OpenMP dialect - // workshare loop operation has already been created. Only the - // body needs to be created here and the do_loop can be skipped. - // Skip the number of collapsed loops, which is 1 when there is a - // no collapse requested. - - Fortran::lower::pft::Evaluation *curEval = &eval; - const Fortran::parser::OmpClauseList *loopOpClauseList = nullptr; - if (ompLoop) { - loopOpClauseList = &std::get<Fortran::parser::OmpClauseList>( - std::get<Fortran::parser::OmpBeginLoopDirective>(ompLoop->t).t); - int64_t collapseValue = Fortran::lower::getCollapseValue(*loopOpClauseList); - - curEval = &curEval->getFirstNestedEvaluation(); - for (int64_t i = 1; i < collapseValue; i++) { - curEval = &*std::next(curEval->getNestedEvaluations().begin()); - } - } - - for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) - converter.genEval(e); - - if (ompLoop) { - genOpenMPReduction(converter, *loopOpClauseList); - } else if (ompBlock) { - const auto &blockStart = - std::get<Fortran::parser::OmpBeginBlockDirective>(ompBlock->t); - const auto &blockClauses = - std::get<Fortran::parser::OmpClauseList>(blockStart.t); - genOpenMPReduction(converter, blockClauses); - } - + genOMP(converter, symTable, semanticsContext, eval, omp); symTable.popScope(); } @@ -3656,8 +3661,7 @@ void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::pft::Evaluation &eval, const Fortran::parser::OpenMPDeclarativeConstruct &omp) { genOMP(converter, eval, omp); - for (Fortran::lower::pft::Evaluation &e : eval.getNestedEvaluations()) - converter.genEval(e); + genNestedEvaluations(converter, eval); } void Fortran::lower::genOpenMPSymbolProperties( _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits