jdoerfert created this revision.
jdoerfert added reviewers: kiranchandramohan, ABataev, RaviNarayanaswamy, 
gtbercea, grokos, sdmitriev, JonChesterfield, hfinkel, fghanim.
Herald added subscribers: llvm-commits, guansong, bollu, jholewinski.
Herald added projects: clang, LLVM.

This is the first part extracted from D71179 <https://reviews.llvm.org/D71179> 
and cleaned up.

This patch provides parsing support for `omp begin/end declare variant`,
as defined in OpenMP technical report 8 (TR8).

A major purpose of this patch is to provide proper math.h/cmath support
for OpenMP target offloading. See PR42061, PR42798, PR42799.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D74941

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/Basic/DiagnosticParseKinds.td
  clang/include/clang/Parse/Parser.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPKinds.def

Index: llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
+++ llvm/include/llvm/Frontend/OpenMP/OMPKinds.def
@@ -91,6 +91,8 @@
 __OMP_DIRECTIVE_EXT(master_taskloop_simd, "master taskloop simd")
 __OMP_DIRECTIVE_EXT(parallel_master_taskloop_simd,
                     "parallel master taskloop simd")
+__OMP_DIRECTIVE_EXT(begin_declare_variant, "begin declare variant")
+__OMP_DIRECTIVE_EXT(end_declare_variant, "end declare variant")
 
 // Has to be the last because Clang implicitly expects it to be.
 __OMP_DIRECTIVE(unknown)
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -3747,6 +3747,8 @@
   case OMPD_end_declare_target:
   case OMPD_requires:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
@@ -4946,6 +4948,8 @@
   case OMPD_declare_simd:
   case OMPD_requires:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
@@ -11074,6 +11078,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_teams:
@@ -11144,6 +11150,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_teams:
@@ -11219,6 +11227,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -11291,6 +11301,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -11364,6 +11376,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -11436,6 +11450,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -11507,6 +11523,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
@@ -11581,6 +11599,8 @@
     case OMPD_declare_mapper:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_simd:
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -47,6 +47,8 @@
   OMPD_target_teams_distribute_parallel,
   OMPD_mapper,
   OMPD_variant,
+  OMPD_begin,
+  OMPD_begin_declare,
 };
 
 // Helper to unify the enum class OpenMPDirectiveKind with its extension
@@ -100,6 +102,7 @@
       .Case("update", OMPD_update)
       .Case("mapper", OMPD_mapper)
       .Case("variant", OMPD_variant)
+      .Case("begin", OMPD_begin)
       .Default(OMPD_unknown);
 }
 
@@ -108,18 +111,21 @@
   // E.g.: OMPD_for OMPD_simd ===> OMPD_for_simd
   // TODO: add other combined directives in topological order.
   static const OpenMPDirectiveKindExWrapper F[][3] = {
+      {OMPD_begin, OMPD_declare, OMPD_begin_declare},
+      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_cancellation, OMPD_point, OMPD_cancellation_point},
       {OMPD_declare, OMPD_reduction, OMPD_declare_reduction},
       {OMPD_declare, OMPD_mapper, OMPD_declare_mapper},
       {OMPD_declare, OMPD_simd, OMPD_declare_simd},
       {OMPD_declare, OMPD_target, OMPD_declare_target},
       {OMPD_declare, OMPD_variant, OMPD_declare_variant},
+      {OMPD_begin_declare, OMPD_variant, OMPD_begin_declare_variant},
+      {OMPD_end_declare, OMPD_variant, OMPD_end_declare_variant},
       {OMPD_distribute, OMPD_parallel, OMPD_distribute_parallel},
       {OMPD_distribute_parallel, OMPD_for, OMPD_distribute_parallel_for},
       {OMPD_distribute_parallel_for, OMPD_simd,
        OMPD_distribute_parallel_for_simd},
       {OMPD_distribute, OMPD_simd, OMPD_distribute_simd},
-      {OMPD_end, OMPD_declare, OMPD_end_declare},
       {OMPD_end_declare, OMPD_target, OMPD_end_declare_target},
       {OMPD_target, OMPD_data, OMPD_target_data},
       {OMPD_target, OMPD_enter, OMPD_target_enter},
@@ -1339,6 +1345,29 @@
     return;
   }
 
+  OMPTraitInfo TI;
+  if (ParseOMPDeclareVariantMatchClause(Loc, TI))
+    return;
+
+  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
+      Actions.checkOpenMPDeclareVariantFunction(
+          Ptr, AssociatedFunction.get(), TI,
+          SourceRange(Loc, Tok.getLocation()));
+
+  // Skip last tokens.
+  while (Tok.isNot(tok::annot_pragma_openmp_end))
+    ConsumeAnyToken();
+  if (DeclVarData.hasValue() && !TI.Sets.empty())
+    Actions.ActOnOpenMPDeclareVariantDirective(
+        DeclVarData.getValue().first, DeclVarData.getValue().second, TI,
+        SourceRange(Loc, Tok.getLocation()));
+
+  // Skip the last annot_pragma_openmp_end.
+  (void)ConsumeAnnotationToken();
+}
+
+bool Parser::ParseOMPDeclareVariantMatchClause(SourceLocation Loc,
+                                               OMPTraitInfo &TI) {
   // Parse 'match'.
   OpenMPClauseKind CKind = Tok.isAnnotation()
                                ? OMPC_unknown
@@ -1350,7 +1379,7 @@
       ;
     // Skip the last annot_pragma_openmp_end.
     (void)ConsumeAnnotationToken();
-    return;
+    return true;
   }
   (void)ConsumeToken();
   // Parse '('.
@@ -1361,31 +1390,15 @@
       ;
     // Skip the last annot_pragma_openmp_end.
     (void)ConsumeAnnotationToken();
-    return;
+    return true;
   }
 
   // Parse inner context selectors.
-  OMPTraitInfo TI;
   parseOMPContextSelectors(Loc, TI);
 
   // Parse ')'
   (void)T.consumeClose();
-
-  Optional<std::pair<FunctionDecl *, Expr *>> DeclVarData =
-      Actions.checkOpenMPDeclareVariantFunction(
-          Ptr, AssociatedFunction.get(), TI,
-          SourceRange(Loc, Tok.getLocation()));
-
-  // Skip last tokens.
-  while (Tok.isNot(tok::annot_pragma_openmp_end))
-    ConsumeAnyToken();
-  if (DeclVarData.hasValue() && !TI.Sets.empty())
-    Actions.ActOnOpenMPDeclareVariantDirective(
-        DeclVarData.getValue().first, DeclVarData.getValue().second, TI,
-        SourceRange(Loc, Tok.getLocation()));
-
-  // Skip the last annot_pragma_openmp_end.
-  (void)ConsumeAnnotationToken();
+  return false;
 }
 
 /// Parsing of simple OpenMP clauses like 'default' or 'proc_bind'.
@@ -1523,13 +1536,28 @@
   return Actions.BuildDeclaratorGroup(Decls);
 }
 
+bool Parser::parseOMPEndDirective(OpenMPDirectiveKind MatchingKind,
+                                  OpenMPDirectiveKind ExpectedKind,
+                                  OpenMPDirectiveKind FoundKind,
+                                  SourceLocation MatchingLoc,
+                                  SourceLocation FoundLoc) {
+  int DiagSelection = ExpectedKind == OMPD_end_declare_target ? 0 : 1;
+
+  if (FoundKind != ExpectedKind) {
+    Diag(FoundLoc, diag::err_expected_end_declare_target_or_variant)
+        << DiagSelection;
+    Diag(MatchingLoc, diag::note_matching)
+        << ("'#pragma omp " + getOpenMPDirectiveName(MatchingKind) + "'").str();
+    return true;
+  }
+  return false;
+}
+
 void Parser::ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind,
-                                               SourceLocation DTLoc) {
-  if (DKind != OMPD_end_declare_target) {
-    Diag(Tok, diag::err_expected_end_declare_target);
-    Diag(DTLoc, diag::note_matching) << "'#pragma omp declare target'";
+                                               SourceLocation DKLoc) {
+  if (parseOMPEndDirective(OMPD_declare_target, OMPD_end_declare_target, DKind,
+                           DKLoc, Tok.getLocation()))
     return;
-  }
   ConsumeAnyToken();
   if (Tok.isNot(tok::annot_pragma_openmp_end)) {
     Diag(Tok, diag::warn_omp_extra_tokens_at_eol)
@@ -1741,6 +1769,57 @@
     }
     break;
   }
+  case OMPD_begin_declare_variant: {
+    // The syntax is:
+    // { #pragma omp begin declare variant clause }
+    // <function-declaration-or-definition-sequence>
+    // { #pragma omp end declare variant }
+    //
+    SourceLocation BeginLoc = ConsumeToken();
+
+    OMPTraitInfo TI;
+    if (ParseOMPDeclareVariantMatchClause(Loc, TI))
+      break;
+
+    // Skip last tokens.
+    while (Tok.isNot(tok::annot_pragma_openmp_end))
+      ConsumeAnyToken();
+
+    VariantMatchInfo VMI;
+    ASTContext &ASTCtx = Actions.getASTContext();
+    TI.getAsVariantMatchInfo(ASTCtx, VMI, /* DeviseSetOnly */ true);
+    OMPContext OMPCtx(ASTCtx.getLangOpts().OpenMPIsDevice,
+                      ASTCtx.getTargetInfo().getTriple());
+
+    bool IsApplicableOpenMPSelector = isVariantApplicableInContext(VMI, OMPCtx);
+    if (IsApplicableOpenMPSelector)
+      break;
+
+    // Elide all the code till the matching end declare variant was found.
+    unsigned Nesting = 1;
+    SourceLocation DKLoc;
+    OpenMPDirectiveKind DK = OMPD_unknown;
+    do {
+      DKLoc = Tok.getLocation();
+      DK = parseOpenMPDirectiveKind(*this);
+      if (DK == OMPD_end_declare_variant)
+        --Nesting;
+      if (DK == OMPD_begin_declare_variant)
+        ++Nesting;
+      if (!Nesting || isEofOrEom())
+        break;
+      ConsumeAnyToken();
+    } while (true);
+
+    parseOMPEndDirective(OMPD_begin_declare_variant, OMPD_end_declare_variant,
+                         DK, BeginLoc, DKLoc);
+    break;
+  }
+  case OMPD_end_declare_variant:
+    // FIXME: With the sema changes we will keep track of nesting and be able to
+    // diagnose unmatchend OMPD_end_declare_variant.
+    ConsumeToken();
+    break;
   case OMPD_declare_variant:
   case OMPD_declare_simd: {
     // The syntax is:
@@ -2233,6 +2312,8 @@
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_requires:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_variant:
     Diag(Tok, diag::err_omp_unexpected_directive)
         << 1 << getOpenMPDirectiveName(DKind);
Index: clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -801,6 +801,8 @@
     case OMPD_target_update:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -877,6 +879,8 @@
   case OMPD_target_update:
   case OMPD_declare_simd:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -1046,6 +1050,8 @@
     case OMPD_target_update:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -1128,6 +1134,8 @@
   case OMPD_target_update:
   case OMPD_declare_simd:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6717,6 +6717,8 @@
   case OMPD_target_update:
   case OMPD_declare_simd:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -7028,6 +7030,8 @@
   case OMPD_target_update:
   case OMPD_declare_simd:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_declare_reduction:
@@ -8807,6 +8811,8 @@
     case OMPD_target_update:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -9570,6 +9576,8 @@
     case OMPD_target_update:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -10207,6 +10215,8 @@
     case OMPD_teams_distribute_parallel_for_simd:
     case OMPD_declare_simd:
     case OMPD_declare_variant:
+    case OMPD_begin_declare_variant:
+    case OMPD_end_declare_variant:
     case OMPD_declare_target:
     case OMPD_end_declare_target:
     case OMPD_declare_reduction:
@@ -11077,7 +11087,8 @@
   for (const auto *A : FD->specific_attrs<OMPDeclareVariantAttr>()) {
     const OMPTraitInfo &TI = A->getTraitInfos();
     VMIs.push_back(VariantMatchInfo());
-    TI.getAsVariantMatchInfo(CGM.getContext(), VMIs.back());
+    TI.getAsVariantMatchInfo(CGM.getContext(), VMIs.back(),
+                             /* DeviceSetOnly */ false);
     VariantExprs.push_back(A->getVariantFuncRef());
   }
 
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -941,6 +941,8 @@
       break;
     }
     break;
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
   case OMPD_declare_target:
   case OMPD_end_declare_target:
   case OMPD_unknown:
@@ -1202,6 +1204,8 @@
   case OMPD_end_declare_target:
   case OMPD_requires:
   case OMPD_declare_variant:
+  case OMPD_begin_declare_variant:
+  case OMPD_end_declare_variant:
     llvm_unreachable("OpenMP Directive is not allowed");
   case OMPD_unknown:
     llvm_unreachable("Unknown OpenMP directive");
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -1724,9 +1724,12 @@
      << ")";
 }
 
-void OMPTraitInfo::getAsVariantMatchInfo(
-    ASTContext &ASTCtx, llvm::omp::VariantMatchInfo &VMI) const {
+void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx,
+                                         llvm::omp::VariantMatchInfo &VMI,
+                                         bool DeviceSetOnly) const {
   for (const OMPTraitSet &Set : Sets) {
+    if (DeviceSetOnly && Set.Kind != llvm::omp::TraitSet::device)
+      continue;
     for (const OMPTraitSelector &Selector : Set.Selectors) {
 
       // User conditions are special as we evaluate the condition here.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -2954,14 +2954,31 @@
   /// Parses OpenMP context selectors.
   bool parseOMPContextSelectors(SourceLocation Loc, OMPTraitInfo &TI);
 
+  /// Parse a `match` clause for an '#pragma omp declare variant'. Return true
+  /// if there was an error.
+  bool ParseOMPDeclareVariantMatchClause(SourceLocation Loc, OMPTraitInfo &TI);
+
   /// Parse clauses for '#pragma omp declare variant'.
   void ParseOMPDeclareVariantClauses(DeclGroupPtrTy Ptr, CachedTokens &Toks,
                                      SourceLocation Loc);
+  /// Parse '#pragma omp end declare variant'.
+  void ParseOMPEndDeclareVariantDirective(OpenMPDirectiveKind DKind,
+                                          SourceLocation Loc);
   /// Parse clauses for '#pragma omp declare target'.
   DeclGroupPtrTy ParseOMPDeclareTargetClauses();
   /// Parse '#pragma omp end declare target'.
   void ParseOMPEndDeclareTargetDirective(OpenMPDirectiveKind DKind,
                                          SourceLocation Loc);
+
+  /// Parse the "end" directive \p ExpectedKind matching the "begin" directive
+  /// \p MatchingKind. The actual kind found is \p FoundKind. Returns true if
+  /// the expected kind was not found.
+  bool parseOMPEndDirective(OpenMPDirectiveKind MatchingKind,
+                            OpenMPDirectiveKind ExpectedKind,
+                            OpenMPDirectiveKind FoundKind,
+                            SourceLocation MatchingLoc,
+                            SourceLocation FoundLoc);
+
   /// Parses declarative OpenMP directives.
   DeclGroupPtrTy ParseOpenMPDeclarativeDirectiveWithExtDecl(
       AccessSpecifier &AS, ParsedAttributesWithRange &Attrs,
Index: clang/include/clang/Basic/DiagnosticParseKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticParseKinds.td
+++ clang/include/clang/Basic/DiagnosticParseKinds.td
@@ -1248,6 +1248,8 @@
   "unexpected '%0' clause, '%1' is specified already">;
 def err_expected_end_declare_target : Error<
   "expected '#pragma omp end declare target'">;
+def err_expected_end_declare_target_or_variant : Error<
+  "expected '#pragma omp end declare %select{target|variant}0'">;
 def err_omp_declare_target_unexpected_clause: Error<
   "unexpected '%0' clause, only %select{'to' or 'link'|'to', 'link' or 'device_type'}1 clauses expected">;
 def err_omp_expected_clause: Error<
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -6697,9 +6697,11 @@
   /// former is a flat representation the actual main difference is that the
   /// latter uses clang::Expr to store the score/condition while the former is
   /// independent of clang. Thus, expressions and conditions are evaluated in
-  /// this method.
+  /// this method. If \p DeviceSetOnly is set only the device selector set, if
+  /// present, is put into \p VMI.
   void getAsVariantMatchInfo(ASTContext &ASTCtx,
-                             llvm::omp::VariantMatchInfo &VMI) const;
+                             llvm::omp::VariantMatchInfo &VMI,
+                             bool DeviceSetOnly) const;
 
   /// Print a human readable representation into \p OS.
   void print(llvm::raw_ostream &OS, const PrintingPolicy &Policy) const;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to