https://github.com/kparzysz updated https://github.com/llvm/llvm-project/pull/87247
>From f725face892cef4faf9f17d4b549541bdbcd7e08 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Fri, 29 Mar 2024 09:20:41 -0500 Subject: [PATCH 1/3] [flang][OpenMP] Move clause/object conversion to happen early, in genOMP This removes the last use of genOmpObectList2, which has now been removed. --- flang/lib/Lower/OpenMP/ClauseProcessor.h | 5 +- flang/lib/Lower/OpenMP/DataSharingProcessor.h | 5 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 424 +++++++++--------- flang/lib/Lower/OpenMP/Utils.cpp | 30 +- flang/lib/Lower/OpenMP/Utils.h | 6 +- 5 files changed, 218 insertions(+), 252 deletions(-) diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index db7a1b8335f818..f4d659b70cfee7 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -49,9 +49,8 @@ class ClauseProcessor { public: ClauseProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses) - : converter(converter), semaCtx(semaCtx), - clauses(makeClauses(clauses, semaCtx)) {} + const List<Clause> &clauses) + : converter(converter), semaCtx(semaCtx), clauses(clauses) {} // 'Unique' clauses: They can appear at most once in the clause list. bool processCollapse( diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.h b/flang/lib/Lower/OpenMP/DataSharingProcessor.h index c11ee299c5d085..ef7b14327278e3 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.h +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.h @@ -78,13 +78,12 @@ class DataSharingProcessor { public: DataSharingProcessor(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &opClauseList, + const List<Clause> &clauses, Fortran::lower::pft::Evaluation &eval, bool useDelayedPrivatization = false, Fortran::lower::SymMap *symTable = nullptr) : hasLastPrivateOp(false), converter(converter), - firOpBuilder(converter.getFirOpBuilder()), - clauses(omp::makeClauses(opClauseList, semaCtx)), eval(eval), + firOpBuilder(converter.getFirOpBuilder()), clauses(clauses), eval(eval), useDelayedPrivatization(useDelayedPrivatization), symTable(symTable) {} // Privatisation is split into two steps. diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index edae453972d3d9..23dc25ac1ae9a1 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -17,6 +17,7 @@ #include "DataSharingProcessor.h" #include "DirectivesCommon.h" #include "ReductionProcessor.h" +#include "Utils.h" #include "flang/Common/idioms.h" #include "flang/Lower/Bridge.h" #include "flang/Lower/ConvertExpr.h" @@ -310,14 +311,15 @@ static void getDeclareTargetInfo( } else if (const auto *clauseList{ Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>( spec.u)}) { - if (clauseList->v.empty()) { + List<Clause> clauses = makeClauses(*clauseList, semaCtx); + if (clauses.empty()) { // Case: declare target, implicit capture of function symbolAndClause.emplace_back( mlir::omp::DeclareTargetCaptureClause::to, eval.getOwningProcedure()->getSubprogramSymbol()); } - ClauseProcessor cp(converter, semaCtx, *clauseList); + ClauseProcessor cp(converter, semaCtx, clauses); cp.processDeviceType(clauseOps); cp.processEnter(symbolAndClause); cp.processLink(symbolAndClause); @@ -597,14 +599,11 @@ static void removeStoreOp(mlir::Operation *reductionOp, mlir::Value symVal) { // TODO: Generate the reduction operation during lowering instead of creating // and removing operations since this is not a robust approach. Also, removing // ops in the builder (instead of a rewriter) is probably not the best approach. -static void -genOpenMPReduction(Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauseList) { +static void genOpenMPReduction(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List<Clause> &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); - List<Clause> clauses{makeClauses(clauseList, semaCtx)}; - for (const Clause &clause : clauses) { if (const auto &reductionClause = std::get_if<clause::Reduction>(&clause.u)) { @@ -812,7 +811,7 @@ struct OpWithBodyGenInfo { return *this; } - OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) { + OpWithBodyGenInfo &setClauses(const List<Clause> *value) { clauses = value; return *this; } @@ -848,7 +847,7 @@ struct OpWithBodyGenInfo { /// [in] is this an outer operation - prevents privatization. bool outerCombined = false; /// [in] list of clauses to process. - const Fortran::parser::OmpClauseList *clauses = nullptr; + const List<Clause> *clauses = nullptr; /// [in] if provided, processes the construct's data-sharing attributes. DataSharingProcessor *dsp = nullptr; /// [in] if provided, list of reduction symbols @@ -1226,36 +1225,33 @@ static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) { // Code generation functions for clauses //===----------------------------------------------------------------------===// -static void genCriticalDeclareClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::CriticalClauseOps &clauseOps, llvm::StringRef name) { +static void +genCriticalDeclareClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const List<Clause> &clauses, mlir::Location loc, + mlir::omp::CriticalClauseOps &clauseOps, + llvm::StringRef name) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processHint(clauseOps); clauseOps.nameAttr = mlir::StringAttr::get(converter.getFirOpBuilder().getContext(), name); } -static void genFlushClauses( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, - const std::optional<Fortran::parser::OmpObjectList> &objects, - const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>> - &clauses, - mlir::Location loc, llvm::SmallVectorImpl<mlir::Value> &operandRange) { - if (objects) - genObjectList2(*objects, converter, operandRange); - - if (clauses && clauses->size() > 0) +static void genFlushClauses(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + const ObjectList &objects, + const List<Clause> &clauses, mlir::Location loc, + llvm::SmallVectorImpl<mlir::Value> &operandRange) { + genObjectList(objects, converter, operandRange); + + if (clauses.size() > 0) TODO(converter.getCurrentLocation(), "Handle OmpMemoryOrderClause"); } static void genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::OrderedRegionClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO<clause::Simd>(loc, llvm::omp::Directive::OMPD_ordered); @@ -1264,9 +1260,9 @@ genOrderedRegionClauses(Fortran::lower::AbstractConverter &converter, static void genParallelClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processReduction, mlir::omp::ParallelClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, bool processReduction, + mlir::omp::ParallelClauseOps &clauseOps, llvm::SmallVectorImpl<mlir::Type> &reductionTypes, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1286,8 +1282,7 @@ static void genParallelClauses( static void genSectionsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, bool clausesFromBeginSections, mlir::omp::SectionsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); @@ -1304,9 +1299,8 @@ static void genSimdLoopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::SimdLoopClauseOps &clauseOps, + Fortran::lower::pft::Evaluation &eval, const List<Clause> &clauses, + mlir::Location loc, mlir::omp::SimdLoopClauseOps &clauseOps, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processCollapse(loc, eval, clauseOps, iv); @@ -1324,9 +1318,8 @@ static void genSimdLoopClauses( static void genSingleClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList &endClauses, - mlir::Location loc, + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc, mlir::omp::SingleClauseOps &clauseOps) { ClauseProcessor bcp(converter, semaCtx, beginClauses); bcp.processAllocate(clauseOps); @@ -1340,9 +1333,8 @@ static void genSingleClauses(Fortran::lower::AbstractConverter &converter, static void genTargetClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - bool processHostOnlyClauses, bool processReduction, + Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, bool processHostOnlyClauses, bool processReduction, mlir::omp::TargetClauseOps &clauseOps, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSyms, llvm::SmallVectorImpl<mlir::Location> &mapSymLocs, @@ -1368,9 +1360,8 @@ static void genTargetClauses( static void genTargetDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - mlir::omp::TargetDataClauseOps &clauseOps, + Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, mlir::omp::TargetDataClauseOps &clauseOps, llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes, llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &useDeviceSyms) { @@ -1401,9 +1392,8 @@ static void genTargetDataClauses( static void genTargetEnterExitUpdateDataClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, mlir::Location loc, - llvm::omp::Directive directive, + Fortran::lower::StatementContext &stmtCtx, const List<Clause> &clauses, + mlir::Location loc, llvm::omp::Directive directive, mlir::omp::TargetEnterExitUpdateDataClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processDepend(clauseOps); @@ -1422,8 +1412,7 @@ static void genTargetEnterExitUpdateDataClauses( static void genTaskClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TaskClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1442,8 +1431,7 @@ static void genTaskClauses(Fortran::lower::AbstractConverter &converter, static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TaskgroupClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1453,8 +1441,7 @@ static void genTaskgroupClauses(Fortran::lower::AbstractConverter &converter, static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TaskwaitClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processTODO<clause::Depend, clause::Nowait>( @@ -1464,8 +1451,7 @@ static void genTaskwaitClauses(Fortran::lower::AbstractConverter &converter, static void genTeamsClauses(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - const Fortran::parser::OmpClauseList &clauses, - mlir::Location loc, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TeamsClauseOps &clauseOps) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); @@ -1482,9 +1468,8 @@ static void genWsloopClauses( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::StatementContext &stmtCtx, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauses, - const Fortran::parser::OmpClauseList *endClauses, mlir::Location loc, + Fortran::lower::pft::Evaluation &eval, const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc, mlir::omp::WsloopClauseOps &clauseOps, llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &iv, llvm::SmallVectorImpl<mlir::Type> &reductionTypes, @@ -1501,8 +1486,8 @@ static void genWsloopClauses( if (ReductionProcessor::doReductionByRef(clauseOps.reductionVars)) clauseOps.reductionByRefAttr = firOpBuilder.getUnitAttr(); - if (endClauses) { - ClauseProcessor ecp(converter, semaCtx, *endClauses); + if (!endClauses.empty()) { + ClauseProcessor ecp(converter, semaCtx, endClauses); ecp.processNowait(clauseOps); } @@ -1525,8 +1510,7 @@ static mlir::omp::CriticalOp genCriticalOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List<Clause> &clauses, const std::optional<Fortran::parser::Name> &name) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); mlir::FlatSymbolRefAttr nameAttr; @@ -1537,7 +1521,7 @@ genCriticalOp(Fortran::lower::AbstractConverter &converter, auto global = mod.lookupSymbol<mlir::omp::CriticalDeclareOp>(nameStr); if (!global) { mlir::omp::CriticalClauseOps clauseOps; - genCriticalDeclareClauses(converter, semaCtx, clauseList, loc, clauseOps, + genCriticalDeclareClauses(converter, semaCtx, clauses, loc, clauseOps, nameStr); mlir::OpBuilder modBuilder(mod.getBodyRegion()); @@ -1556,8 +1540,7 @@ static mlir::omp::DistributeOp genDistributeOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { TODO(loc, "Distribute construct"); return nullptr; } @@ -1566,12 +1549,9 @@ static mlir::omp::FlushOp genFlushOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const std::optional<Fortran::parser::OmpObjectList> &objectList, - const std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>> - &clauseList) { + const ObjectList &objects, const List<Clause> &clauses) { llvm::SmallVector<mlir::Value> operandRange; - genFlushClauses(converter, semaCtx, objectList, clauseList, loc, - operandRange); + genFlushClauses(converter, semaCtx, objects, clauses, loc, operandRange); return converter.getFirOpBuilder().create<mlir::omp::FlushOp>( converter.getCurrentLocation(), operandRange); @@ -1591,7 +1571,7 @@ static mlir::omp::OrderedOp genOrderedOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List<Clause> &clauses) { TODO(loc, "OMPD_ordered"); return nullptr; } @@ -1600,10 +1580,9 @@ static mlir::omp::OrderedRegionOp genOrderedRegionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { mlir::omp::OrderedRegionClauseOps clauseOps; - genOrderedRegionClauses(converter, semaCtx, clauseList, loc, clauseOps); + genOrderedRegionClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody<mlir::omp::OrderedRegionOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval).setGenNested(genNested), @@ -1615,8 +1594,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, Fortran::lower::SymMap &symTable, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List<Clause> &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1624,7 +1602,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> privateSyms; llvm::SmallVector<mlir::Type> reductionTypes; llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms; - genParallelClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genParallelClauses(converter, semaCtx, stmtCtx, clauses, loc, /*processReduction=*/!outerCombined, clauseOps, reductionTypes, reductionSyms); @@ -1637,7 +1615,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList) + .setClauses(&clauses) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(reductionCallback); @@ -1645,7 +1623,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody<mlir::omp::ParallelOp>(genInfo, clauseOps); bool privatize = !outerCombined; - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval, + DataSharingProcessor dsp(converter, semaCtx, clauses, eval, /*useDelayedPrivatization=*/true, &symTable); if (privatize) @@ -1692,14 +1670,13 @@ static mlir::omp::SectionOp genSectionOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { // Currently only private/firstprivate clause is handled, and // all privatization is done within `omp.section` operations. return genOpWithBody<mlir::omp::SectionOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList)); + .setClauses(&clauses)); } static mlir::omp::SectionsOp @@ -1716,18 +1693,17 @@ static mlir::omp::SimdLoopOp genSimdLoopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { - DataSharingProcessor dsp(converter, semaCtx, clauseList, eval); + const List<Clause> &clauses) { + DataSharingProcessor dsp(converter, semaCtx, clauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; mlir::omp::SimdLoopClauseOps clauseOps; llvm::SmallVector<const Fortran::semantics::Symbol *> iv; - genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauseList, loc, - clauseOps, iv); + genSimdLoopClauses(converter, semaCtx, stmtCtx, eval, clauses, loc, clauseOps, + iv); - auto *nestedEval = - getCollapsedLoopEval(eval, Fortran::lower::getCollapseValue(clauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(clauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopVars(op, converter, loc, iv); @@ -1735,7 +1711,7 @@ genSimdLoopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody<mlir::omp::SimdLoopOp>( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&clauseList) + .setClauses(&clauses) .setDataSharingProcessor(&dsp) .setGenRegionEntryCb(ivCallback), clauseOps); @@ -1745,17 +1721,16 @@ static mlir::omp::SingleOp genSingleOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList &endClauseList) { + mlir::Location loc, const List<Clause> &beginClauses, + const List<Clause> &endClauses) { mlir::omp::SingleClauseOps clauseOps; - genSingleClauses(converter, semaCtx, beginClauseList, endClauseList, loc, + genSingleClauses(converter, semaCtx, beginClauses, endClauses, loc, clauseOps); return genOpWithBody<mlir::omp::SingleOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&beginClauseList), + .setClauses(&beginClauses), clauseOps); } @@ -1763,8 +1738,7 @@ static mlir::omp::TargetOp genTargetOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List<Clause> &clauses, bool outerCombined = false) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1777,7 +1751,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> mapSyms; llvm::SmallVector<mlir::Location> mapSymLocs; llvm::SmallVector<mlir::Type> mapSymTypes; - genTargetClauses(converter, semaCtx, stmtCtx, clauseList, loc, + genTargetClauses(converter, semaCtx, stmtCtx, clauses, loc, processHostOnlyClauses, /*processReduction=*/outerCombined, clauseOps, mapSyms, mapSymLocs, mapSymTypes); @@ -1875,14 +1849,13 @@ static mlir::omp::TargetDataOp genTargetDataOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TargetDataClauseOps clauseOps; llvm::SmallVector<mlir::Type> useDeviceTypes; llvm::SmallVector<mlir::Location> useDeviceLocs; llvm::SmallVector<const Fortran::semantics::Symbol *> useDeviceSyms; - genTargetDataClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps, + genTargetDataClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps, useDeviceTypes, useDeviceLocs, useDeviceSyms); auto targetDataOp = @@ -1894,11 +1867,11 @@ genTargetDataOp(Fortran::lower::AbstractConverter &converter, return targetDataOp; } -template <typename OpTy> -static OpTy genTargetEnterExitUpdateDataOp( - Fortran::lower::AbstractConverter &converter, - Fortran::semantics::SemanticsContext &semaCtx, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { +template <typename OpTy> static OpTy +genTargetEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter, + Fortran::semantics::SemanticsContext &semaCtx, + mlir::Location loc, + const List<Clause> &clauses) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); Fortran::lower::StatementContext stmtCtx; @@ -1915,8 +1888,8 @@ static OpTy genTargetEnterExitUpdateDataOp( } mlir::omp::TargetEnterExitUpdateDataClauseOps clauseOps; - genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauseList, - loc, directive, clauseOps); + genTargetEnterExitUpdateDataClauses(converter, semaCtx, stmtCtx, clauses, loc, + directive, clauseOps); return firOpBuilder.create<OpTy>(loc, clauseOps); } @@ -1925,16 +1898,15 @@ static mlir::omp::TaskOp genTaskOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TaskClauseOps clauseOps; - genTaskClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTaskClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody<mlir::omp::TaskOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1942,15 +1914,14 @@ static mlir::omp::TaskgroupOp genTaskgroupOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + mlir::Location loc, const List<Clause> &clauses) { mlir::omp::TaskgroupClauseOps clauseOps; - genTaskgroupClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskgroupClauses(converter, semaCtx, clauses, loc, clauseOps); return genOpWithBody<mlir::omp::TaskgroupOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -1958,7 +1929,7 @@ static mlir::omp::TaskloopOp genTaskloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List<Clause> &clauses) { TODO(loc, "Taskloop construct"); } @@ -1966,9 +1937,9 @@ static mlir::omp::TaskwaitOp genTaskwaitOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &clauseList) { + const List<Clause> &clauses) { mlir::omp::TaskwaitClauseOps clauseOps; - genTaskwaitClauses(converter, semaCtx, clauseList, loc, clauseOps); + genTaskwaitClauses(converter, semaCtx, clauses, loc, clauseOps); return converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(loc, clauseOps); } @@ -1984,17 +1955,17 @@ static mlir::omp::TeamsOp genTeamsOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, bool genNested, - mlir::Location loc, const Fortran::parser::OmpClauseList &clauseList, + mlir::Location loc, const List<Clause> &clauses, bool outerCombined = false) { Fortran::lower::StatementContext stmtCtx; mlir::omp::TeamsClauseOps clauseOps; - genTeamsClauses(converter, semaCtx, stmtCtx, clauseList, loc, clauseOps); + genTeamsClauses(converter, semaCtx, stmtCtx, clauses, loc, clauseOps); return genOpWithBody<mlir::omp::TeamsOp>( OpWithBodyGenInfo(converter, semaCtx, loc, eval) .setGenNested(genNested) .setOuterCombined(outerCombined) - .setClauses(&clauseList), + .setClauses(&clauses), clauseOps); } @@ -2002,9 +1973,8 @@ static mlir::omp::WsloopOp genWsloopOp(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, mlir::Location loc, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList) { - DataSharingProcessor dsp(converter, semaCtx, beginClauseList, eval); + const List<Clause> &beginClauses, const List<Clause> &endClauses) { + DataSharingProcessor dsp(converter, semaCtx, beginClauses, eval); dsp.processStep1(); Fortran::lower::StatementContext stmtCtx; @@ -2012,12 +1982,10 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, llvm::SmallVector<const Fortran::semantics::Symbol *> iv; llvm::SmallVector<mlir::Type> reductionTypes; llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSyms; - genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauseList, - endClauseList, loc, clauseOps, iv, reductionTypes, - reductionSyms); + genWsloopClauses(converter, semaCtx, stmtCtx, eval, beginClauses, endClauses, + loc, clauseOps, iv, reductionTypes, reductionSyms); - auto *nestedEval = getCollapsedLoopEval( - eval, Fortran::lower::getCollapseValue(beginClauseList)); + auto *nestedEval = getCollapsedLoopEval(eval, getCollapseValue(beginClauses)); auto ivCallback = [&](mlir::Operation *op) { return genLoopAndReductionVars(op, converter, loc, iv, reductionSyms, @@ -2026,7 +1994,7 @@ genWsloopOp(Fortran::lower::AbstractConverter &converter, return genOpWithBody<mlir::omp::WsloopOp>( OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval) - .setClauses(&beginClauseList) + .setClauses(&beginClauses) .setDataSharingProcessor(&dsp) .setReductions(&reductionSyms, &reductionTypes) .setGenRegionEntryCb(ivCallback), @@ -2041,8 +2009,8 @@ static void genCompositeDistributeParallelDo( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO"); } @@ -2050,8 +2018,8 @@ static void genCompositeDistributeParallelDoSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE PARALLEL DO SIMD"); } @@ -2059,8 +2027,8 @@ static void genCompositeDistributeSimd( Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, mlir::Location loc) { + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc) { TODO(loc, "Composite DISTRIBUTE SIMD"); } @@ -2068,10 +2036,10 @@ static void genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc) { - ClauseProcessor cp(converter, semaCtx, beginClauseList); + ClauseProcessor cp(converter, semaCtx, beginClauses); cp.processTODO<clause::Aligned, clause::Allocate, clause::Linear, clause::Order, clause::Safelen, clause::Simdlen>( loc, llvm::omp::OMPD_do_simd); @@ -2083,15 +2051,15 @@ genCompositeDoSimd(Fortran::lower::AbstractConverter &converter, // When support for vectorization is enabled, then we need to add handling of // if clause. Currently if clause can be skipped because we always assume // SIMD length = 1. - genWsloopOp(converter, semaCtx, eval, loc, beginClauseList, endClauseList); + genWsloopOp(converter, semaCtx, eval, loc, beginClauses, endClauses); } static void genCompositeTaskloopSimd(Fortran::lower::AbstractConverter &converter, Fortran::semantics::SemanticsContext &semaCtx, Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OmpClauseList &beginClauseList, - const Fortran::parser::OmpClauseList *endClauseList, + const List<Clause> &beginClauses, + const List<Clause> &endClauses, mlir::Location loc) { TODO(loc, "Composite TASKLOOP SIMD"); } @@ -2201,8 +2169,9 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const auto &directive = std::get<Fortran::parser::OmpSimpleStandaloneDirective>( simpleStandaloneConstruct.t); - const auto &clauseList = - std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t); + List<Clause> clauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t), + semaCtx); mlir::Location currentLocation = converter.genLocation(directive.source); switch (directive.v) { @@ -2212,29 +2181,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, genBarrierOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_taskwait: - genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauseList); + genTaskwaitOp(converter, semaCtx, eval, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_taskyield: genTaskyieldOp(converter, semaCtx, eval, currentLocation); break; case llvm::omp::Directive::OMPD_target_data: genTargetDataOp(converter, semaCtx, eval, /*genNested=*/true, - currentLocation, clauseList); + currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_enter_data: genTargetEnterExitUpdateDataOp<mlir::omp::TargetEnterDataOp>( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_exit_data: genTargetEnterExitUpdateDataOp<mlir::omp::TargetExitDataOp>( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_target_update: genTargetEnterExitUpdateDataOp<mlir::omp::TargetUpdateOp>( - converter, semaCtx, currentLocation, clauseList); + converter, semaCtx, currentLocation, clauses); break; case llvm::omp::Directive::OMPD_ordered: - genOrderedOp(converter, semaCtx, eval, currentLocation, clauseList); + genOrderedOp(converter, semaCtx, eval, currentLocation, clauses); break; } } @@ -2251,8 +2220,14 @@ genOMP(Fortran::lower::AbstractConverter &converter, const auto &clauseList = std::get<std::optional<std::list<Fortran::parser::OmpMemoryOrderClause>>>( flushConstruct.t); + ObjectList objects = + objectList ? makeObjects(*objectList, semaCtx) : ObjectList{}; + List<Clause> clauses = + clauseList ? makeList(*clauseList, + [&](auto &&s) { return makeClause(s.v, semaCtx); }) + : List<Clause>{}; mlir::Location currentLocation = converter.genLocation(verbatim.source); - genFlushOp(converter, semaCtx, eval, currentLocation, objectList, clauseList); + genFlushOp(converter, semaCtx, eval, currentLocation, objects, clauses); } static void @@ -2357,44 +2332,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, converter.genLocation(beginBlockDirective.source); const auto origDirective = std::get<Fortran::parser::OmpBlockDirective>(beginBlockDirective.t).v; - const auto &beginClauseList = - std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t); - const auto &endClauseList = - std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t); + List<Clause> beginClauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(beginBlockDirective.t), semaCtx); + List<Clause> endClauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(endBlockDirective.t), semaCtx); assert(llvm::omp::blockConstructSet.test(origDirective) && "Expected block construct"); - for (const Fortran::parser::OmpClause &clause : beginClauseList.v) { + for (const Clause &clause : beginClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if<Fortran::parser::OmpClause::If>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::NumThreads>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::ProcBind>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Allocate>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Default>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Final>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Priority>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Reduction>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Private>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Firstprivate>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Copyin>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Shared>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Threads>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Map>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::ThreadLimit>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::NumTeams>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Simd>(&clause.u)) { + if (!std::get_if<clause::If>(&clause.u) && + !std::get_if<clause::NumThreads>(&clause.u) && + !std::get_if<clause::ProcBind>(&clause.u) && + !std::get_if<clause::Allocate>(&clause.u) && + !std::get_if<clause::Default>(&clause.u) && + !std::get_if<clause::Final>(&clause.u) && + !std::get_if<clause::Priority>(&clause.u) && + !std::get_if<clause::Reduction>(&clause.u) && + !std::get_if<clause::Depend>(&clause.u) && + !std::get_if<clause::Private>(&clause.u) && + !std::get_if<clause::Firstprivate>(&clause.u) && + !std::get_if<clause::Copyin>(&clause.u) && + !std::get_if<clause::Shared>(&clause.u) && + !std::get_if<clause::Threads>(&clause.u) && + !std::get_if<clause::Map>(&clause.u) && + !std::get_if<clause::UseDevicePtr>(&clause.u) && + !std::get_if<clause::UseDeviceAddr>(&clause.u) && + !std::get_if<clause::ThreadLimit>(&clause.u) && + !std::get_if<clause::NumTeams>(&clause.u) && + !std::get_if<clause::Simd>(&clause.u)) { TODO(clauseLocation, "OpenMP Block construct clause"); } } - for (const auto &clause : endClauseList.v) { + for (const Clause &clause : endClauses) { mlir::Location clauseLocation = converter.genLocation(clause.source); - if (!std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u) && - !std::get_if<Fortran::parser::OmpClause::Copyprivate>(&clause.u)) + if (!std::get_if<clause::Nowait>(&clause.u) && + !std::get_if<clause::Copyprivate>(&clause.u)) TODO(clauseLocation, "OpenMP Block construct clause"); } @@ -2413,44 +2388,44 @@ genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_ordered: // 2.17.9 ORDERED construct. genOrderedRegionOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, outerCombined); + currentLocation, beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_single: // 2.8.2 SINGLE construct. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, outerCombined); + beginClauses, outerCombined); break; case llvm::omp::Directive::OMPD_target_data: // 2.12.2 TARGET DATA construct. genTargetDataOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_task: // 2.10.1 TASK construct. genTaskOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_taskgroup: // 2.17.6 TASKGROUP construct. genTaskgroupOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. // FIXME Pass the outerCombined argument or rename it to better describe // what it represents if it must always be `false` in this context. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_workshare: // 2.8.3 WORKSHARE construct. @@ -2458,7 +2433,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, // implementation for this feature will come later. For the codes // that use this construct, add a single construct for now. genSingleOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, endClauseList); + beginClauses, endClauses); break; default: llvm_unreachable("Unexpected block construct"); @@ -2476,11 +2451,12 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPCriticalConstruct &criticalConstruct) { const auto &cd = std::get<Fortran::parser::OmpCriticalDirective>(criticalConstruct.t); - const auto &clauseList = std::get<Fortran::parser::OmpClauseList>(cd.t); + List<Clause> clauses = + makeClauses(std::get<Fortran::parser::OmpClauseList>(cd.t), semaCtx); const auto &name = std::get<std::optional<Fortran::parser::Name>>(cd.t); mlir::Location currentLocation = converter.getCurrentLocation(); genCriticalOp(converter, semaCtx, eval, /*genNested=*/true, currentLocation, - clauseList, name); + clauses, name); } static void @@ -2499,8 +2475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPLoopConstruct &loopConstruct) { const auto &beginLoopDirective = std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t); - const auto &beginClauseList = - std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t); + List<Clause> beginClauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(beginLoopDirective.t), semaCtx); mlir::Location currentLocation = converter.genLocation(beginLoopDirective.source); const auto origDirective = @@ -2509,15 +2485,15 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, assert(llvm::omp::loopConstructSet.test(origDirective) && "Expected loop construct"); - const auto *endClauseList = [&]() { - using RetTy = const Fortran::parser::OmpClauseList *; + List<Clause> endClauses = [&]() { if (auto &endLoopDirective = std::get<std::optional<Fortran::parser::OmpEndLoopDirective>>( loopConstruct.t)) { - return RetTy( - &std::get<Fortran::parser::OmpClauseList>((*endLoopDirective).t)); + return makeClauses( + std::get<Fortran::parser::OmpClauseList>(endLoopDirective->t), + semaCtx); } - return RetTy(); + return List<Clause>{}; }(); std::optional<llvm::omp::Directive> nextDir = origDirective; @@ -2530,29 +2506,29 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute_parallel_do: // 2.9.4.3 DISTRIBUTE PARALLEL Worksharing-Loop construct. genCompositeDistributeParallelDo(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_parallel_do_simd: // 2.9.4.4 DISTRIBUTE PARALLEL Worksharing-Loop SIMD construct. genCompositeDistributeParallelDoSimd(converter, semaCtx, eval, - beginClauseList, endClauseList, + beginClauses, endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_distribute_simd: // 2.9.4.2 DISTRIBUTE SIMD construct. - genCompositeDistributeSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDistributeSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_do_simd: // 2.9.3.2 Worksharing-Loop SIMD construct. - genCompositeDoSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeDoSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; case llvm::omp::Directive::OMPD_taskloop_simd: // 2.10.3 TASKLOOP SIMD construct. - genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauseList, - endClauseList, currentLocation); + genCompositeTaskloopSimd(converter, semaCtx, eval, beginClauses, + endClauses, currentLocation); break; default: llvm_unreachable("Unexpected composite construct"); @@ -2563,12 +2539,12 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_distribute: // 2.9.4.1 DISTRIBUTE construct. genDistributeOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_do: // 2.9.2 Worksharing-Loop construct. - genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauseList, - endClauseList); + genWsloopOp(converter, semaCtx, eval, currentLocation, beginClauses, + endClauses); break; case llvm::omp::Directive::OMPD_parallel: // 2.6 PARALLEL construct. @@ -2577,24 +2553,24 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genParallelOp(converter, symTable, semaCtx, eval, genNested, - currentLocation, beginClauseList, + currentLocation, beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_simd: // 2.9.3.1 SIMD construct. genSimdLoopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); - genOpenMPReduction(converter, semaCtx, beginClauseList); + beginClauses); + genOpenMPReduction(converter, semaCtx, beginClauses); break; case llvm::omp::Directive::OMPD_target: // 2.12.5 TARGET construct. genTargetOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_taskloop: // 2.10.2 TASKLOOP construct. genTaskloopOp(converter, semaCtx, eval, currentLocation, - beginClauseList); + beginClauses); break; case llvm::omp::Directive::OMPD_teams: // 2.7 TEAMS construct. @@ -2603,7 +2579,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter, // Maybe rename the argument if it represents something else or // initialize it properly. genTeamsOp(converter, semaCtx, eval, genNested, currentLocation, - beginClauseList, /*outerCombined=*/true); + beginClauses, /*outerCombined=*/true); break; case llvm::omp::Directive::OMPD_loop: case llvm::omp::Directive::OMPD_masked: @@ -2639,14 +2615,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, const Fortran::parser::OpenMPSectionsConstruct §ionsConstruct) { const auto &beginSectionsDirective = std::get<Fortran::parser::OmpBeginSectionsDirective>(sectionsConstruct.t); - const auto &beginClauseList = - std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t); + List<Clause> beginClauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(beginSectionsDirective.t), + semaCtx); // Process clauses before optional omp.parallel, so that new variables are // allocated outside of the parallel region mlir::Location currentLocation = converter.getCurrentLocation(); mlir::omp::SectionsClauseOps clauseOps; - genSectionsClauses(converter, semaCtx, beginClauseList, currentLocation, + genSectionsClauses(converter, semaCtx, beginClauses, currentLocation, /*clausesFromBeginSections=*/true, clauseOps); // Parallel wrapper of PARALLEL SECTIONS construct @@ -2655,14 +2632,15 @@ genOMP(Fortran::lower::AbstractConverter &converter, .v; if (dir == llvm::omp::Directive::OMPD_parallel_sections) { genParallelOp(converter, symTable, semaCtx, eval, - /*genNested=*/false, currentLocation, beginClauseList, + /*genNested=*/false, currentLocation, beginClauses, /*outerCombined=*/true); } else { const auto &endSectionsDirective = std::get<Fortran::parser::OmpEndSectionsDirective>(sectionsConstruct.t); - const auto &endClauseList = - std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t); - genSectionsClauses(converter, semaCtx, endClauseList, currentLocation, + List<Clause> endClauses = makeClauses( + std::get<Fortran::parser::OmpClauseList>(endSectionsDirective.t), + semaCtx); + genSectionsClauses(converter, semaCtx, endClauses, currentLocation, /*clausesFromBeginSections=*/false, clauseOps); } @@ -2678,7 +2656,7 @@ genOMP(Fortran::lower::AbstractConverter &converter, llvm::zip(sectionBlocks.v, eval.getNestedEvaluations())) { symTable.pushScope(); genSectionOp(converter, semaCtx, neval, /*genNested=*/true, currentLocation, - beginClauseList); + beginClauses); symTable.popScope(); firOpBuilder.restoreInsertionPoint(ip); } diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp index b9c0660aa4da8e..da3f2be73e5095 100644 --- a/flang/lib/Lower/OpenMP/Utils.cpp +++ b/flang/lib/Lower/OpenMP/Utils.cpp @@ -36,6 +36,17 @@ namespace Fortran { namespace lower { namespace omp { +int64_t getCollapseValue(const List<Clause> &clauses) { + auto iter = llvm::find_if(clauses, [](const Clause &clause) { + return clause.id == llvm::omp::Clause::OMPC_collapse; + }); + if (iter != clauses.end()) { + const auto &collapse = std::get<clause::Collapse>(iter->u); + return evaluate::ToInt64(collapse.v).value(); + } + return 1; +} + void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl<mlir::Value> &operands) { @@ -52,25 +63,6 @@ void genObjectList(const ObjectList &objects, } } -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl<mlir::Value> &operands) { - auto addOperands = [&](Fortran::lower::SymbolRef sym) { - const mlir::Value variable = converter.getSymbolAddress(sym); - if (variable) { - operands.push_back(variable); - } else if (const auto *details = - sym->detailsIf<Fortran::semantics::HostAssocDetails>()) { - operands.push_back(converter.getSymbolAddress(details->symbol())); - converter.copySymbolBinding(details->symbol(), sym); - } - }; - for (const Fortran::parser::OmpObject &ompObject : objectList.v) { - Fortran::semantics::Symbol *sym = getOmpObjectSymbol(ompObject); - addOperands(*sym); - } -} - mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter, std::size_t loopVarTypeSize) { // OpenMP runtime requires 32-bit or 64-bit loop variables. diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h index 4074bf73987d5b..b3a9f7f30c98bd 100644 --- a/flang/lib/Lower/OpenMP/Utils.h +++ b/flang/lib/Lower/OpenMP/Utils.h @@ -58,6 +58,8 @@ void gatherFuncAndVarSyms( const ObjectList &objects, mlir::omp::DeclareTargetCaptureClause clause, llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause); +int64_t getCollapseValue(const List<Clause> &clauses); + Fortran::semantics::Symbol * getOmpObjectSymbol(const Fortran::parser::OmpObject &ompObject); @@ -65,10 +67,6 @@ void genObjectList(const ObjectList &objects, Fortran::lower::AbstractConverter &converter, llvm::SmallVectorImpl<mlir::Value> &operands); -void genObjectList2(const Fortran::parser::OmpObjectList &objectList, - Fortran::lower::AbstractConverter &converter, - llvm::SmallVectorImpl<mlir::Value> &operands); - } // namespace omp } // namespace lower } // namespace Fortran >From 291dc48d5e0b7e0ee39681a1276bd1d63f456b01 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Mon, 1 Apr 2024 10:07:45 -0500 Subject: [PATCH 2/3] [Frontend][OpenMP] Refactor getLeafConstructs, add getCompoundConstruct Emit a special leaf constuct table in DirectiveEmitter.cpp, which will allow both decomposition of a construct into leafs, and composition of constituent constructs into a single compound construct (is possible). --- llvm/include/llvm/Frontend/OpenMP/OMP.h | 7 + llvm/lib/Frontend/OpenMP/OMP.cpp | 64 +++++- llvm/test/TableGen/directive1.td | 19 +- llvm/test/TableGen/directive2.td | 19 +- llvm/unittests/Frontend/CMakeLists.txt | 1 + llvm/unittests/Frontend/OpenMPComposeTest.cpp | 41 ++++ llvm/utils/TableGen/DirectiveEmitter.cpp | 194 +++++++++++------- 7 files changed, 258 insertions(+), 87 deletions(-) create mode 100644 llvm/unittests/Frontend/OpenMPComposeTest.cpp diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.h b/llvm/include/llvm/Frontend/OpenMP/OMP.h index a85cd9d344c6d7..4ed47f15dfe59e 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.h @@ -15,4 +15,11 @@ #include "llvm/Frontend/OpenMP/OMP.h.inc" +#include "llvm/ADT/ArrayRef.h" + +namespace llvm::omp { +ArrayRef<Directive> getLeafConstructs(Directive D); +Directive getCompoundConstruct(ArrayRef<Directive> Parts); +} // namespace llvm::omp + #endif // LLVM_FRONTEND_OPENMP_OMP_H diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index 4f2f95392648b3..dd99d3d074fd1e 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -8,12 +8,74 @@ #include "llvm/Frontend/OpenMP/OMP.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/ErrorHandling.h" +#include <algorithm> +#include <iterator> +#include <type_traits> + using namespace llvm; -using namespace omp; +using namespace llvm::omp; #define GEN_DIRECTIVES_IMPL #include "llvm/Frontend/OpenMP/OMP.inc" + +namespace llvm::omp { +ArrayRef<Directive> getLeafConstructs(Directive D) { + auto Idx = static_cast<int>(D); + if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize)) + return {}; + const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; + return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1])); +} + +Directive getCompoundConstruct(ArrayRef<Directive> Parts) { + if (Parts.empty()) + return OMPD_unknown; + + // Parts don't have to be leafs, so expand them into leafs first. + // Store the expanded leafs in the same format as rows in the leaf + // table (generated by tablegen). + SmallVector<Directive> RawLeafs(2); + for (Directive P : Parts) { + ArrayRef<Directive> Ls = getLeafConstructs(P); + if (!Ls.empty()) + RawLeafs.append(Ls.begin(), Ls.end()); + else + RawLeafs.push_back(P); + } + + auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)}; + if (GivenLeafs.size() == 1) + return GivenLeafs.front(); + RawLeafs[1] = static_cast<Directive>(GivenLeafs.size()); + + auto Iter = llvm::lower_bound( + LeafConstructTable, + static_cast<std::decay_t<decltype(*LeafConstructTable)>>(RawLeafs.data()), + [](const auto *RowA, const auto *RowB) { + const auto *BeginA = &RowA[2]; + const auto *EndA = BeginA + static_cast<int>(RowA[1]); + const auto *BeginB = &RowB[2]; + const auto *EndB = BeginB + static_cast<int>(RowB[1]); + if (BeginA == EndA && BeginB == EndB) + return static_cast<int>(RowA[0]) < static_cast<int>(RowB[0]); + return std::lexicographical_compare(BeginA, EndA, BeginB, EndB); + }); + + if (Iter == std::end(LeafConstructTable)) + return OMPD_unknown; + + // Verify that we got a match. + Directive Found = (*Iter)[0]; + ArrayRef<Directive> FoundLeafs = getLeafConstructs(Found); + if (FoundLeafs == GivenLeafs) + return Found; + return OMPD_unknown; +} +} // namespace llvm::omp diff --git a/llvm/test/TableGen/directive1.td b/llvm/test/TableGen/directive1.td index 3184f625ead928..e6150210e7e9a4 100644 --- a/llvm/test/TableGen/directive1.td +++ b/llvm/test/TableGen/directive1.td @@ -52,6 +52,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-EMPTY: // CHECK-NEXT: #include "llvm/ADT/ArrayRef.h" // CHECK-NEXT: #include "llvm/ADT/BitmaskEnum.h" +// CHECK-NEXT: #include <cstddef> // CHECK-EMPTY: // CHECK-NEXT: namespace llvm { // CHECK-NEXT: class StringRef; @@ -112,7 +113,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version. // CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version); // CHECK-EMPTY: -// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D); +// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; } // CHECK-NEXT: Association getDirectiveAssociation(Directive D); // CHECK-NEXT: AKind getAKind(StringRef); // CHECK-NEXT: llvm::StringRef getTdlAKindName(AKind); @@ -359,13 +360,6 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind"); // IMPL-NEXT: } // IMPL-EMPTY: -// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) { -// IMPL-NEXT: switch (Dir) { -// IMPL-NEXT: default: -// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{}; -// IMPL-NEXT: } // switch (Dir) -// IMPL-NEXT: } -// IMPL-EMPTY: // IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) { // IMPL-NEXT: switch (Dir) { // IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira: @@ -374,4 +368,13 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Unexpected directive"); // IMPL-NEXT: } // IMPL-EMPTY: +// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int)); +// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = { +// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0), +// IMPL-NEXT: }; +// IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { +// IMPL-NEXT: 0, +// IMPL-NEXT: }; +// IMPL-EMPTY: // IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL diff --git a/llvm/test/TableGen/directive2.td b/llvm/test/TableGen/directive2.td index d6fa4835c8dfdc..1750022e1f94ea 100644 --- a/llvm/test/TableGen/directive2.td +++ b/llvm/test/TableGen/directive2.td @@ -45,6 +45,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: #define LLVM_Tdl_INC // CHECK-EMPTY: // CHECK-NEXT: #include "llvm/ADT/ArrayRef.h" +// CHECK-NEXT: #include <cstddef> // CHECK-EMPTY: // CHECK-NEXT: namespace llvm { // CHECK-NEXT: class StringRef; @@ -88,7 +89,7 @@ def TDL_DirA : Directive<"dira"> { // CHECK-NEXT: /// Return true if \p C is a valid clause for \p D in version \p Version. // CHECK-NEXT: bool isAllowedClauseForDirective(Directive D, Clause C, unsigned Version); // CHECK-EMPTY: -// CHECK-NEXT: llvm::ArrayRef<Directive> getLeafConstructs(Directive D); +// CHECK-NEXT: constexpr std::size_t getMaxLeafCount() { return 0; } // CHECK-NEXT: Association getDirectiveAssociation(Directive D); // CHECK-NEXT: } // namespace tdl // CHECK-NEXT: } // namespace llvm @@ -290,13 +291,6 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Invalid Tdl Directive kind"); // IMPL-NEXT: } // IMPL-EMPTY: -// IMPL-NEXT: llvm::ArrayRef<llvm::tdl::Directive> llvm::tdl::getLeafConstructs(llvm::tdl::Directive Dir) { -// IMPL-NEXT: switch (Dir) { -// IMPL-NEXT: default: -// IMPL-NEXT: return ArrayRef<llvm::tdl::Directive>{}; -// IMPL-NEXT: } // switch (Dir) -// IMPL-NEXT: } -// IMPL-EMPTY: // IMPL-NEXT: llvm::tdl::Association llvm::tdl::getDirectiveAssociation(llvm::tdl::Directive Dir) { // IMPL-NEXT: switch (Dir) { // IMPL-NEXT: case llvm::tdl::Directive::TDLD_dira: @@ -305,4 +299,13 @@ def TDL_DirA : Directive<"dira"> { // IMPL-NEXT: llvm_unreachable("Unexpected directive"); // IMPL-NEXT: } // IMPL-EMPTY: +// IMPL-NEXT: static_assert(sizeof(llvm::tdl::Directive) == sizeof(int)); +// IMPL-NEXT: {{.*}} static const llvm::tdl::Directive LeafConstructTable[][2] = { +// IMPL-NEXT: llvm::tdl::TDLD_dira, static_cast<llvm::tdl::Directive>(0), +// IMPL-NEXT: }; +// IMPL-EMPTY: +// IMPL-NEXT: {{.*}} static const int LeafConstructTableOrdering[] = { +// IMPL-NEXT: 0, +// IMPL-NEXT: }; +// IMPL-EMPTY: // IMPL-NEXT: #endif // GEN_DIRECTIVES_IMPL diff --git a/llvm/unittests/Frontend/CMakeLists.txt b/llvm/unittests/Frontend/CMakeLists.txt index c6f60142d6276a..ddb6a16cbb984e 100644 --- a/llvm/unittests/Frontend/CMakeLists.txt +++ b/llvm/unittests/Frontend/CMakeLists.txt @@ -14,6 +14,7 @@ add_llvm_unittest(LLVMFrontendTests OpenMPContextTest.cpp OpenMPIRBuilderTest.cpp OpenMPParsingTest.cpp + OpenMPComposeTest.cpp DEPENDS acc_gen diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp new file mode 100644 index 00000000000000..29b1be4eb3432c --- /dev/null +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -0,0 +1,41 @@ +//===- llvm/unittests/Frontend/OpenMPComposeTest.cpp ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Frontend/OpenMP/OMP.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::omp; + +TEST(Composition, GetLeafConstructs) { + ArrayRef<Directive> L1 = getLeafConstructs(OMPD_loop); + ASSERT_EQ(L1, (ArrayRef<Directive>{})); + ArrayRef<Directive> L2 = getLeafConstructs(OMPD_parallel_for); + ASSERT_EQ(L2, (ArrayRef<Directive>{OMPD_parallel, OMPD_for})); + ArrayRef<Directive> L3 = getLeafConstructs(OMPD_parallel_for_simd); + ASSERT_EQ(L3, (ArrayRef<Directive>{OMPD_parallel, OMPD_for, OMPD_simd})); +} + +TEST(Composition, GetCompoundConstruct) { + Directive C1 = + getCompoundConstruct({OMPD_target, OMPD_teams, OMPD_distribute}); + ASSERT_EQ(C1, OMPD_target_teams_distribute); + Directive C2 = getCompoundConstruct({OMPD_target}); + ASSERT_EQ(C2, OMPD_target); + Directive C3 = getCompoundConstruct({OMPD_target, OMPD_masked}); + ASSERT_EQ(C3, OMPD_unknown); + Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C4, OMPD_target_teams_distribute); + Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); + ASSERT_EQ(C5, OMPD_target_teams_distribute); + Directive C6 = getCompoundConstruct({}); + ASSERT_EQ(C6, OMPD_unknown); + Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); + ASSERT_EQ(C7, OMPD_parallel_for_simd); +} diff --git a/llvm/utils/TableGen/DirectiveEmitter.cpp b/llvm/utils/TableGen/DirectiveEmitter.cpp index e0edf1720f8ac5..2d2b7748491897 100644 --- a/llvm/utils/TableGen/DirectiveEmitter.cpp +++ b/llvm/utils/TableGen/DirectiveEmitter.cpp @@ -20,6 +20,9 @@ #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" +#include <numeric> +#include <vector> + using namespace llvm; namespace { @@ -39,7 +42,8 @@ class IfDefScope { }; } // namespace -// Generate enum class +// Generate enum class. Entries are emitted in the order in which they appear +// in the `Records` vector. static void GenerateEnumClass(const std::vector<Record *> &Records, raw_ostream &OS, StringRef Enum, StringRef Prefix, const DirectiveLanguage &DirLang, @@ -175,6 +179,16 @@ bool DirectiveLanguage::HasValidityErrors() const { return HasDuplicateClausesInDirectives(getDirectives()); } +// Count the maximum number of leaf constituents per construct. +static size_t GetMaxLeafCount(const DirectiveLanguage &DirLang) { + size_t MaxCount = 0; + for (Record *R : DirLang.getDirectives()) { + size_t Count = Directive{R}.getLeafConstructs().size(); + MaxCount = std::max(MaxCount, Count); + } + return MaxCount; +} + // Generate the declaration section for the enumeration in the directive // language static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { @@ -189,6 +203,7 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { if (DirLang.hasEnableBitmaskEnumInNamespace()) OS << "#include \"llvm/ADT/BitmaskEnum.h\"\n"; + OS << "#include <cstddef>\n"; // for size_t OS << "\n"; OS << "namespace llvm {\n"; OS << "class StringRef;\n"; @@ -244,7 +259,8 @@ static void EmitDirectivesDecl(RecordKeeper &Records, raw_ostream &OS) { OS << "bool isAllowedClauseForDirective(Directive D, " << "Clause C, unsigned Version);\n"; OS << "\n"; - OS << "llvm::ArrayRef<Directive> getLeafConstructs(Directive D);\n"; + OS << "constexpr std::size_t getMaxLeafCount() { return " + << GetMaxLeafCount(DirLang) << "; }\n"; OS << "Association getDirectiveAssociation(Directive D);\n"; if (EnumHelperFuncs.length() > 0) { OS << EnumHelperFuncs; @@ -396,6 +412,19 @@ GenerateCaseForVersionedClauses(const std::vector<Record *> &Clauses, } } +static std::string GetDirectiveName(const DirectiveLanguage &DirLang, + const Record *Rec) { + Directive Dir{Rec}; + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::" + + DirLang.getDirectivePrefix() + Dir.getFormattedName()) + .str(); +} + +static std::string GetDirectiveType(const DirectiveLanguage &DirLang) { + return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + "::Directive") + .str(); +} + // Generate the isAllowedClauseForDirective function implementation. static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, raw_ostream &OS) { @@ -450,77 +479,102 @@ static void GenerateIsAllowedClause(const DirectiveLanguage &DirLang, OS << "}\n"; // End of function isAllowedClauseForDirective } -// Generate the getLeafConstructs function implementation. -static void GenerateGetLeafConstructs(const DirectiveLanguage &DirLang, - raw_ostream &OS) { - auto getQualifiedName = [&](StringRef Formatted) -> std::string { - return (llvm::Twine("llvm::") + DirLang.getCppNamespace() + - "::Directive::" + DirLang.getDirectivePrefix() + Formatted) - .str(); - }; - - // For each list of leaves, generate a static local object, then - // return a reference to that object for a given directive, e.g. +static void EmitLeafTable(const DirectiveLanguage &DirLang, raw_ostream &OS, + StringRef TableName) { + // The leaf constructs are emitted in a form of a 2D table, where each + // row corresponds to a directive (and there is a row for each directive). // - // static ListTy leafConstructs_A_B = { A, B }; - // static ListTy leafConstructs_C_D_E = { C, D, E }; - // switch (Dir) { - // case A_B: - // return leafConstructs_A_B; - // case C_D_E: - // return leafConstructs_C_D_E; - // } - - // Map from a record that defines a directive to the name of the - // local object with the list of its leaves. - DenseMap<Record *, std::string> ListNames; - - std::string DirectiveTypeName = - std::string("llvm::") + DirLang.getCppNamespace().str() + "::Directive"; - - OS << '\n'; - - // ArrayRef<...> llvm::<ns>::GetLeafConstructs(llvm::<ns>::Directive Dir) - OS << "llvm::ArrayRef<" << DirectiveTypeName - << "> llvm::" << DirLang.getCppNamespace() << "::getLeafConstructs(" - << DirectiveTypeName << " Dir) "; - OS << "{\n"; - - // Generate the locals. - for (Record *R : DirLang.getDirectives()) { - Directive Dir{R}; + // Each row consists of + // - the id of the directive itself, + // - number of leaf constructs that will follow (0 for leafs), + // - ids of the leaf constructs (none if the directive is itself a leaf). + // The total number of these entries is at most MaxLeafCount+2. If this + // number is less than that, it is padded to occupy exactly MaxLeafCount+2 + // entries in memory. + // + // The rows are stored in the table in the lexicographical order. This + // is intended to enable binary search when mapping a sequence of leafs + // back to the compound directive. + // The consequence of that is that in order to find a row corresponding + // to the given directive, we'd need to scan the first element of each + // row. To avoid this, an auxiliary ordering table is created, such that + // row for Dir_A = table[auxiliary[Dir_A]]. + + std::vector<Record *> Directives = DirLang.getDirectives(); + DenseMap<Record *, size_t> DirId; // Record * -> llvm::omp::Directive + + for (auto [Idx, Rec] : llvm::enumerate(Directives)) + DirId.insert(std::make_pair(Rec, Idx)); + + using LeafList = std::vector<int>; + int MaxLeafCount = GetMaxLeafCount(DirLang); + + // The initial leaf table, rows order is same as directive order. + std::vector<LeafList> LeafTable(Directives.size()); + for (auto [Idx, Rec] : llvm::enumerate(Directives)) { + Directive Dir{Rec}; + std::vector<Record *> Leaves = Dir.getLeafConstructs(); + + auto &List = LeafTable[Idx]; + List.resize(MaxLeafCount + 2); + List[0] = Idx; // The id of the directive itself. + List[1] = Leaves.size(); // The number of leaves to follow. + + for (int I = 0; I != MaxLeafCount; ++I) + List[I + 2] = + static_cast<size_t>(I) < Leaves.size() ? DirId.at(Leaves[I]) : -1; + } - std::vector<Record *> LeafConstructs = Dir.getLeafConstructs(); - if (LeafConstructs.empty()) - continue; + // Avoid sorting the vector<vector> array, instead sort an index array. + // It will also be useful later to create the auxiliary indexing array. + std::vector<int> Ordering(Directives.size()); + std::iota(Ordering.begin(), Ordering.end(), 0); + + llvm::sort(Ordering, [&](int A, int B) { + auto &LeavesA = LeafTable[A]; + auto &LeavesB = LeafTable[B]; + if (LeavesA[1] == 0 && LeavesB[1] == 0) + return LeavesA[0] < LeavesB[0]; + return std::lexicographical_compare(&LeavesA[2], &LeavesA[2] + LeavesA[1], + &LeavesB[2], &LeavesB[2] + LeavesB[1]); + }); - std::string ListName = "leafConstructs_" + Dir.getFormattedName(); - OS << " static const " << DirectiveTypeName << ' ' << ListName - << "[] = {\n"; - for (Record *L : LeafConstructs) { - Directive LeafDir{L}; - OS << " " << getQualifiedName(LeafDir.getFormattedName()) << ",\n"; + // Emit the table + + // The directives are emitted into a scoped enum, for which the underlying + // type is `int` (by default). The code above uses `int` to store directive + // ids, so make sure that we catch it when something changes in the + // underlying type. + std::string DirectiveType = GetDirectiveType(DirLang); + OS << "\nstatic_assert(sizeof(" << DirectiveType << ") == sizeof(int));\n"; + + OS << "[[maybe_unused]] static const " << DirectiveType << ' ' << TableName + << "[][" << MaxLeafCount + 2 << "] = {\n"; + for (size_t I = 0, E = Directives.size(); I != E; ++I) { + auto &Leaves = LeafTable[Ordering[I]]; + OS << " " << GetDirectiveName(DirLang, Directives[Leaves[0]]); + OS << ", static_cast<" << DirectiveType << ">(" << Leaves[1] << "),"; + for (size_t I = 2, E = Leaves.size(); I != E; ++I) { + int Idx = Leaves[I]; + if (Idx >= 0) + OS << ' ' << GetDirectiveName(DirLang, Directives[Leaves[I]]) << ','; + else + OS << " static_cast<" << DirectiveType << ">(-1),"; } - OS << " };\n"; - ListNames.insert(std::make_pair(R, std::move(ListName))); - } - - if (!ListNames.empty()) OS << '\n'; - OS << " switch (Dir) {\n"; - for (Record *R : DirLang.getDirectives()) { - auto F = ListNames.find(R); - if (F == ListNames.end()) - continue; - - Directive Dir{R}; - OS << " case " << getQualifiedName(Dir.getFormattedName()) << ":\n"; - OS << " return " << F->second << ";\n"; } - OS << " default:\n"; - OS << " return ArrayRef<" << DirectiveTypeName << ">{};\n"; - OS << " } // switch (Dir)\n"; - OS << "}\n"; + OS << "};\n\n"; + + // Emit the auxiliary index table: it's the inverse of the `Ordering` + // table above. + OS << "[[maybe_unused]] static const int " << TableName << "Ordering[] = {\n"; + OS << " "; + std::vector<int> Reverse(Ordering.size()); + for (int I = 0, E = Ordering.size(); I != E; ++I) + Reverse[Ordering[I]] = I; + for (int Idx : Reverse) + OS << ' ' << Idx << ','; + OS << "\n};\n"; } static void GenerateGetDirectiveAssociation(const DirectiveLanguage &DirLang, @@ -1105,11 +1159,11 @@ void EmitDirectivesBasicImpl(const DirectiveLanguage &DirLang, // isAllowedClauseForDirective(Directive D, Clause C, unsigned Version) GenerateIsAllowedClause(DirLang, OS); - // getLeafConstructs(Directive D) - GenerateGetLeafConstructs(DirLang, OS); - // getDirectiveAssociation(Directive D) GenerateGetDirectiveAssociation(DirLang, OS); + + // Leaf table for getLeafConstructs, etc. + EmitLeafTable(DirLang, OS, "LeafConstructTable"); } // Generate the implemenation section for the enumeration in the directive >From 0d92781c7a52ed2fbab33ae6e7b3dae61cfd42ae Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek <krzysztof.parzys...@amd.com> Date: Tue, 2 Apr 2024 08:20:15 -0500 Subject: [PATCH 3/3] Address review comments --- llvm/lib/Frontend/OpenMP/OMP.cpp | 10 ++++++++-- llvm/unittests/Frontend/OpenMPComposeTest.cpp | 10 ++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Frontend/OpenMP/OMP.cpp b/llvm/lib/Frontend/OpenMP/OMP.cpp index dd99d3d074fd1e..7504c9076fde1b 100644 --- a/llvm/lib/Frontend/OpenMP/OMP.cpp +++ b/llvm/lib/Frontend/OpenMP/OMP.cpp @@ -27,8 +27,8 @@ using namespace llvm::omp; namespace llvm::omp { ArrayRef<Directive> getLeafConstructs(Directive D) { - auto Idx = static_cast<int>(D); - if (Idx < 0 || Idx >= static_cast<int>(Directive_enumSize)) + auto Idx = static_cast<std::size_t>(D); + if (Idx >= Directive_enumSize) return {}; const auto *Row = LeafConstructTable[LeafConstructTableOrdering[Idx]]; return ArrayRef(&Row[2], &Row[2] + static_cast<int>(Row[1])); @@ -50,6 +50,12 @@ Directive getCompoundConstruct(ArrayRef<Directive> Parts) { RawLeafs.push_back(P); } + // RawLeafs will be used as key in the binary search. The search doesn't + // guarantee that the exact same entry will be found (since RawLeafs may + // not correspond to any compound directive). Because of that, we will + // need to compare the search result with the given set of leafs. + // Also, if there is only one leaf in the list, it corresponds to itself, + // no search is necessary. auto GivenLeafs{ArrayRef<Directive>(RawLeafs).drop_front(2)}; if (GivenLeafs.size() == 1) return GivenLeafs.front(); diff --git a/llvm/unittests/Frontend/OpenMPComposeTest.cpp b/llvm/unittests/Frontend/OpenMPComposeTest.cpp index 29b1be4eb3432c..c3e0880ece8641 100644 --- a/llvm/unittests/Frontend/OpenMPComposeTest.cpp +++ b/llvm/unittests/Frontend/OpenMPComposeTest.cpp @@ -32,10 +32,8 @@ TEST(Composition, GetCompoundConstruct) { ASSERT_EQ(C3, OMPD_unknown); Directive C4 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); ASSERT_EQ(C4, OMPD_target_teams_distribute); - Directive C5 = getCompoundConstruct({OMPD_target, OMPD_teams_distribute}); - ASSERT_EQ(C5, OMPD_target_teams_distribute); - Directive C6 = getCompoundConstruct({}); - ASSERT_EQ(C6, OMPD_unknown); - Directive C7 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); - ASSERT_EQ(C7, OMPD_parallel_for_simd); + Directive C5 = getCompoundConstruct({}); + ASSERT_EQ(C5, OMPD_unknown); + Directive C6 = getCompoundConstruct({OMPD_parallel_for, OMPD_simd}); + ASSERT_EQ(C6, OMPD_parallel_for_simd); } _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits