ggeorgakoudis updated this revision to Diff 414139.
ggeorgakoudis added a comment.
Herald added subscribers: kbarton, nemanjai.

Fix tentative parsing reverting
Fix device codegen for metadirectives


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D120573/new/

https://reviews.llvm.org/D120573

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Lex/Preprocessor.h
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/CodeGen/CGOpenMPRuntime.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Lex/PPCaching.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/tools/libclang/CIndex.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -368,7 +368,9 @@
 def OMPC_Align : Clause<"align"> {
   let clangClass = "OMPAlignClause";
 }
-def OMPC_When: Clause<"when"> {}
+def OMPC_When: Clause<"when"> {
+  let clangClass = "OMPWhenClause";
+}
 
 def OMPC_Bind : Clause<"bind"> {
   let clangClass = "OMPBindClause";
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2591,6 +2591,10 @@
 }
 void OMPClauseEnqueue::VisitOMPBindClause(const OMPBindClause *C) {}
 
+void OMPClauseEnqueue::VisitOMPWhenClause(const OMPWhenClause *C) {
+  Visitor->AddStmt(C->getDirective());
+}
+
 } // namespace
 
 void EnqueueVisitor::EnqueueChildren(const OMPClause *S) {
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6864,6 +6864,12 @@
   Record.AddSourceLocation(C->getBindKindLoc());
 }
 
+void OMPClauseWriter::VisitOMPWhenClause(OMPWhenClause *C) {
+  // TODO: check, not familiar with this.
+  Record.writeOMPTraitInfo(&C->getTraitInfo());
+  Record.AddStmt(C->getDirective());
+}
+
 void ASTRecordWriter::writeOMPTraitInfo(const OMPTraitInfo *TI) {
   writeUInt32(TI->Sets.size());
   for (const auto &Set : TI->Sets) {
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -12994,6 +12994,12 @@
   C->setBindKindLoc(Record.readSourceLocation());
 }
 
+void OMPClauseReader::VisitOMPWhenClause(OMPWhenClause *C) {
+  // TODO: check, not familiar with this.
+  C->setTraitInfo(Record.readOMPTraitInfo());
+  C->setDirective(Record.readStmt());
+}
+
 void OMPClauseReader::VisitOMPAlignClause(OMPAlignClause *C) {
   C->setAlignment(Record.readExpr());
   C->setLParenLoc(Record.readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -2284,6 +2284,18 @@
     return getSema().ActOnOpenMPAlignClause(A, StartLoc, LParenLoc, EndLoc);
   }
 
+  /// Build a new OpenMP 'when' clause.
+  ///
+  /// By default, performs semantic analysis to build the new OpenMP clause.
+  /// Subclasses may override this routine to provide different behavior.
+  OMPClause *RebuildOMPWhenClause(OMPTraitInfo &TI, Stmt *Directive,
+                                  SourceLocation StartLoc,
+                                  SourceLocation LParenLoc,
+                                  SourceLocation EndLoc) {
+    return getSema().ActOnOpenMPWhenClause(TI, Directive, StartLoc, LParenLoc,
+                                           EndLoc);
+  }
+
   /// Rebuild the operand to an Objective-C \@synchronized statement.
   ///
   /// By default, performs semantic analysis to build the new statement.
@@ -10323,6 +10335,13 @@
       C->getLParenLoc(), C->getEndLoc());
 }
 
+template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPWhenClause(OMPWhenClause *C) {
+  return getDerived().RebuildOMPWhenClause(C->getTraitInfo(), C->getDirective(),
+                                           C->getBeginLoc(), C->getLParenLoc(),
+                                           C->getEndLoc());
+}
+
 //===----------------------------------------------------------------------===//
 // Expression transformation
 //===----------------------------------------------------------------------===//
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -21631,3 +21631,17 @@
   return OMPBindClause::Create(Context, Kind, KindLoc, StartLoc, LParenLoc,
                                EndLoc);
 }
+
+OMPClause *Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, StmtResult Directive,
+                                       SourceLocation StartLoc,
+                                       SourceLocation LParenLoc,
+                                       SourceLocation EndLoc) {
+  return new (Context)
+      OMPWhenClause(TI, Directive.get(), StartLoc, LParenLoc, EndLoc);
+}
+
+StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef<OMPClause *> Clauses,
+                                          SourceLocation StartLoc,
+                                          SourceLocation EndLoc) {
+  return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses);
+}
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -2488,12 +2488,12 @@
     // First iteration of parsing all clauses of metadirective.
     // This iteration only parses and collects all context selector ignoring the
     // associated directives.
-    TentativeParsingAction TPA(*this);
     ASTContext &ASTContext = Actions.getASTContext();
 
     BalancedDelimiterTracker T(*this, tok::l_paren,
                                tok::annot_pragma_openmp_end);
     while (Tok.isNot(tok::annot_pragma_openmp_end)) {
+      TentativeParsingAction TPA(*this);
       OpenMPClauseKind CKind = Tok.isAnnotation()
                                    ? OMPC_unknown
                                    : getOpenMPClauseKind(PP.getSpelling(Tok));
@@ -2505,122 +2505,102 @@
         return Directive;
 
       OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
-      if (CKind == OMPC_when) {
-        // parse and get OMPTraitInfo to pass to the When clause
-        parseOMPContextSelectors(Loc, TI);
-        if (TI.Sets.size() == 0) {
-          Diag(Tok, diag::err_omp_expected_context_selector) << "when clause";
-          TPA.Commit();
-          return Directive;
-        }
-
-        // Parse ':'
-        if (Tok.is(tok::colon))
-          ConsumeAnyToken();
-        else {
-          Diag(Tok, diag::err_omp_expected_colon) << "when clause";
-          TPA.Commit();
-          return Directive;
-        }
-      }
-      // Skip Directive for now. We will parse directive in the second iteration
-      int paren = 0;
-      while (Tok.isNot(tok::r_paren) || paren != 0) {
-        if (Tok.is(tok::l_paren))
-          paren++;
-        if (Tok.is(tok::r_paren))
-          paren--;
-        if (Tok.is(tok::annot_pragma_openmp_end)) {
-          Diag(Tok, diag::err_omp_expected_punc)
-              << getOpenMPClauseName(CKind) << 0;
-          TPA.Commit();
-          return Directive;
-        }
-        ConsumeAnyToken();
-      }
-      // Parse ')'
-      if (Tok.is(tok::r_paren))
-        T.consumeClose();
-
-      VariantMatchInfo VMI;
-      TI.getAsVariantMatchInfo(ASTContext, VMI);
-
-      VMIs.push_back(VMI);
-    }
-
-    TPA.Revert();
-    // End of the first iteration. Parser is reset to the start of metadirective
+      if (CKind == OMPC_when || CKind == OMPC_default) {
+        // If it is a "when" clause parse the context selectors for the trait
+        // info, a "default" clause will have empty trait info.
+        if (CKind == OMPC_when) {
+          // parse and get OMPTraitInfo to pass to the When clause
+          parseOMPContextSelectors(Loc, TI);
+          if (TI.Sets.size() == 0) {
+            Diag(Tok, diag::err_omp_expected_context_selector) << "when clause";
+            TPA.Commit();
+            return Directive;
+          }
 
-    std::function<void(StringRef)> DiagUnknownTrait = [this, Loc](
-                                                          StringRef ISATrait) {
-      // TODO Track the selector locations in a way that is accessible here to
-      // improve the diagnostic location.
-      Diag(Loc, diag::warn_unknown_declare_variant_isa_trait) << ISATrait;
-    };
-    TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait),
-                            /* CurrentFunctionDecl */ nullptr,
-                            ArrayRef<llvm::omp::TraitProperty>());
+          // Parse ':'
+          if (Tok.is(tok::colon))
+            ConsumeAnyToken();
+          else {
+            Diag(Tok, diag::err_omp_expected_colon) << "when clause";
+            TPA.Commit();
+            return Directive;
+          }
 
-    // A single match is returned for OpenMP 5.0
-    int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx);
+          VariantMatchInfo VMI;
+          TI.getAsVariantMatchInfo(ASTContext, VMI);
+          VMIs.push_back(VMI);
+        }
 
-    int Idx = 0;
-    // In OpenMP 5.0 metadirective is either replaced by another directive or
-    // ignored.
-    // TODO: In OpenMP 5.1 generate multiple directives based upon the matches
-    // found by getBestWhenMatchForContext.
-    while (Tok.isNot(tok::annot_pragma_openmp_end)) {
-      // OpenMP 5.0 implementation - Skip to the best index found.
-      if (Idx++ != BestIdx) {
+        // TODO: currently expects a directive, OpenMP 5.1 specifies nothing
+        // directive when there is none.
+        ReadDirectiveWithinMetadirective = true;
+        StmtResult WhenDirective =
+            ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx);
+        ReadDirectiveWithinMetadirective = false;
+        if (OMPExecutableDirective *D =
+                dyn_cast<OMPExecutableDirective>(WhenDirective.get()))
+          if (D->hasAssociatedStmt())
+            HasAssociatedStatement = true;
+
+        auto *WhenClause = Actions.ActOnOpenMPWhenClause(
+            TI, WhenDirective, Loc, T.getOpenLocation(), T.getCloseLocation());
+        Clauses.push_back(WhenClause);
+
+        // Revert back to the beginning of the clause.
+        TPA.Revert(/*RevertAnnotations*/ true);
+
+        // Skip until the end of this clause.
         ConsumeToken();  // Consume clause name
         T.consumeOpen(); // Consume '('
         int paren = 0;
-        // Skip everything inside the clause
         while (Tok.isNot(tok::r_paren) || paren != 0) {
           if (Tok.is(tok::l_paren))
             paren++;
           if (Tok.is(tok::r_paren))
             paren--;
+          if (Tok.is(tok::annot_pragma_openmp_end)) {
+            Diag(Tok, diag::err_omp_expected_punc)
+                << getOpenMPClauseName(CKind) << 0;
+            return Directive;
+          }
           ConsumeAnyToken();
         }
-        // Parse ')'
-        if (Tok.is(tok::r_paren))
-          T.consumeClose();
-        continue;
-      }
-
-      OpenMPClauseKind CKind = Tok.isAnnotation()
-                                   ? OMPC_unknown
-                                   : getOpenMPClauseKind(PP.getSpelling(Tok));
-      SourceLocation Loc = ConsumeToken();
-
-      // Parse '('.
-      T.consumeOpen();
 
-      // Skip ContextSelectors for when clause
-      if (CKind == OMPC_when) {
-        OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
-        // parse and skip the ContextSelectors
-        parseOMPContextSelectors(Loc, TI);
-
-        // Parse ':'
-        ConsumeAnyToken();
+        assert(Tok.is(tok::r_paren) &&
+               "Expected right paren ending the clause.");
+        T.consumeClose();
       }
+    }
 
-      // If no directive is passed, skip in OpenMP 5.0.
-      // TODO: Generate nothing directive from OpenMP 5.1.
-      if (Tok.is(tok::r_paren)) {
-        SkipUntil(tok::annot_pragma_openmp_end);
-        break;
-      }
+    // Skip until the end of the metadirective.
+    SkipUntil(tok::annot_pragma_openmp_end);
+    // Skip any associated statement.
+    if (HasAssociatedStatement)
+      ParseStatement();
+
+    std::function<void(StringRef)> DiagUnknownTrait =
+        [this, Loc](StringRef ISATrait) {
+          // TODO Track the selector locations in a way that is accessible here
+          // to improve the diagnostic location.
+          Diag(Loc, diag::warn_unknown_declare_variant_isa_trait) << ISATrait;
+        };
+    TargetOMPContext OMPCtx(ASTContext, std::move(DiagUnknownTrait),
+                            /* CurrentFunctionDecl */ nullptr,
+                            ArrayRef<llvm::omp::TraitProperty>());
 
-      // Parse Directive
-      ReadDirectiveWithinMetadirective = true;
-      Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx);
-      ReadDirectiveWithinMetadirective = false;
-      break;
+    // Find applicable clauses.
+    SmallVector<OMPClause *, 5> ApplicableClauses;
+    for (OMPClause *C : Clauses) {
+      OMPTraitInfo &TI = cast<OMPWhenClause>(C)->getTraitInfo();
+      VariantMatchInfo VMI;
+      TI.getAsVariantMatchInfo(ASTContext, VMI);
+      SmallVector<unsigned, 8> ConstructMatches;
+      if (isVariantApplicableInContext(VMI, OMPCtx, /*DeviceOnly*/ false))
+        ApplicableClauses.push_back(C);
     }
-    break;
+
+    return Actions.ActOnOpenMPMetaDirective(ApplicableClauses, Loc,
+                                            Tok.getLastLoc());
   }
   case OMPD_threadprivate: {
     // FIXME: Should this be permitted in C++?
Index: clang/lib/Lex/PPCaching.cpp
===================================================================
--- clang/lib/Lex/PPCaching.cpp
+++ clang/lib/Lex/PPCaching.cpp
@@ -37,12 +37,18 @@
 
 // Make Preprocessor re-lex the tokens that were lexed since
 // EnableBacktrackAtThisPos() was previously called.
-void Preprocessor::Backtrack() {
-  assert(!BacktrackPositions.empty()
-         && "EnableBacktrackAtThisPos was not called!");
+void Preprocessor::Backtrack(bool ReverseAnnotations = false) {
+  assert(!BacktrackPositions.empty() &&
+         "EnableBacktrackAtThisPos was not called!");
   CachedLexPos = BacktrackPositions.back();
   BacktrackPositions.pop_back();
   recomputeCurLexerKind();
+
+  if (ReverseAnnotations) {
+    CachedTokens.erase(CachedTokens.begin() + CachedLexPos, CachedTokens.end());
+    CachedTokens.append(UnannotatedCachedTokens.begin() + CachedLexPos,
+                        UnannotatedCachedTokens.end());
+  }
 }
 
 void Preprocessor::CachingLex(Token &Result) {
@@ -66,6 +72,7 @@
     // Cache the lexed token.
     EnterCachingLexModeUnchecked();
     CachedTokens.push_back(Result);
+    UnannotatedCachedTokens.push_back(Result);
     ++CachedLexPos;
     return;
   }
@@ -75,6 +82,7 @@
   } else {
     // All cached tokens were consumed.
     CachedTokens.clear();
+    UnannotatedCachedTokens.clear();
     CachedLexPos = 0;
   }
 }
@@ -106,8 +114,10 @@
   assert(CachedLexPos + N > CachedTokens.size() && "Confused caching.");
   ExitCachingLexMode();
   for (size_t C = CachedLexPos + N - CachedTokens.size(); C > 0; --C) {
-    CachedTokens.push_back(Token());
-    Lex(CachedTokens.back());
+    Token Result;
+    Lex(Result);
+    CachedTokens.push_back(Result);
+    UnannotatedCachedTokens.push_back(Result);
   }
   EnterCachingLexMode();
   return CachedTokens.back();
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -24,6 +24,7 @@
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/PrettyStackTrace.h"
+#include "clang/Parse/ParseDiagnostic.h"
 #include "llvm/BinaryFormat/Dwarf.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
@@ -1788,8 +1789,72 @@
   checkForLastprivateConditionalUpdate(*this, S);
 }
 
-void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &S) {
-  EmitStmt(S.getIfStmt());
+void CodeGenFunction::EmitOMPMetaDirective(const OMPMetaDirective &D) {
+  llvm::BasicBlock *AfterBlock =
+      createBasicBlock("omp.meta.user.condition.after");
+
+  SmallVector<const OMPWhenClause *, 4> StaticWhenClauses;
+  SmallVector<VariantMatchInfo, 4> StaticVMIs;
+
+  for (auto *C : D.getClausesOfKind<OMPWhenClause>()) {
+    OMPTraitInfo &TI = C->getTraitInfo();
+
+    llvm::BasicBlock *ExitBlock =
+        createBasicBlock("omp.meta.user.condition.exit");
+
+    // Emit code to generate a dynamic condition, returns true if there is a
+    // condition, false otherwise.
+    auto GenerateCond = [&](Expr *&E, bool IsScore) {
+      if (IsScore)
+        return false;
+
+      // Do not emit code if the expression is statically resolvable, will be
+      // handled as a static when clause.
+      if (E->getIntegerConstantExpr(getContext()))
+        return false;
+
+      llvm::BasicBlock *TrueBlock = createBasicBlock("omp.meta.user.condition");
+      EmitBranchOnBoolExpr(E, TrueBlock, ExitBlock,
+                           getProfileCount(C->getDirective()));
+      EmitBlock(TrueBlock);
+      EmitStmt(C->getDirective());
+      EmitBranch(AfterBlock);
+
+      return true;
+    };
+
+    // If there is no dynamic condition for the clause then add it to the
+    // static clauses and resolve later the best match to generate code for.
+    // This also handles the default clause.
+    if (!TI.anyScoreOrCondition(GenerateCond)) {
+      StaticWhenClauses.push_back(C);
+      VariantMatchInfo VMI;
+      TI.getAsVariantMatchInfo(getContext(), VMI);
+      StaticVMIs.push_back(VMI);
+    } else
+      EmitBlock(ExitBlock);
+  }
+
+  // Emit code for static clauses, if any.
+  if (!StaticWhenClauses.empty()) {
+    std::function<void(StringRef)> DiagUnknownTrait = [&](StringRef ISATrait) {
+      CGM.getDiags().Report(D.getBeginLoc(),
+                            diag::warn_unknown_declare_variant_isa_trait)
+          << ISATrait;
+    };
+
+    TargetOMPContext OMPCtx(
+        getContext(), std::move(DiagUnknownTrait),
+        /* CurrentFunctionDecl */ nullptr,
+        /* ConstructTraits */ ArrayRef<llvm::omp::TraitProperty>());
+
+    int BestIdx = getBestVariantMatchForContext(StaticVMIs, OMPCtx);
+
+    EmitStmt(StaticWhenClauses[BestIdx]->getDirective());
+    EmitBranch(AfterBlock);
+  }
+
+  EmitBlock(AfterBlock);
 }
 
 namespace {
Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp
===================================================================
--- clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -10762,6 +10762,12 @@
   }
 
   if (const auto *E = dyn_cast<OMPExecutableDirective>(S)) {
+    if (E->getDirectiveKind() == OMPD_metadirective) {
+      for (const auto *C : E->getClausesOfKind<OMPWhenClause>())
+        if (C->getDirective())
+          scanForTargetRegionsFunctions(C->getDirective(), ParentName);
+    }
+
     if (!E->hasAssociatedStmt() || !E->getAssociatedStmt())
       return;
 
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -886,6 +886,11 @@
 }
 void OMPClauseProfiler::VisitOMPOrderClause(const OMPOrderClause *C) {}
 void OMPClauseProfiler::VisitOMPBindClause(const OMPBindClause *C) {}
+
+void OMPClauseProfiler::VisitOMPWhenClause(const OMPWhenClause *C) {
+  if (C->getDirective())
+    Profiler->VisitStmt(C->getDirective());
+}
 } // namespace
 
 void
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -262,11 +262,10 @@
 OMPMetaDirective *OMPMetaDirective::Create(const ASTContext &C,
                                            SourceLocation StartLoc,
                                            SourceLocation EndLoc,
-                                           ArrayRef<OMPClause *> Clauses,
-                                           Stmt *AssociatedStmt, Stmt *IfStmt) {
-  auto *Dir = createDirective<OMPMetaDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/1, StartLoc, EndLoc);
-  Dir->setIfStmt(IfStmt);
+                                           ArrayRef<OMPClause *> Clauses) {
+  auto *Dir =
+      createDirective<OMPMetaDirective>(C, Clauses, /*AssociatedStmt*/ nullptr,
+                                        /*NumChildren=*/0, StartLoc, EndLoc);
   return Dir;
 }
 
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -15,6 +15,7 @@
 #include "clang/AST/Attr.h"
 #include "clang/AST/Decl.h"
 #include "clang/AST/DeclOpenMP.h"
+#include "clang/AST/StmtOpenMP.h"
 #include "clang/Basic/LLVM.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/TargetInfo.h"
@@ -2334,6 +2335,32 @@
      << ")";
 }
 
+void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) {
+  OMPTraitInfo &TI = Node->getTraitInfo();
+  if (TI.Sets.empty())
+    OS << "default(";
+  else
+    OS << "when(";
+  TI.print(OS, Policy);
+  Stmt *S = Node->getDirective();
+  if (S) {
+    OS << ":";
+    OMPExecutableDirective *D = cast<OMPExecutableDirective>(S);
+    auto DKind = D->getDirectiveKind();
+    OS << getOpenMPDirectiveName(DKind);
+
+    OMPClausePrinter Printer(OS, Policy);
+    ArrayRef<OMPClause *> Clauses = D->clauses();
+    for (auto *Clause : Clauses)
+      if (Clause && !Clause->isImplicit()) {
+        OS << ' ';
+        Printer.Visit(Clause);
+      }
+  }
+  OS << ")";
+  return;
+}
+
 void OMPTraitInfo::getAsVariantMatchInfo(ASTContext &ASTCtx,
                                          VariantMatchInfo &VMI) const {
   for (const OMPTraitSet &Set : Sets) {
@@ -2348,13 +2375,13 @@
                    TraitProperty::user_condition_unknown &&
                "Ill-formed user condition, expected unknown trait property!");
 
+        // If Condition is statically resolvable add it as a trait, otherwise
+        // do nothing since codegen will generate dynamic conditions.
         if (Optional<APSInt> CondVal =
                 Selector.ScoreOrCondition->getIntegerConstantExpr(ASTCtx))
           VMI.addTrait(CondVal->isZero() ? TraitProperty::user_condition_false
                                          : TraitProperty::user_condition_true,
                        "<condition>");
-        else
-          VMI.addTrait(TraitProperty::user_condition_false, "<condition>");
         continue;
       }
 
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -10543,7 +10543,7 @@
   /// Called on well-formed '\#pragma omp metadirective' after parsing
   /// of the  associated statement.
   StmtResult ActOnOpenMPMetaDirective(ArrayRef<OMPClause *> Clauses,
-                                      Stmt *AStmt, SourceLocation StartLoc,
+                                      SourceLocation StartLoc,
                                       SourceLocation EndLoc);
 
   // OpenMP directives and clauses.
@@ -11139,7 +11139,8 @@
                                      SourceLocation LParenLoc,
                                      SourceLocation EndLoc);
   /// Called on well-formed 'when' clause.
-  OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc,
+  OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, StmtResult Directive,
+                                   SourceLocation StartLoc,
                                    SourceLocation LParenLoc,
                                    SourceLocation EndLoc);
   /// Called on well-formed 'default' clause.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -964,9 +964,9 @@
       P.PP.CommitBacktrackedTokens();
       isActive = false;
     }
-    void Revert() {
+    void Revert(bool ReverseAnnotations = false) {
       assert(isActive && "Parsing action was finished!");
-      P.PP.Backtrack();
+      P.PP.Backtrack(ReverseAnnotations);
       P.PreferredType = PrevPreferredType;
       P.Tok = PrevTok;
       P.TentativelyDeclaredIdentifiers.resize(
Index: clang/include/clang/Lex/Preprocessor.h
===================================================================
--- clang/include/clang/Lex/Preprocessor.h
+++ clang/include/clang/Lex/Preprocessor.h
@@ -899,6 +899,7 @@
   /// Cached tokens are stored here when we do backtracking or
   /// lookahead. They are "lexed" by the CachingLex() method.
   CachedTokensTy CachedTokens;
+  CachedTokensTy UnannotatedCachedTokens;
 
   /// The position of the cached token that CachingLex() should
   /// "lex" next.
@@ -1458,7 +1459,7 @@
 
   /// Make Preprocessor re-lex the tokens that were lexed since
   /// EnableBacktrackAtThisPos() was previously called.
-  void Backtrack();
+  void Backtrack(bool RevertAnnotations);
 
   /// True if EnableBacktrackAtThisPos() was called and
   /// caching of tokens is on.
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -5456,7 +5456,6 @@
 class OMPMetaDirective final : public OMPExecutableDirective {
   friend class ASTStmtReader;
   friend class OMPExecutableDirective;
-  Stmt *IfStmt;
 
   OMPMetaDirective(SourceLocation StartLoc, SourceLocation EndLoc)
       : OMPExecutableDirective(OMPMetaDirectiveClass,
@@ -5467,16 +5466,13 @@
                                llvm::omp::OMPD_metadirective, SourceLocation(),
                                SourceLocation()) {}
 
-  void setIfStmt(Stmt *S) { IfStmt = S; }
-
 public:
   static OMPMetaDirective *Create(const ASTContext &C, SourceLocation StartLoc,
                                   SourceLocation EndLoc,
-                                  ArrayRef<OMPClause *> Clauses,
-                                  Stmt *AssociatedStmt, Stmt *IfStmt);
+                                  ArrayRef<OMPClause *> Clauses);
+
   static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
                                        EmptyShell);
-  Stmt *getIfStmt() const { return IfStmt; }
 
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPMetaDirectiveClass;
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3700,6 +3700,18 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPWhenClause(OMPWhenClause *C) {
+  for (const OMPTraitSet &Set : C->getTraitInfo().Sets) {
+    for (const OMPTraitSelector &Selector : Set.Selectors) {
+      if (Selector.Kind == llvm::omp::TraitSelector::user_condition &&
+          Selector.ScoreOrCondition)
+        TRY_TO(TraverseStmt(Selector.ScoreOrCondition));
+    }
+  }
+  return true;
+}
+
 // FIXME: look at the following tricky-seeming exprs to see if we
 // need to recurse on anything.  These are ones that have methods
 // returning decls or qualtypes or nestednamespecifier -- though I'm
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -8612,6 +8612,80 @@
   }
 };
 
+/// This represents 'when' clause in the '#pragma omp metadirective'
+/// directive.
+///
+/// \code
+/// #pragma omp metadirective when(user={condition(N<10)}: parallel)
+/// \endcode
+/// In this example directive '#pragma omp metadirective' has simple 'when'
+/// clause with user defined condition.
+class OMPTraitInfo;
+class OMPWhenClause final : public OMPClause {
+  friend class OMPClauseReader;
+
+  OMPTraitInfo *TI;
+  Stmt *Directive;
+
+  /// Location of '('.
+  SourceLocation LParenLoc;
+
+public:
+  /// Build 'when' clause with arguments \a T for traits, \a D for the
+  /// associated directive.
+  ///
+  /// \param T TraitInfo containing information about the context selector
+  /// \param D The statement associated with the when clause
+  /// \param StartLoc Starting location of the clause.
+  /// \param LParenLoc Location of '('.
+  /// \param EndLoc Ending location of the clause.
+  OMPWhenClause(OMPTraitInfo &T, Stmt *D, SourceLocation StartLoc,
+                SourceLocation LParenLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), Directive(D),
+        LParenLoc(LParenLoc) {}
+
+  /// Build an empty clause.
+  OMPWhenClause()
+      : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {}
+
+  /// Sets the location of '('.
+  void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+  /// Returns the location of '('.
+  SourceLocation getLParenLoc() const { return LParenLoc; }
+
+  /// Returns the associated OpenMP directive.
+  Stmt *getDirective() const { return Directive; }
+
+  /// Set the associated OpenMP directive.
+  void setDirective(Stmt *S) { Directive = S; }
+
+  /// Returns the OMPTraitInfo
+  OMPTraitInfo &getTraitInfo() const { return *TI; }
+
+  /// Set the OMPTraitInfo
+  void setTraitInfo(OMPTraitInfo *T) { TI = T; }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_when;
+  }
+};
+
 /// This class implements a simple visitor for OMPClause
 /// subclasses.
 template<class ImplClass, template <typename> class Ptr, typename RetTy>
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to