skatrak updated this revision to Diff 531437. skatrak added a comment. Update patch to integrate with related patch D149337 <https://reviews.llvm.org/D149337> and address reviewer's comments.
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D147218/new/ https://reviews.llvm.org/D147218 Files: flang/include/flang/Lower/OpenMP.h flang/lib/Lower/Bridge.cpp flang/lib/Lower/OpenMP.cpp flang/test/Lower/OpenMP/requires-notarget.f90 flang/test/Lower/OpenMP/requires.f90
Index: flang/test/Lower/OpenMP/requires.f90 =================================================================== --- /dev/null +++ flang/test/Lower/OpenMP/requires.f90 @@ -0,0 +1,13 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks the lowering of requires into MLIR + +!CHECK: module attributes { +!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory> +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires + +subroutine f + !$omp declare target +end subroutine f Index: flang/test/Lower/OpenMP/requires-notarget.f90 =================================================================== --- /dev/null +++ flang/test/Lower/OpenMP/requires-notarget.f90 @@ -0,0 +1,11 @@ +! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s + +! This test checks that requires lowering into MLIR skips creating the +! omp.requires attribute with target-related clauses if there are no device +! functions in the compilation unit + +!CHECK: module attributes { +!CHECK-NOT: omp.requires +program requires + !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst) +end program requires Index: flang/lib/Lower/OpenMP.cpp =================================================================== --- flang/lib/Lower/OpenMP.cpp +++ flang/lib/Lower/OpenMP.cpp @@ -2594,16 +2594,14 @@ converter.bindSymbol(sym, symThreadprivateExv); } -void handleDeclareTarget(Fortran::lower::AbstractConverter &converter, - Fortran::lower::pft::Evaluation &eval, - const Fortran::parser::OpenMPDeclareTargetConstruct - &declareTargetConstruct) { - llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause, - Fortran::semantics::Symbol>, - 0> - symbolAndClause; - mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); - +/// Extract the list of function and variable symbols affected by the given +/// 'declare target' directive and return the intended device type for them. +static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo( + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct, + SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause, + Fortran::semantics::Symbol>> &symbolAndClause) { + // Gather the symbols and clauses auto findFuncAndVarSyms = [&](const Fortran::parser::OmpObjectList &objList, mlir::omp::DeclareTargetCaptureClause clause) { for (const Fortran::parser::OmpObject &ompObject : objList.v) { @@ -2628,6 +2626,7 @@ Fortran::parser::OmpDeviceTypeClause::Type::Any; const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>( declareTargetConstruct.t); + if (const auto *objectList{ Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) { // Case: declare target(func, var1, var2) @@ -2662,6 +2661,28 @@ } } + switch (deviceType) { + case Fortran::parser::OmpDeviceTypeClause::Type::Any: + return mlir::omp::DeclareTargetDeviceType::any; + case Fortran::parser::OmpDeviceTypeClause::Type::Host: + return mlir::omp::DeclareTargetDeviceType::host; + case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: + return mlir::omp::DeclareTargetDeviceType::nohost; + } +} + +void genDeclareTarget(Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause, + Fortran::semantics::Symbol>, + 0> + symbolAndClause; + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + for (std::pair<mlir::omp::DeclareTargetCaptureClause, Fortran::semantics::Symbol> symClause : symbolAndClause) { @@ -2688,35 +2709,44 @@ converter.getCurrentLocation(), "Attempt to apply declare target on unsupported operation"); - mlir::omp::DeclareTargetDeviceType newDeviceType; - switch (deviceType) { - case Fortran::parser::OmpDeviceTypeClause::Type::Nohost: - newDeviceType = mlir::omp::DeclareTargetDeviceType::nohost; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Host: - newDeviceType = mlir::omp::DeclareTargetDeviceType::host; - break; - case Fortran::parser::OmpDeviceTypeClause::Type::Any: - newDeviceType = mlir::omp::DeclareTargetDeviceType::any; - break; - } - // The function or global already has a declare target applied to it, // very likely through implicit capture (usage in another declare // target function/subroutine). It should be marked as any if it has // been assigned both host and nohost, else we skip, as there is no // change if (declareTargetOp.isDeclareTarget()) { - if (declareTargetOp.getDeclareTargetDeviceType() != newDeviceType) + if (declareTargetOp.getDeclareTargetDeviceType() != deviceType) declareTargetOp.setDeclareTarget( mlir::omp::DeclareTargetDeviceType::any, std::get<0>(symClause)); continue; } - declareTargetOp.setDeclareTarget(newDeviceType, std::get<0>(symClause)); + declareTargetOp.setDeclareTarget(deviceType, std::get<0>(symClause)); } } +void Fortran::lower::analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl, + bool &ompDeviceCodeFound) { + std::visit( + Fortran::common::visitors{ + [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) { + mlir::omp::DeclareTargetDeviceType targetType = + Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + converter, eval, ompReq) + .value_or(mlir::omp::DeclareTargetDeviceType::host); + + ompDeviceCodeFound = + ompDeviceCodeFound || + targetType != mlir::omp::DeclareTargetDeviceType::host; + }, + [&](const auto &) {}, + }, + ompDecl.u); +} + void Fortran::lower::genOpenMPDeclarativeConstruct( Fortran::lower::AbstractConverter &converter, Fortran::lower::pft::Evaluation &eval, @@ -2739,11 +2769,14 @@ }, [&](const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct) { - handleDeclareTarget(converter, eval, declareTargetConstruct); + genDeclareTarget(converter, eval, declareTargetConstruct); }, [&](const Fortran::parser::OpenMPRequiresConstruct &requiresConstruct) { - TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct"); + // Requires directives are gathered and processed in semantics in + // order to support modules, and then combined in the lowering + // bridge before triggering codegen just once. Hence, there is no + // need for codegen for each individual occurrence here. }, [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) { // The directive is lowered when instantiating the variable to @@ -2965,3 +2998,84 @@ } } } + +std::optional<mlir::omp::DeclareTargetDeviceType> +Fortran::lower::getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &converter, + Fortran::lower::pft::Evaluation &eval, + const Fortran::parser::OpenMPDeclareTargetConstruct + &declareTargetConstruct) { + llvm::SmallVector<std::pair<mlir::omp::DeclareTargetCaptureClause, + Fortran::semantics::Symbol>, + 0> + symbolAndClause; + mlir::omp::DeclareTargetDeviceType deviceType = + getDeclareTargetInfo(eval, declareTargetConstruct, symbolAndClause); + + // Return the device type only if at least one of the targets for the + // directive is a function or subroutine + mlir::ModuleOp mod = converter.getFirOpBuilder().getModule(); + for (std::pair<mlir::omp::DeclareTargetCaptureClause, + Fortran::semantics::Symbol> + sym : symbolAndClause) { + mlir::Operation *op = + mod.lookupSymbol(converter.mangleName(std::get<1>(sym))); + + if (mlir::isa<mlir::func::FuncOp>(op)) + return deviceType; + } + + return std::nullopt; +} + +bool Fortran::lower::isOpenMPTargetConstruct( + const Fortran::parser::OpenMPConstruct &omp) { + if (const auto *blockDir = + std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) { + const auto &beginBlockDir{ + std::get<Fortran::parser::OmpBeginBlockDirective>(blockDir->t)}; + const auto &beginDir{ + std::get<Fortran::parser::OmpBlockDirective>(beginBlockDir.t)}; + + switch (beginDir.v) { + case llvm::omp::Directive::OMPD_target: + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_parallel_do: + case llvm::omp::Directive::OMPD_target_parallel_do_simd: + case llvm::omp::Directive::OMPD_target_simd: + case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_distribute: + case llvm::omp::Directive::OMPD_target_teams_distribute_simd: + return true; + default: + break; + } + } + + return false; +} + +omp::ClauseRequires Fortran::lower::extractOpenMPRequiresClauses( + const Fortran::parser::OmpClauseList &clauseList) { + using omp::ClauseRequires, Fortran::parser::OmpClause; + auto requiresFlags = ClauseRequires::none; + + for (const OmpClause &clause : clauseList.v) { + if (std::get_if<OmpClause::DynamicAllocators>(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::dynamic_allocators; + else if (std::get_if<OmpClause::ReverseOffload>(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::reverse_offload; + else if (std::get_if<OmpClause::UnifiedAddress>(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_address; + else if (std::get_if<OmpClause::UnifiedSharedMemory>(&clause.u)) + requiresFlags = requiresFlags | ClauseRequires::unified_shared_memory; + } + + return requiresFlags; +} + +void Fortran::lower::genOpenMPRequires(Operation *mod, + omp::ClauseRequires flags) { + if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) + offloadMod.setRequires(flags); +} Index: flang/lib/Lower/Bridge.cpp =================================================================== --- flang/lib/Lower/Bridge.cpp +++ flang/lib/Lower/Bridge.cpp @@ -50,6 +50,7 @@ #include "flang/Parser/parse-tree.h" #include "flang/Runtime/iostat.h" #include "flang/Semantics/runtime-type-info.h" +#include "flang/Semantics/symbol.h" #include "flang/Semantics/tools.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/IR/PatternMatch.h" @@ -62,6 +63,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/Path.h" +#include <mlir/Dialect/OpenMP/OpenMPDialect.h> #include <optional> #define DEBUG_TYPE "flang-lower-bridge" @@ -288,20 +290,34 @@ // that they are available before lowering any function that may use // them. bool hasMainProgram = false; + Fortran::semantics::OmpRequiresFlags ompRequiresFlags = + Fortran::semantics::OmpRequiresFlags::None; + std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type> + ompAtomicDefaultMemOrder; for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) { std::visit(Fortran::common::visitors{ [&](Fortran::lower::pft::FunctionLikeUnit &f) { if (f.isMainProgram()) hasMainProgram = true; declareFunction(f); + ompProcessTopLevelSymbol(f.getScope().symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); }, [&](Fortran::lower::pft::ModuleLikeUnit &m) { lowerModuleDeclScope(m); for (Fortran::lower::pft::FunctionLikeUnit &f : m.nestedFunctions) declareFunction(f); + ompProcessTopLevelSymbol(m.getScope().symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); + }, + [&](Fortran::lower::pft::BlockDataUnit &b) { + ompProcessTopLevelSymbol(b.symTab.symbol(), + ompRequiresFlags, + ompAtomicDefaultMemOrder); }, - [&](Fortran::lower::pft::BlockDataUnit &b) {}, [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {}, }, u); @@ -344,6 +360,24 @@ fir::runtime::genEnvironmentDefaults(*builder, toLocation(), bridge.getEnvironmentDefaults()); }); + + // Set the module attribute related to OpenMP requires directives + if (ompDeviceCodeFound) { + using MlirRequires = mlir::omp::ClauseRequires; + using SemaRequires = Fortran::semantics::OmpRequiresFlags; + MlirRequires flags = MlirRequires::none; + + if (ompRequiresFlags & SemaRequires::ReverseOffload) + flags = flags | MlirRequires::reverse_offload; + if (ompRequiresFlags & SemaRequires::UnifiedAddress) + flags = flags | MlirRequires::unified_address; + if (ompRequiresFlags & SemaRequires::UnifiedSharedMemory) + flags = flags | MlirRequires::unified_shared_memory; + if (ompRequiresFlags & SemaRequires::DynamicAllocators) + flags = flags | MlirRequires::dynamic_allocators; + + Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(), flags); + } } /// Declare a function. @@ -1191,6 +1225,47 @@ activeConstructStack.pop_back(); } + void ompProcessTopLevelSymbol( + const Fortran::semantics::Symbol *symbol, + Fortran::semantics::OmpRequiresFlags &ompRequiresFlags, + std::optional<Fortran::parser::OmpAtomicDefaultMemOrderClause::Type> + &ompAtomicDefaultMemOrder) { + if (!symbol) + return; + + Fortran::common::visit( + [&](const auto &details) { + if constexpr (std::is_base_of_v< + Fortran::semantics::WithOmpDeclarative, + std::decay_t<decltype(details)>>) { + // Collect OpenMP 'requires' clauses. + if (details.has_ompRequires()) + ompRequiresFlags |= *details.ompRequires(); + + // Make sure any atomic_default_mem_order OpenMP 'requires' clauses + // obtained for different top-level symbols match. + if (details.has_ompAtomicDefaultMemOrder()) { + Fortran::parser::OmpAtomicDefaultMemOrderClause::Type memOrder{ + *details.ompAtomicDefaultMemOrder()}; + if (ompAtomicDefaultMemOrder && + memOrder != *ompAtomicDefaultMemOrder) + fir::emitFatalError( + getCurrentLocation(), + llvm::StringRef{ + "incompatible OpenMP requires atomic_default_mem_order " + "clauses found: '"} + + Fortran::parser::OmpAtomicDefaultMemOrderClause:: + EnumToString(memOrder) + + llvm::StringRef{"' and '"} + + Fortran::parser::OmpAtomicDefaultMemOrderClause:: + EnumToString(*ompAtomicDefaultMemOrder)); + ompAtomicDefaultMemOrder = memOrder; + } + } + }, + symbol->details()); + } + //===--------------------------------------------------------------------===// // Termination of symbolically referenced execution units //===--------------------------------------------------------------------===// @@ -2201,10 +2276,16 @@ localSymbols.popScope(); builder->restoreInsertionPoint(insertPt); + + // Register if a target region was found + ompDeviceCodeFound = + ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp); } void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) { mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint(); + analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl, + ompDeviceCodeFound); genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl); for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) genFIR(e); @@ -4530,6 +4611,10 @@ /// A counter for uniquing names in `literalNamesMap`. std::uint64_t uniqueLitId = 0; + + /// Whether an OpenMP target region or declare target function/subroutine + /// intended for device offloading has been detected + bool ompDeviceCodeFound = false; }; } // namespace Index: flang/include/flang/Lower/OpenMP.h =================================================================== --- flang/include/flang/Lower/OpenMP.h +++ flang/include/flang/Lower/OpenMP.h @@ -13,13 +13,9 @@ #ifndef FORTRAN_LOWER_OPENMP_H #define FORTRAN_LOWER_OPENMP_H +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include <cinttypes> -namespace mlir { -class Value; -class Operation; -} // namespace mlir - namespace fir { class FirOpBuilder; class ConvertOp; @@ -29,6 +25,7 @@ namespace parser { struct OpenMPConstruct; struct OpenMPDeclarativeConstruct; +struct OpenMPDeclareTargetConstruct; struct OmpEndLoopDirective; struct OmpClauseList; } // namespace parser @@ -44,6 +41,9 @@ void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPConstruct &); +void analyzeOpenMPDeclarativeConstruct( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const parser::OpenMPDeclarativeConstruct &, bool &); void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &, const parser::OpenMPDeclarativeConstruct &); int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList); @@ -56,6 +56,17 @@ void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value, mlir::Value, fir::ConvertOp * = nullptr); void removeStoreOp(mlir::Operation *, mlir::Value); + +std::optional<mlir::omp::DeclareTargetDeviceType> +getOpenMPDeclareTargetFunctionDevice( + Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &, + const Fortran::parser::OpenMPDeclareTargetConstruct &); +bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &); + +mlir::omp::ClauseRequires +extractOpenMPRequiresClauses(const Fortran::parser::OmpClauseList &); +void genOpenMPRequires(mlir::Operation *, mlir::omp::ClauseRequires); + } // namespace lower } // namespace Fortran
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits