https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116219
>From d2b7ffef6e2a1cce81f47e7d1886551aef677ed8 Mon Sep 17 00:00:00 2001 From: Sergio Afonso <safon...@amd.com> Date: Thu, 14 Nov 2024 12:24:15 +0000 Subject: [PATCH] [Flang][OpenMP] Lowering of host-evaluated clauses This patch adds support for lowering OpenMP clauses and expressions attached to constructs nested inside of a target region that need to be evaluated in the host device. This is done through the use of the `OpenMP_HostEvalClause` `omp.target` set of operands and entry block arguments. When lowering clauses for a target construct, a more involved `processHostEvalClauses()` function is called, which looks at the current and potentially other nested constructs in order to find and lower clauses that need to be processed outside of the `omp.target` operation under construction. This populates an instance of a global structure with the resulting MLIR values. The resulting list of host-evaluated values is used to initialize the `host_eval` operands when constructing the `omp.target` operation, and then replaced with the corresponding block arguments after creating that operation's region. Afterwards, while lowering nested operations, those that might potentially be evaluated in the host (e.g. `num_teams`, `thread_limit`, `num_threads` and `collapse`) check first whether there is an active global host-evaluated information structure and whether it holds values referring to these clauses. If that is the case, the stored values (referring to `omp.target` entry block arguments at that stage) are used instead of lowering clauses again. --- flang/include/flang/Common/OpenMP-utils.h | 20 +- flang/lib/Common/OpenMP-utils.cpp | 9 +- flang/lib/Lower/OpenMP/OpenMP.cpp | 449 +++++++++++++++++- flang/test/Lower/OpenMP/host-eval.f90 | 157 ++++++ flang/test/Lower/OpenMP/target-spmd.f90 | 191 ++++++++ .../Dialect/OpenMP/OpenMPClauseOperands.h | 6 + 6 files changed, 805 insertions(+), 27 deletions(-) create mode 100644 flang/test/Lower/OpenMP/host-eval.f90 create mode 100644 flang/test/Lower/OpenMP/target-spmd.f90 diff --git a/flang/include/flang/Common/OpenMP-utils.h b/flang/include/flang/Common/OpenMP-utils.h index e6a3f1bac1c605..827f13bc4758e2 100644 --- a/flang/include/flang/Common/OpenMP-utils.h +++ b/flang/include/flang/Common/OpenMP-utils.h @@ -34,6 +34,7 @@ struct EntryBlockArgsEntry { /// Structure holding the information needed to create and bind entry block /// arguments associated to all clauses that can define them. struct EntryBlockArgs { + llvm::ArrayRef<mlir::Value> hostEvalVars; EntryBlockArgsEntry inReduction; EntryBlockArgsEntry map; EntryBlockArgsEntry priv; @@ -49,18 +50,25 @@ struct EntryBlockArgs { } auto getSyms() const { - return llvm::concat<const Fortran::semantics::Symbol *const>( - inReduction.syms, map.syms, priv.syms, reduction.syms, - taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms); + return llvm::concat<const semantics::Symbol *const>(inReduction.syms, + map.syms, priv.syms, reduction.syms, taskReduction.syms, + useDeviceAddr.syms, useDevicePtr.syms); } auto getVars() const { - return llvm::concat<const mlir::Value>(inReduction.vars, map.vars, - priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars, - useDevicePtr.vars); + return llvm::concat<const mlir::Value>(hostEvalVars, inReduction.vars, + map.vars, priv.vars, reduction.vars, taskReduction.vars, + useDeviceAddr.vars, useDevicePtr.vars); } }; +/// Create an entry block for the given region, including the clause-defined +/// arguments specified. +/// +/// \param [in] builder - MLIR operation builder. +/// \param [in] args - entry block arguments information for the given +/// operation. +/// \param [in] region - Empty region in which to create the entry block. mlir::Block *genEntryBlock( mlir::OpBuilder &builder, const EntryBlockArgs &args, mlir::Region ®ion); } // namespace Fortran::common::openmp diff --git a/flang/lib/Common/OpenMP-utils.cpp b/flang/lib/Common/OpenMP-utils.cpp index f5115f475d6a19..47e89fe6dd1ee9 100644 --- a/flang/lib/Common/OpenMP-utils.cpp +++ b/flang/lib/Common/OpenMP-utils.cpp @@ -18,10 +18,10 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args, llvm::SmallVector<mlir::Type> types; llvm::SmallVector<mlir::Location> locs; - unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() + - args.priv.vars.size() + args.reduction.vars.size() + - args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() + - args.useDevicePtr.vars.size(); + unsigned numVars = args.hostEvalVars.size() + args.inReduction.vars.size() + + args.map.vars.size() + args.priv.vars.size() + + args.reduction.vars.size() + args.taskReduction.vars.size() + + args.useDeviceAddr.vars.size() + args.useDevicePtr.vars.size(); types.reserve(numVars); locs.reserve(numVars); @@ -34,6 +34,7 @@ mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args, // Populate block arguments in clause name alphabetical order to match // expected order by the BlockArgOpenMPOpInterface. + extractTypeLoc(args.hostEvalVars); extractTypeLoc(args.inReduction.vars); extractTypeLoc(args.map.vars); extractTypeLoc(args.priv.vars); diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index cd4b25a17722c1..ac64032d5e08ed 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -55,6 +55,149 @@ static void genOMPDispatch(lower::AbstractConverter &converter, const ConstructQueue &queue, ConstructQueue::const_iterator item); +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc); + +namespace { +/// Structure holding information that is needed to pass host-evaluated +/// information to later lowering stages. +class HostEvalInfo { +public: + // Allow this function access to private members in order to initialize them. + friend void ::processHostEvalClauses(lower::AbstractConverter &, + semantics::SemanticsContext &, + lower::StatementContext &, + lower::pft::Evaluation &, + mlir::Location); + + /// Fill \c vars with values stored in \c ops. + /// + /// The order in which values are stored matches the one expected by \see + /// bindOperands(). + void collectValues(llvm::SmallVectorImpl<mlir::Value> &vars) const { + vars.append(ops.loopLowerBounds); + vars.append(ops.loopUpperBounds); + vars.append(ops.loopSteps); + + if (ops.numTeamsLower) + vars.push_back(ops.numTeamsLower); + + if (ops.numTeamsUpper) + vars.push_back(ops.numTeamsUpper); + + if (ops.numThreads) + vars.push_back(ops.numThreads); + + if (ops.threadLimit) + vars.push_back(ops.threadLimit); + } + + /// Update \c ops, replacing all values with the corresponding block argument + /// in \c args. + /// + /// The order in which values are stored in \c args is the same as the one + /// used by \see collectValues(). + void bindOperands(llvm::ArrayRef<mlir::BlockArgument> args) { + assert(args.size() == + ops.loopLowerBounds.size() + ops.loopUpperBounds.size() + + ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) + + (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) + + (ops.threadLimit ? 1 : 0) && + "invalid block argument list"); + int argIndex = 0; + for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i) + ops.loopLowerBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopUpperBounds.size(); ++i) + ops.loopUpperBounds[i] = args[argIndex++]; + + for (size_t i = 0; i < ops.loopSteps.size(); ++i) + ops.loopSteps[i] = args[argIndex++]; + + if (ops.numTeamsLower) + ops.numTeamsLower = args[argIndex++]; + + if (ops.numTeamsUpper) + ops.numTeamsUpper = args[argIndex++]; + + if (ops.numThreads) + ops.numThreads = args[argIndex++]; + + if (ops.threadLimit) + ops.threadLimit = args[argIndex++]; + } + + /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated + /// values and Fortran symbols, respectively, if they have already been + /// initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::LoopNestOperands &clauseOps, + llvm::SmallVectorImpl<const semantics::Symbol *> &ivOut) { + if (iv.empty() || loopNestApplied) { + loopNestApplied = true; + return false; + } + + loopNestApplied = true; + clauseOps.loopLowerBounds = ops.loopLowerBounds; + clauseOps.loopUpperBounds = ops.loopUpperBounds; + clauseOps.loopSteps = ops.loopSteps; + ivOut.append(iv); + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized but not yet applied. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::ParallelOperands &clauseOps) { + if (!ops.numThreads || parallelApplied) { + parallelApplied = true; + return false; + } + + parallelApplied = true; + clauseOps.numThreads = ops.numThreads; + return true; + } + + /// Update \p clauseOps with the corresponding host-evaluated values if they + /// have already been initialized. + /// + /// \returns whether an update was performed. If not, these clauses were not + /// evaluated in the host device. + bool apply(mlir::omp::TeamsOperands &clauseOps) { + if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit) + return false; + + clauseOps.numTeamsLower = ops.numTeamsLower; + clauseOps.numTeamsUpper = ops.numTeamsUpper; + clauseOps.threadLimit = ops.threadLimit; + return true; + } + +private: + mlir::omp::HostEvaluatedOperands ops; + llvm::SmallVector<const semantics::Symbol *> iv; + bool loopNestApplied = false, parallelApplied = false; +}; +} // namespace + +/// Stack of \see HostEvalInfo to represent the current nest of \c omp.target +/// operations being created. +/// +/// The current implementation prevents nested 'target' regions from breaking +/// the handling of the outer region by keeping a stack of information +/// structures, but it will probably still require some further work to support +/// reverse offloading. +static llvm::SmallVector<HostEvalInfo, 0> hostEvalInfo; + /// Bind symbols to their corresponding entry block arguments. /// /// The binding will be performed inside of the current block, which does not @@ -176,6 +319,8 @@ static void bindEntryBlockArgs(lower::AbstractConverter &converter, }; // Process in clause name alphabetical order to match block arguments order. + // Do not bind host_eval variables because they cannot be used inside of the + // corresponding region, except for very specific cases handled separately. bindPrivateLike(args.inReduction.syms, args.inReduction.vars, op.getInReductionBlockArgs()); bindMapLike(args.map.syms, op.getMapBlockArgs()); @@ -213,6 +358,256 @@ extractMappedBaseValues(llvm::ArrayRef<mlir::Value> vars, }); } +/// Get the directive enumeration value corresponding to the given OpenMP +/// construct PFT node. +llvm::omp::Directive +extractOmpDirective(const parser::OpenMPConstruct &ompConstruct) { + return common::visit( + common::visitors{ + [](const parser::OpenMPAllocatorsConstruct &c) { + return llvm::omp::OMPD_allocators; + }, + [](const parser::OpenMPAtomicConstruct &c) { + return llvm::omp::OMPD_atomic; + }, + [](const parser::OpenMPBlockConstruct &c) { + return std::get<parser::OmpBlockDirective>( + std::get<parser::OmpBeginBlockDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPCriticalConstruct &c) { + return llvm::omp::OMPD_critical; + }, + [](const parser::OpenMPDeclarativeAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPExecutableAllocate &c) { + return llvm::omp::OMPD_allocate; + }, + [](const parser::OpenMPLoopConstruct &c) { + return std::get<parser::OmpLoopDirective>( + std::get<parser::OmpBeginLoopDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPSectionConstruct &c) { + return llvm::omp::OMPD_section; + }, + [](const parser::OpenMPSectionsConstruct &c) { + return std::get<parser::OmpSectionsDirective>( + std::get<parser::OmpBeginSectionsDirective>(c.t).t) + .v; + }, + [](const parser::OpenMPStandaloneConstruct &c) { + return common::visit( + common::visitors{ + [](const parser::OpenMPSimpleStandaloneConstruct &c) { + return std::get<parser::OmpSimpleStandaloneDirective>(c.t) + .v; + }, + [](const parser::OpenMPFlushConstruct &c) { + return llvm::omp::OMPD_flush; + }, + [](const parser::OpenMPCancelConstruct &c) { + return llvm::omp::OMPD_cancel; + }, + [](const parser::OpenMPCancellationPointConstruct &c) { + return llvm::omp::OMPD_cancellation_point; + }, + [](const parser::OpenMPDepobjConstruct &c) { + return llvm::omp::OMPD_depobj; + }}, + c.u); + }, + [](const parser::OpenMPUtilityConstruct &c) { + return common::visit( + common::visitors{[](const parser::OmpErrorDirective &c) { + return llvm::omp::OMPD_error; + }, + [](const parser::OmpNothingDirective &c) { + return llvm::omp::OMPD_nothing; + }}, + c.u); + }}, + ompConstruct.u); +} + +/// Populate the global \see hostEvalInfo after processing clauses for the given +/// \p eval OpenMP target construct, or nested constructs, if these must be +/// evaluated outside of the target region per the spec. +/// +/// In particular, this will ensure that in 'target teams' and equivalent nested +/// constructs, the \c thread_limit and \c num_teams clauses will be evaluated +/// in the host. Additionally, loop bounds, steps and the \c num_threads clause +/// will also be evaluated in the host if a target SPMD construct is detected +/// (i.e. 'target teams distribute parallel do [simd]' or equivalent nesting). +/// +/// The result, stored as a global, is intended to be used to populate the \c +/// host_eval operands of the associated \c omp.target operation, and also to be +/// checked and used by later lowering steps to populate the corresponding +/// operands of the \c omp.teams, \c omp.parallel or \c omp.loop_nest +/// operations. +static void processHostEvalClauses(lower::AbstractConverter &converter, + semantics::SemanticsContext &semaCtx, + lower::StatementContext &stmtCtx, + lower::pft::Evaluation &eval, + mlir::Location loc) { + // Obtain the list of clauses of the given OpenMP block or loop construct + // evaluation. Other evaluations passed to this lambda keep `clauses` + // unchanged. + auto extractClauses = [&semaCtx](lower::pft::Evaluation &eval, + List<Clause> &clauses) { + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + const parser::OmpClauseList *beginClauseList = nullptr; + const parser::OmpClauseList *endClauseList = nullptr; + common::visit( + common::visitors{ + [&](const parser::OpenMPBlockConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginBlockDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + endClauseList = &std::get<parser::OmpClauseList>( + std::get<parser::OmpEndBlockDirective>(ompConstruct.t).t); + }, + [&](const parser::OpenMPLoopConstruct &ompConstruct) { + const auto &beginDirective = + std::get<parser::OmpBeginLoopDirective>(ompConstruct.t); + beginClauseList = + &std::get<parser::OmpClauseList>(beginDirective.t); + + if (auto &endDirective = + std::get<std::optional<parser::OmpEndLoopDirective>>( + ompConstruct.t)) + endClauseList = + &std::get<parser::OmpClauseList>(endDirective->t); + }, + [&](const auto &) {}}, + ompEval->u); + + assert(beginClauseList && "expected begin directive"); + clauses.append(makeClauses(*beginClauseList, semaCtx)); + + if (endClauseList) + clauses.append(makeClauses(*endClauseList, semaCtx)); + }; + + // Return the directive that is immediately nested inside of the given + // `parent` evaluation, if it is its only non-end-statement nested evaluation + // and it represents an OpenMP construct. + auto extractOnlyOmpNestedDir = [](lower::pft::Evaluation &parent) + -> std::optional<llvm::omp::Directive> { + if (!parent.hasNestedEvaluations()) + return std::nullopt; + + llvm::omp::Directive dir; + auto &nested = parent.getFirstNestedEvaluation(); + if (const auto *ompEval = nested.getIf<parser::OpenMPConstruct>()) + dir = extractOmpDirective(*ompEval); + else + return std::nullopt; + + for (auto &sibling : parent.getNestedEvaluations()) + if (&sibling != &nested && !sibling.isEndStmt()) + return std::nullopt; + + return dir; + }; + + // Process the given evaluation assuming it's part of a 'target' construct or + // captured by one, and store results in the global `hostEvalInfo`. + std::function<void(lower::pft::Evaluation &, const List<Clause> &)> + processEval; + processEval = [&](lower::pft::Evaluation &eval, const List<Clause> &clauses) { + using namespace llvm::omp; + ClauseProcessor cp(converter, semaCtx, clauses); + + // Call `processEval` recursively with the immediately nested evaluation and + // its corresponding clauses if there is a single nested evaluation + // representing an OpenMP directive that passes the given test. + auto processSingleNestedIf = [&](llvm::function_ref<bool(Directive)> test) { + std::optional<Directive> nestedDir = extractOnlyOmpNestedDir(eval); + if (!nestedDir || !test(*nestedDir)) + return; + + lower::pft::Evaluation &nestedEval = eval.getFirstNestedEvaluation(); + List<lower::omp::Clause> nestedClauses; + extractClauses(nestedEval, nestedClauses); + processEval(nestedEval, nestedClauses); + }; + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + if (!ompEval) + return; + + HostEvalInfo &hostInfo = hostEvalInfo.back(); + + switch (extractOmpDirective(*ompEval)) { + // Cases where 'teams' and target SPMD clauses might be present. + case OMPD_teams_distribute_parallel_do: + case OMPD_teams_distribute_parallel_do_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute_parallel_do: + case OMPD_target_teams_distribute_parallel_do_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_distribute_parallel_do: + case OMPD_distribute_parallel_do_simd: + cp.processCollapse(loc, eval, hostInfo.ops, hostInfo.iv); + cp.processNumThreads(stmtCtx, hostInfo.ops); + break; + + // Cases where 'teams' clauses might be present, and target SPMD is + // possible by looking at nested evaluations. + case OMPD_teams: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams: + cp.processNumTeams(stmtCtx, hostInfo.ops); + processSingleNestedIf([](Directive nestedDir) { + return nestedDir == OMPD_distribute_parallel_do || + nestedDir == OMPD_distribute_parallel_do_simd; + }); + break; + + // Cases where only 'teams' host-evaluated clauses might be present. + case OMPD_teams_distribute: + case OMPD_teams_distribute_simd: + cp.processThreadLimit(stmtCtx, hostInfo.ops); + [[fallthrough]]; + case OMPD_target_teams_distribute: + case OMPD_target_teams_distribute_simd: + cp.processNumTeams(stmtCtx, hostInfo.ops); + break; + + // Standalone 'target' case. + case OMPD_target: { + processSingleNestedIf( + [](Directive nestedDir) { return topTeamsSet.test(nestedDir); }); + break; + } + default: + break; + } + }; + + assert(!hostEvalInfo.empty() && "expected HOST_EVAL info structure"); + + const auto *ompEval = eval.getIf<parser::OpenMPConstruct>(); + assert(ompEval && + llvm::omp::allTargetSet.test(extractOmpDirective(*ompEval)) && + "expected TARGET construct evaluation"); + + // Use the whole list of clauses passed to the construct here, rather than the + // ones only applied to omp.target. + List<lower::omp::Clause> clauses; + extractClauses(eval, clauses); + processEval(eval, clauses); +} + static lower::pft::Evaluation * getCollapsedLoopEval(lower::pft::Evaluation &eval, int collapseValue) { // Return the Evaluation of the innermost collapsed loop, or the current one @@ -913,6 +1308,8 @@ static void genBodyOfTargetOp( mlir::Region ®ion = targetOp.getRegion(); mlir::Block *entryBlock = genEntryBlock(firOpBuilder, args, region); bindEntryBlockArgs(converter, targetOp, args); + if (!hostEvalInfo.empty()) + hostEvalInfo.back().bindOperands(argIface.getHostEvalBlockArgs()); // Check if cloning the bounds introduced any dependency on the outer region. // If so, then either clone them as well if they are MemoryEffectFree, or else @@ -1126,7 +1523,10 @@ genLoopNestClauses(lower::AbstractConverter &converter, mlir::Location loc, mlir::omp::LoopNestOperands &clauseOps, llvm::SmallVectorImpl<const semantics::Symbol *> &iv) { ClauseProcessor cp(converter, semaCtx, clauses); - cp.processCollapse(loc, eval, clauseOps, iv); + + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps, iv)) + cp.processCollapse(loc, eval, clauseOps, iv); + clauseOps.loopInclusive = converter.getFirOpBuilder().getUnitAttr(); } @@ -1168,7 +1568,10 @@ static void genParallelClauses( ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_parallel, clauseOps); - cp.processNumThreads(stmtCtx, clauseOps); + + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) + cp.processNumThreads(stmtCtx, clauseOps); + cp.processProcBind(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); } @@ -1215,8 +1618,8 @@ static void genSingleClauses(lower::AbstractConverter &converter, static void genTargetClauses( lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx, - lower::StatementContext &stmtCtx, const List<Clause> &clauses, - mlir::Location loc, bool processHostOnlyClauses, + lower::StatementContext &stmtCtx, lower::pft::Evaluation &eval, + const List<Clause> &clauses, mlir::Location loc, mlir::omp::TargetOperands &clauseOps, llvm::SmallVectorImpl<const semantics::Symbol *> &hasDeviceAddrSyms, llvm::SmallVectorImpl<const semantics::Symbol *> &isDevicePtrSyms, @@ -1226,13 +1629,15 @@ static void genTargetClauses( cp.processDepend(clauseOps); cp.processDevice(stmtCtx, clauseOps); cp.processHasDeviceAddr(clauseOps, hasDeviceAddrSyms); + if (!hostEvalInfo.empty()) { + // Only process host_eval if compiling for the host device. + processHostEvalClauses(converter, semaCtx, stmtCtx, eval, loc); + hostEvalInfo.back().collectValues(clauseOps.hostEvalVars); + } cp.processIf(llvm::omp::Directive::OMPD_target, clauseOps); cp.processIsDevicePtr(clauseOps, isDevicePtrSyms); cp.processMap(loc, stmtCtx, clauseOps, &mapSyms); - - if (processHostOnlyClauses) - cp.processNowait(clauseOps); - + cp.processNowait(clauseOps); cp.processThreadLimit(stmtCtx, clauseOps); cp.processTODO<clause::Allocate, clause::Defaultmap, clause::Firstprivate, @@ -1344,10 +1749,13 @@ static void genTeamsClauses(lower::AbstractConverter &converter, ClauseProcessor cp(converter, semaCtx, clauses); cp.processAllocate(clauseOps); cp.processIf(llvm::omp::Directive::OMPD_teams, clauseOps); - cp.processNumTeams(stmtCtx, clauseOps); - cp.processThreadLimit(stmtCtx, clauseOps); - // TODO Support delayed privatization. + if (hostEvalInfo.empty() || !hostEvalInfo.back().apply(clauseOps)) { + cp.processNumTeams(stmtCtx, clauseOps); + cp.processThreadLimit(stmtCtx, clauseOps); + } + + // TODO Support delayed privatization. cp.processTODO<clause::Reduction>(loc, llvm::omp::Directive::OMPD_teams); } @@ -1721,17 +2129,19 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, ConstructQueue::const_iterator item) { fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); lower::StatementContext stmtCtx; + bool isTargetDevice = + llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp()) + .getIsTargetDevice(); - bool processHostOnlyClauses = - !llvm::cast<mlir::omp::OffloadModuleInterface>(*converter.getModuleOp()) - .getIsTargetDevice(); + // Introduce a new host_eval information structure for this target region. + if (!isTargetDevice) + hostEvalInfo.emplace_back(); mlir::omp::TargetOperands clauseOps; llvm::SmallVector<const semantics::Symbol *> mapSyms, isDevicePtrSyms, hasDeviceAddrSyms; - genTargetClauses(converter, semaCtx, stmtCtx, item->clauses, loc, - processHostOnlyClauses, clauseOps, hasDeviceAddrSyms, - isDevicePtrSyms, mapSyms); + genTargetClauses(converter, semaCtx, stmtCtx, eval, item->clauses, loc, + clauseOps, hasDeviceAddrSyms, isDevicePtrSyms, mapSyms); DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, /*shouldCollectPreDeterminedSymbols=*/ @@ -1841,6 +2251,7 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, extractMappedBaseValues(clauseOps.mapVars, mapBaseValues); EntryBlockArgs args; + args.hostEvalVars = clauseOps.hostEvalVars; // TODO: Add in_reduction syms and vars. args.map.syms = mapSyms; args.map.vars = mapBaseValues; @@ -1849,6 +2260,10 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable, genBodyOfTargetOp(converter, symTable, semaCtx, eval, targetOp, args, loc, queue, item, dsp); + + // Remove the host_eval information structure created for this target region. + if (!isTargetDevice) + hostEvalInfo.pop_back(); return targetOp; } diff --git a/flang/test/Lower/OpenMP/host-eval.f90 b/flang/test/Lower/OpenMP/host-eval.f90 new file mode 100644 index 00000000000000..32c52462b86a76 --- /dev/null +++ b/flang/test/Lower/OpenMP/host-eval.f90 @@ -0,0 +1,157 @@ +! The "thread_limit" clause was added to the "target" construct in OpenMP 5.1. +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 %s -o - | FileCheck %s --check-prefixes=BOTH,HOST +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -fopenmp-is-target-device %s -o - | FileCheck %s --check-prefixes=BOTH,DEVICE + +! BOTH-LABEL: func.func @_QPteams +subroutine teams() + ! BOTH: omp.target + + ! HOST-SAME: host_eval(%{{.*}} -> %[[NUM_TEAMS:.*]], %{{.*}} -> %[[THREAD_LIMIT:.*]] : i32, i32) + + ! DEVICE-NOT: host_eval({{.*}}) + ! DEVICE-SAME: { + !$omp target + + ! BOTH: omp.teams + + ! HOST-SAME: num_teams( to %[[NUM_TEAMS]] : i32) thread_limit(%[[THREAD_LIMIT]] : i32) + ! DEVICE-SAME: num_teams({{.*}}) thread_limit({{.*}}) + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams + + !$omp end target + + ! BOTH: omp.teams + ! BOTH-SAME: num_teams({{.*}}) thread_limit({{.*}}) { + !$omp teams num_teams(1) thread_limit(2) + call foo() + !$omp end teams +end subroutine teams + +! BOTH-LABEL: func.func @_QPdistribute_parallel_do +subroutine distribute_parallel_do() + ! BOTH: omp.target + + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) + + ! DEVICE-NOT: host_eval({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.parallel + + ! HOST-SAME: num_threads(%[[NUM_THREADS]] : i32) + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.loop_nest + + ! HOST-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: host_eval({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() !< Prevents this from being SPMD. + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + !$omp distribute parallel do num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do + !$omp end teams +end subroutine distribute_parallel_do + +! BOTH-LABEL: func.func @_QPdistribute_parallel_do_simd +subroutine distribute_parallel_do_simd() + ! BOTH: omp.target + + ! HOST-SAME: host_eval(%{{.*}} -> %[[LB:.*]], %{{.*}} -> %[[UB:.*]], %{{.*}} -> %[[STEP:.*]], %{{.*}} -> %[[NUM_THREADS:.*]] : i32, i32, i32, i32) + + ! DEVICE-NOT: host_eval({{.*}}) + ! DEVICE-SAME: { + + ! BOTH: omp.teams + !$omp target teams + + ! BOTH: omp.parallel + + ! HOST-SAME: num_threads(%[[NUM_THREADS]] : i32) + ! DEVICE-SAME: num_threads({{.*}}) + + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + ! BOTH-NEXT: omp.loop_nest + + ! HOST-SAME: (%{{.*}}) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.target + ! BOTH-NOT: host_eval({{.*}}) + ! BOTH-SAME: { + ! BOTH: omp.teams + !$omp target teams + call foo() !< Prevents this from being SPMD. + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end target teams + + ! BOTH: omp.teams + !$omp teams + + ! BOTH: omp.parallel + ! BOTH-SAME: num_threads({{.*}}) + ! BOTH: omp.distribute + ! BOTH-NEXT: omp.wsloop + ! BOTH-NEXT: omp.simd + !$omp distribute parallel do simd num_threads(1) + do i=1,10 + call foo() + end do + !$omp end distribute parallel do simd + !$omp end teams +end subroutine distribute_parallel_do_simd diff --git a/flang/test/Lower/OpenMP/target-spmd.f90 b/flang/test/Lower/OpenMP/target-spmd.f90 new file mode 100644 index 00000000000000..43613819ccc8e9 --- /dev/null +++ b/flang/test/Lower/OpenMP/target-spmd.f90 @@ -0,0 +1,191 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func.func @_QPdistribute_parallel_do_generic() { +subroutine distribute_parallel_do_generic() + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target + !$omp teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + call bar() !< Prevents this from being SPMD. + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + call bar() !< Prevents this from being SPMD. + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + !$omp end target teams +end subroutine distribute_parallel_do_generic + +! CHECK-LABEL: func.func @_QPdistribute_parallel_do_spmd() { +subroutine distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target + !$omp teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target teams + !$omp distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do + !$omp end target teams +end subroutine distribute_parallel_do_spmd + +! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_generic() { +subroutine distribute_parallel_do_simd_generic() + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target + !$omp teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + call bar() !< Prevents this from being SPMD. + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + call bar() !< Prevents this from being SPMD. + !$omp end target teams + + ! CHECK: omp.target + ! CHECK-NOT: host_eval({{.*}}) + ! CHECK-SAME: { + !$omp target teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + !$omp end target teams +end subroutine distribute_parallel_do_simd_generic + +! CHECK-LABEL: func.func @_QPdistribute_parallel_do_simd_spmd() { +subroutine distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target + !$omp teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + !$omp end teams + !$omp end target + + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target teams + !$omp distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end distribute parallel do simd + !$omp end target teams +end subroutine distribute_parallel_do_simd_spmd + +! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_spmd() { +subroutine teams_distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target + !$omp teams distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end teams distribute parallel do + !$omp end target +end subroutine teams_distribute_parallel_do_spmd + +! CHECK-LABEL: func.func @_QPteams_distribute_parallel_do_simd_spmd() { +subroutine teams_distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target + !$omp teams distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end teams distribute parallel do simd + !$omp end target +end subroutine teams_distribute_parallel_do_simd_spmd + +! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_spmd() { +subroutine target_teams_distribute_parallel_do_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target teams distribute parallel do + do i = 1, 10 + call foo(i) + end do + !$omp end target teams distribute parallel do +end subroutine target_teams_distribute_parallel_do_spmd + +! CHECK-LABEL: func.func @_QPtarget_teams_distribute_parallel_do_simd_spmd() { +subroutine target_teams_distribute_parallel_do_simd_spmd() + ! CHECK: omp.target + ! CHECK-SAME: host_eval({{.*}}) + !$omp target teams distribute parallel do simd + do i = 1, 10 + call foo(i) + end do + !$omp end target teams distribute parallel do simd +end subroutine target_teams_distribute_parallel_do_simd_spmd diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h index 1247a871f93c6d..f9a85626a3f149 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h @@ -41,6 +41,12 @@ struct DeviceTypeClauseOps { // Extra operation operand structures. //===----------------------------------------------------------------------===// +/// Clauses that correspond to operations other than omp.target, but might have +/// to be evaluated outside of a parent target region. +using HostEvaluatedOperands = + detail::Clauses<LoopRelatedClauseOps, NumTeamsClauseOps, + NumThreadsClauseOps, ThreadLimitClauseOps>; + // TODO: Add `indirect` clause. using DeclareTargetOperands = detail::Clauses<DeviceTypeClauseOps>; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits