skatrak updated this revision to Diff 556542.
skatrak marked 3 inline comments as done.
skatrak added a comment.

Address review comments.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D147218/new/

https://reviews.llvm.org/D147218

Files:
  flang/include/flang/Lower/OpenMP.h
  flang/lib/Lower/Bridge.cpp
  flang/lib/Lower/OpenMP.cpp
  flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
  flang/test/Lower/OpenMP/requires-common.f90
  flang/test/Lower/OpenMP/requires-notarget.f90
  flang/test/Lower/OpenMP/requires.f90

Index: flang/test/Lower/OpenMP/requires.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires reverse_offload|unified_shared_memory>
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+  !$omp target
+  !$omp end target
+end program requires
Index: flang/test/Lower/OpenMP/requires-notarget.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires-notarget.f90
@@ -0,0 +1,14 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks that requires lowering into MLIR skips creating the
+! omp.requires attribute with target-related clauses if there are no device
+! functions in the compilation unit
+
+!CHECK:      module attributes {
+!CHECK-NOT:  omp.requires
+program requires
+  !$omp requires unified_shared_memory reverse_offload atomic_default_mem_order(seq_cst)
+end program requires
Index: flang/test/Lower/OpenMP/requires-common.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/requires-common.f90
@@ -0,0 +1,19 @@
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+
+! This test checks the lowering of requires into MLIR
+
+!CHECK:      module attributes {
+!CHECK-SAME: omp.requires = #omp<clause_requires unified_shared_memory>
+block data init
+  !$omp requires unified_shared_memory
+  integer :: x
+  common /block/ x
+  data x / 10 /
+end
+
+subroutine f
+  !$omp declare target
+end subroutine f
Index: flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
===================================================================
--- /dev/null
+++ flang/test/Lower/OpenMP/Todo/requires-unnamed-common.f90
@@ -0,0 +1,25 @@
+! This test checks the lowering of REQUIRES inside of an unnamed BLOCK DATA.
+! The symbol of the `symTab` scope of the `BlockDataUnit` PFT node is null in
+! this case, resulting in the inability to store the REQUIRES flags gathered in
+! it.
+
+! RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+! RUN: %flang_fc1 -emit-fir -fopenmp -fopenmp-is-target-device %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -emit-fir %s -o - | FileCheck %s
+! RUN: bbc -fopenmp -fopenmp-is-target-device -emit-fir %s -o - | FileCheck %s
+! XFAIL: *
+
+! The string "EXPECTED" denotes the expected FIR
+
+!CHECK:         module attributes {
+!CHECK-SAME:    omp.requires = #omp<clause_requires unified_shared_memory>
+block data
+  !$omp requires unified_shared_memory
+  integer :: x
+  common /block/ x
+  data x / 10 /
+end
+
+subroutine f
+  !$omp declare target
+end subroutine f
Index: flang/lib/Lower/OpenMP.cpp
===================================================================
--- flang/lib/Lower/OpenMP.cpp
+++ flang/lib/Lower/OpenMP.cpp
@@ -78,9 +78,7 @@
 static void gatherFuncAndVarSyms(
     const Fortran::parser::OmpObjectList &objList,
     mlir::omp::DeclareTargetCaptureClause clause,
-    llvm::SmallVectorImpl<std::pair<mlir::omp::DeclareTargetCaptureClause,
-                                    Fortran::semantics::Symbol>>
-        &symbolAndClause) {
+    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
   for (const Fortran::parser::OmpObject &ompObject : objList.v) {
     Fortran::common::visit(
         Fortran::common::visitors{
@@ -2474,6 +2472,71 @@
                                  reductionDeclSymbols));
 }
 
+/// Extract the list of function and variable symbols affected by the given
+/// 'declare target' directive and return the intended device type for them.
+static mlir::omp::DeclareTargetDeviceType getDeclareTargetInfo(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct &declareTargetConstruct,
+    llvm::SmallVectorImpl<DeclareTargetCapturePair> &symbolAndClause) {
+
+  // The default capture type
+  mlir::omp::DeclareTargetDeviceType deviceType =
+      mlir::omp::DeclareTargetDeviceType::any;
+  const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
+      declareTargetConstruct.t);
+
+  if (const auto *objectList{
+          Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
+    // Case: declare target(func, var1, var2)
+    gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
+                         symbolAndClause);
+  } else if (const auto *clauseList{
+                 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
+                     spec.u)}) {
+    if (clauseList->v.empty()) {
+      // Case: declare target, implicit capture of function
+      symbolAndClause.emplace_back(
+          mlir::omp::DeclareTargetCaptureClause::to,
+          eval.getOwningProcedure()->getSubprogramSymbol());
+    }
+
+    ClauseProcessor cp(converter, *clauseList);
+    cp.processTo(symbolAndClause);
+    cp.processLink(symbolAndClause);
+    cp.processDeviceType(deviceType);
+    cp.processTODO<Fortran::parser::OmpClause::Indirect>(
+        converter.getCurrentLocation(),
+        llvm::omp::Directive::OMPD_declare_target);
+  }
+
+  return deviceType;
+}
+
+static std::optional<mlir::omp::DeclareTargetDeviceType>
+getDeclareTargetFunctionDevice(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclareTargetConstruct
+        &declareTargetConstruct) {
+  llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
+  mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
+
+  // Return the device type only if at least one of the targets for the
+  // directive is a function or subroutine
+  mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
+  for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
+    mlir::Operation *op = mod.lookupSymbol(
+        converter.mangleName(std::get<Fortran::semantics::Symbol>(symClause)));
+
+    if (mlir::isa<mlir::func::FuncOp>(op))
+      return deviceType;
+  }
+
+  return std::nullopt;
+}
+
 //===----------------------------------------------------------------------===//
 // genOMP() Code generation helper functions
 //===----------------------------------------------------------------------===//
@@ -2994,35 +3057,8 @@
                        &declareTargetConstruct) {
   llvm::SmallVector<DeclareTargetCapturePair, 0> symbolAndClause;
   mlir::ModuleOp mod = converter.getFirOpBuilder().getModule();
-
-  // The default capture type
-  mlir::omp::DeclareTargetDeviceType deviceType =
-      mlir::omp::DeclareTargetDeviceType::any;
-  const auto &spec = std::get<Fortran::parser::OmpDeclareTargetSpecifier>(
-      declareTargetConstruct.t);
-  if (const auto *objectList{
-          Fortran::parser::Unwrap<Fortran::parser::OmpObjectList>(spec.u)}) {
-    // Case: declare target(func, var1, var2)
-    gatherFuncAndVarSyms(*objectList, mlir::omp::DeclareTargetCaptureClause::to,
-                         symbolAndClause);
-  } else if (const auto *clauseList{
-                 Fortran::parser::Unwrap<Fortran::parser::OmpClauseList>(
-                     spec.u)}) {
-    if (clauseList->v.empty()) {
-      // Case: declare target, implicit capture of function
-      symbolAndClause.emplace_back(
-          mlir::omp::DeclareTargetCaptureClause::to,
-          eval.getOwningProcedure()->getSubprogramSymbol());
-    }
-
-    ClauseProcessor cp(converter, *clauseList);
-    cp.processTo(symbolAndClause);
-    cp.processLink(symbolAndClause);
-    cp.processDeviceType(deviceType);
-    cp.processTODO<Fortran::parser::OmpClause::Indirect>(
-        converter.getCurrentLocation(),
-        llvm::omp::Directive::OMPD_declare_target);
-  }
+  mlir::omp::DeclareTargetDeviceType deviceType = getDeclareTargetInfo(
+      converter, eval, declareTargetConstruct, symbolAndClause);
 
   for (const DeclareTargetCapturePair &symClause : symbolAndClause) {
     mlir::Operation *op = mod.lookupSymbol(
@@ -3126,6 +3162,27 @@
       ompConstruct.u);
 }
 
+void Fortran::lower::analyzeOpenMPDeclarativeConstruct(
+    Fortran::lower::AbstractConverter &converter,
+    Fortran::lower::pft::Evaluation &eval,
+    const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl,
+    bool &ompDeviceCodeFound) {
+  std::visit(
+      Fortran::common::visitors{
+          [&](const Fortran::parser::OpenMPDeclareTargetConstruct &ompReq) {
+            mlir::omp::DeclareTargetDeviceType targetType =
+                getDeclareTargetFunctionDevice(converter, eval, ompReq)
+                    .value_or(mlir::omp::DeclareTargetDeviceType::host);
+
+            ompDeviceCodeFound =
+                ompDeviceCodeFound ||
+                targetType != mlir::omp::DeclareTargetDeviceType::host;
+          },
+          [&](const auto &) {},
+      },
+      ompDecl.u);
+}
+
 void Fortran::lower::genOpenMPDeclarativeConstruct(
     Fortran::lower::AbstractConverter &converter,
     Fortran::lower::pft::Evaluation &eval,
@@ -3151,7 +3208,10 @@
           },
           [&](const Fortran::parser::OpenMPRequiresConstruct
                   &requiresConstruct) {
-            TODO(converter.getCurrentLocation(), "OpenMPRequiresConstruct");
+            // Requires directives are gathered and processed in semantics and
+            // then combined in the lowering bridge before triggering codegen
+            // just once. Hence, there is no need to lower each individual
+            // occurrence here.
           },
           [&](const Fortran::parser::OpenMPThreadprivate &threadprivate) {
             // The directive is lowered when instantiating the variable to
@@ -3465,3 +3525,55 @@
     }
   }
 }
+
+bool Fortran::lower::isOpenMPTargetConstruct(
+    const Fortran::parser::OpenMPConstruct &omp) {
+  llvm::omp::Directive dir = llvm::omp::Directive::OMPD_unknown;
+  if (const auto *block =
+          std::get_if<Fortran::parser::OpenMPBlockConstruct>(&omp.u)) {
+    const auto &begin =
+        std::get<Fortran::parser::OmpBeginBlockDirective>(block->t);
+    dir = std::get<Fortran::parser::OmpBlockDirective>(begin.t).v;
+  } else if (const auto *loop =
+                 std::get_if<Fortran::parser::OpenMPLoopConstruct>(&omp.u)) {
+    const auto &begin =
+        std::get<Fortran::parser::OmpBeginLoopDirective>(loop->t);
+    dir = std::get<Fortran::parser::OmpLoopDirective>(begin.t).v;
+  }
+  return llvm::omp::allTargetSet.test(dir);
+}
+
+void Fortran::lower::genOpenMPRequires(
+    mlir::Operation *mod, const Fortran::semantics::Symbol *symbol) {
+  using MlirRequires = mlir::omp::ClauseRequires;
+  using SemaRequires = Fortran::semantics::WithOmpDeclarative::RequiresFlag;
+
+  if (auto offloadMod =
+          llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(mod)) {
+    Fortran::semantics::WithOmpDeclarative::RequiresFlags semaFlags;
+    if (symbol) {
+      Fortran::common::visit(
+          [&](const auto &details) {
+            if constexpr (std::is_base_of_v<
+                              Fortran::semantics::WithOmpDeclarative,
+                              std::decay_t<decltype(details)>>) {
+              if (details.has_ompRequires())
+                semaFlags = *details.ompRequires();
+            }
+          },
+          symbol->details());
+    }
+
+    MlirRequires mlirFlags = MlirRequires::none;
+    if (semaFlags.test(SemaRequires::ReverseOffload))
+      mlirFlags = mlirFlags | MlirRequires::reverse_offload;
+    if (semaFlags.test(SemaRequires::UnifiedAddress))
+      mlirFlags = mlirFlags | MlirRequires::unified_address;
+    if (semaFlags.test(SemaRequires::UnifiedSharedMemory))
+      mlirFlags = mlirFlags | MlirRequires::unified_shared_memory;
+    if (semaFlags.test(SemaRequires::DynamicAllocators))
+      mlirFlags = mlirFlags | MlirRequires::dynamic_allocators;
+
+    offloadMod.setRequires(mlirFlags);
+  }
+}
Index: flang/lib/Lower/Bridge.cpp
===================================================================
--- flang/lib/Lower/Bridge.cpp
+++ flang/lib/Lower/Bridge.cpp
@@ -50,6 +50,7 @@
 #include "flang/Parser/parse-tree.h"
 #include "flang/Runtime/iostat.h"
 #include "flang/Semantics/runtime-type-info.h"
+#include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
 #include "mlir/IR/PatternMatch.h"
@@ -294,12 +295,15 @@
     //    that they are available before lowering any function that may use
     //    them.
     bool hasMainProgram = false;
+    const Fortran::semantics::Symbol *globalOmpRequiresSymbol = nullptr;
     for (Fortran::lower::pft::Program::Units &u : pft.getUnits()) {
       std::visit(Fortran::common::visitors{
                      [&](Fortran::lower::pft::FunctionLikeUnit &f) {
                        if (f.isMainProgram())
                          hasMainProgram = true;
                        declareFunction(f);
+                       if (!globalOmpRequiresSymbol)
+                         globalOmpRequiresSymbol = f.getScope().symbol();
                      },
                      [&](Fortran::lower::pft::ModuleLikeUnit &m) {
                        lowerModuleDeclScope(m);
@@ -307,7 +311,10 @@
                             m.nestedFunctions)
                          declareFunction(f);
                      },
-                     [&](Fortran::lower::pft::BlockDataUnit &b) {},
+                     [&](Fortran::lower::pft::BlockDataUnit &b) {
+                       if (!globalOmpRequiresSymbol)
+                         globalOmpRequiresSymbol = b.symTab.symbol();
+                     },
                      [&](Fortran::lower::pft::CompilerDirectiveUnit &d) {},
                  },
                  u);
@@ -352,6 +359,7 @@
       });
 
     finalizeOpenACCLowering();
+    finalizeOpenMPLowering(globalOmpRequiresSymbol);
   }
 
   /// Declare a function.
@@ -2347,10 +2355,16 @@
 
     localSymbols.popScope();
     builder->restoreInsertionPoint(insertPt);
+
+    // Register if a target region was found
+    ompDeviceCodeFound =
+        ompDeviceCodeFound || Fortran::lower::isOpenMPTargetConstruct(omp);
   }
 
   void genFIR(const Fortran::parser::OpenMPDeclarativeConstruct &ompDecl) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
+    analyzeOpenMPDeclarativeConstruct(*this, getEval(), ompDecl,
+                                      ompDeviceCodeFound);
     genOpenMPDeclarativeConstruct(*this, getEval(), ompDecl);
     for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations())
       genFIR(e);
@@ -4758,6 +4772,16 @@
                                                      accRoutineInfos);
   }
 
+  /// Performing OpenMP lowering actions that were deferred to the end of
+  /// lowering.
+  void finalizeOpenMPLowering(
+      const Fortran::semantics::Symbol *globalOmpRequiresSymbol) {
+    // Set the module attribute related to OpenMP requires directives
+    if (ompDeviceCodeFound)
+      Fortran::lower::genOpenMPRequires(getModuleOp().getOperation(),
+                                        globalOmpRequiresSymbol);
+  }
+
   //===--------------------------------------------------------------------===//
 
   Fortran::lower::LoweringBridge &bridge;
@@ -4804,6 +4828,10 @@
 
   /// Deferred OpenACC routine attachment.
   Fortran::lower::AccRoutineInfoMappingList accRoutineInfos;
+
+  /// Whether an OpenMP target region or declare target function/subroutine
+  /// intended for device offloading has been detected
+  bool ompDeviceCodeFound = false;
 };
 
 } // namespace
Index: flang/include/flang/Lower/OpenMP.h
===================================================================
--- flang/include/flang/Lower/OpenMP.h
+++ flang/include/flang/Lower/OpenMP.h
@@ -34,6 +34,10 @@
 struct OmpClauseList;
 } // namespace parser
 
+namespace semantics {
+class Symbol;
+} // namespace semantics
+
 namespace lower {
 
 class AbstractConverter;
@@ -49,6 +53,9 @@
 
 void genOpenMPConstruct(AbstractConverter &, pft::Evaluation &,
                         const parser::OpenMPConstruct &);
+void analyzeOpenMPDeclarativeConstruct(
+    Fortran::lower::AbstractConverter &, Fortran::lower::pft::Evaluation &,
+    const parser::OpenMPDeclarativeConstruct &, bool &);
 void genOpenMPDeclarativeConstruct(AbstractConverter &, pft::Evaluation &,
                                    const parser::OpenMPDeclarativeConstruct &);
 int64_t getCollapseValue(const Fortran::parser::OmpClauseList &clauseList);
@@ -62,6 +69,10 @@
 void updateReduction(mlir::Operation *, fir::FirOpBuilder &, mlir::Value,
                      mlir::Value, fir::ConvertOp * = nullptr);
 void removeStoreOp(mlir::Operation *, mlir::Value);
+
+bool isOpenMPTargetConstruct(const parser::OpenMPConstruct &);
+void genOpenMPRequires(mlir::Operation *, const Fortran::semantics::Symbol *);
+
 } // namespace lower
 } // namespace Fortran
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to