koops created this revision.
koops added reviewers: dreachem, soumitra, tianshilei1992, cchen.
Herald added a subscriber: arphaman.
Herald added a project: All.
koops requested review of this revision.
Herald added a reviewer: jdoerfert.
Herald added subscribers: llvm-commits, cfe-commits, sstefan1.
Herald added projects: clang, LLVM.

This is a support for " #pragma atomic compare fail ". It has Parser & AST 
support for now.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D123235

Files:
  clang/include/clang/AST/ASTNodeTraverser.h
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Basic/OpenMPKinds.def
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/OpenMP/atomic_ast_print.cpp
  clang/test/OpenMP/atomic_messages.cpp
  clang/tools/libclang/CIndex.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -180,6 +180,7 @@
 def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; }
 def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; }
 def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; }
+def OMPC_Fail : Clause<"fail"> { let clangClass = "OMPFailClause"; }
 def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; }
 def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
@@ -545,7 +546,8 @@
     VersionedClause<OMPC_Acquire, 50>,
     VersionedClause<OMPC_Release, 50>,
     VersionedClause<OMPC_Relaxed, 50>,
-    VersionedClause<OMPC_Hint, 50>
+    VersionedClause<OMPC_Hint, 50>,
+    VersionedClause<OMPC_Fail, 51>
   ];
 }
 def OMP_Target : Directive<"target"> {
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2281,6 +2281,8 @@
 
 void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseEnqueue::VisitOMPFailClause(const OMPFailClause *) {}
+
 void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/test/OpenMP/atomic_messages.cpp
===================================================================
--- clang/test/OpenMP/atomic_messages.cpp
+++ clang/test/OpenMP/atomic_messages.cpp
@@ -958,6 +958,24 @@
 // expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'capture' clause}}
 #pragma omp atomic compare compare capture capture
   { v = a; if (a > b) a = b; }
+// expected-error@+1 {{expected 'compare' clause with the 'fail' modifier}}
+#pragma omp atomic fail(seq_cst) 
+  if(v == a) { v = a; }
+// expected-error@+1 {{expected '(' after 'fail'}}
+#pragma omp atomic compare fail
+  if(v < a) { v = a; }
+// expected-error@+1 {{expected a memory order clause}}
+#pragma omp atomic compare fail(capture)
+  if(v < a) { v = a; }
+ // expected-error@+2 {{expected ')' after 'atomic compare fail'}} 
+ // expected-warning@+1 {{extra tokens at the end of '#pragma omp atomic' are ignored}}
+#pragma omp atomic compare fail(seq_cst | acquire)
+  if(v < a) { v = a; }
+// expected-error@+1 {{directive '#pragma omp atomic' cannot contain more than one 'fail' clause}}
+#pragma omp atomic compare fail(relaxed) fail(seq_cst)
+  if(v < a) { v = a; }
+
+
 #endif
   // expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
   return mixed<int>();
Index: clang/test/OpenMP/atomic_ast_print.cpp
===================================================================
--- clang/test/OpenMP/atomic_ast_print.cpp
+++ clang/test/OpenMP/atomic_ast_print.cpp
@@ -226,6 +226,16 @@
   { v = a; if (a < b) { a = b; } }
 #pragma omp atomic compare capture hint(6)
   { v = a == b; if (v) a = c; }
+#pragma omp atomic compare fail(acq_rel)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(acquire)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(release)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(relaxed)
+  { if (a < c) { a = c; } }
+#pragma omp atomic compare fail(seq_cst)
+  { if (a < c) { a = c; } }
 #endif
   return T();
 }
@@ -1099,6 +1109,16 @@
   { v = a; if (a < b) { a = b; } }
 #pragma omp atomic compare capture hint(6)
   { v = a == b; if (v) a = c; }
+#pragma omp atomic compare fail(acq_rel)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(acquire)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(release)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(relaxed)
+  if(a < b) { a = b; }
+#pragma omp atomic compare fail(seq_cst)
+  if(a < b) { a = b; }
 #endif
   // CHECK-NEXT: #pragma omp atomic
   // CHECK-NEXT: a++;
@@ -1429,6 +1449,26 @@
   // CHECK-51-NEXT: if (v)
   // CHECK-51-NEXT: a = c;
   // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(acquire)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(acquire)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(relaxed)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(relaxed)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
+  // CHECK-51-NEXT: #pragma omp atomic compare fail(seq_cst)
+  // CHECK-51-NEXT: if (a < b) {
+  // CHECK-51-NEXT: a = b;
+  // CHECK-51-NEXT: }
   // expect-note@+1 {{in instantiation of function template specialization 'foo<int>' requested here}}
   return foo(a);
 }
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6336,6 +6336,14 @@
 
 void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {}
 
+void OMPClauseWriter::VisitOMPFailClause(OMPFailClause *C) {
+    //Record.AddSourceLocation(C->getLParenLoc());
+    //Copied from VisitOMPUpdateClause
+    Record.AddSourceLocation(C->getLParenLoc());
+    Record.AddSourceLocation(C->getArgumentLoc());
+    Record.writeEnum(C->getMemOrderClauseKind());
+}
+
 void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -11699,6 +11699,9 @@
   case llvm::omp::OMPC_compare:
     C = new (Context) OMPCompareClause();
     break;
+  case llvm::omp::OMPC_fail:
+    C = OMPFailClause::CreateEmpty(Context);
+    break;
   case llvm::omp::OMPC_seq_cst:
     C = new (Context) OMPSeqCstClause();
     break;
@@ -12059,6 +12062,31 @@
 
 void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {}
 
+void OMPClauseReader::VisitOMPFailClause(OMPFailClause *C) {
+    C->setLParenLoc(Record.readSourceLocation());
+    SourceLocation SourceLoc = Record.readSourceLocation();
+    C->setArgumentLoc(SourceLoc);
+    OpenMPAtomicDefaultMemOrderClauseKind CKind = Record.readEnum<OpenMPAtomicDefaultMemOrderClauseKind>();
+    C->setMemOrderClauseKind(CKind);
+
+    SourceLocation EndLoc;
+    OMPClause *MemoryOrderClause = NULL;
+    switch(CKind) {
+    case OMPC_ATOMIC_DEFAULT_MEM_ORDER_acquire:
+      MemoryOrderClause = new (Context) OMPAcquireClause(SourceLoc, EndLoc);
+      break;
+    case OMPC_ATOMIC_DEFAULT_MEM_ORDER_relaxed:
+      MemoryOrderClause = new (Context) OMPRelaxedClause(SourceLoc, EndLoc);
+      break;
+    case OMPC_ATOMIC_DEFAULT_MEM_ORDER_seq_cst: 
+      MemoryOrderClause = new (Context) OMPSeqCstClause(SourceLoc, EndLoc);
+      break;
+    default:
+      break;
+    }
+    C->setMemOrderClause(MemoryOrderClause);
+}
+
 void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -9523,6 +9523,13 @@
   return C;
 }
 
+template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPFailClause(OMPFailClause *C) {
+  // No need to rebuild this clause, no template-dependent parameters.
+  return C;
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPSeqCstClause(OMPSeqCstClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -11973,6 +11973,71 @@
 
   return checkType(ErrorInfo);
 }
+
+class OpenMPAtomicFailChecker {
+
+  protected:
+    Sema &SemaRef;
+    ASTContext &Context;
+
+  public:
+    // Error descriptor type which will be returned to Sema
+    unsigned int ErrorNo;
+
+    OpenMPAtomicFailChecker(Sema &S) : SemaRef(S), Context(S.getASTContext()) {}
+  public:
+    /// Check if all results conform with spec in terms of lvalue/rvalue
+    /// and scalar type.
+    bool checkSubClause(ArrayRef<OMPClause *> Clauses,
+                        SourceLocation *ErrorLoc);
+    /// Return the error descriptor that will guide the error message emission.
+    unsigned getErrorDesc() const { return ErrorNo; }
+};
+
+bool
+OpenMPAtomicFailChecker::checkSubClause(ArrayRef<OMPClause *> Clauses,
+                                        SourceLocation *ErrorLoc) {
+  int no_of_fails = 0;
+  ErrorNo = 0;
+  SourceLocation ClauseLoc;
+  for (const OMPClause *C : Clauses) {
+    if(C->getClauseKind() == OMPC_fail) {
+      no_of_fails++;
+      const OMPFailClause *fC = static_cast<const OMPFailClause *>(C);
+      //SourceLocation failClauseLoc = fC->getBeginLoc();
+      const OMPClause *memOrderC = fC->const_getMemoryOrder();
+      /* Clauses contains OMPC_fail and the subclause */
+      if(memOrderC) {
+        OpenMPClauseKind clauseKind = memOrderC->getClauseKind();
+        if((clauseKind == OMPC_acq_rel) ||
+           (clauseKind == OMPC_acquire) ||
+           (clauseKind == OMPC_relaxed) ||
+           (clauseKind == OMPC_release) ||
+           (clauseKind == OMPC_seq_cst)) {
+		switch(clauseKind) {
+                case OMPC_acq_rel : clauseKind = OMPC_acquire;
+				    break;
+                case OMPC_release : clauseKind = OMPC_relaxed;
+				    break;
+                default : break;
+		}
+            continue;
+        } else {
+          ErrorNo = diag::err_omp_atomic_fail_wrong_or_no_clauses;
+          *ErrorLoc = memOrderC->getBeginLoc();
+          continue;
+        }
+      }
+    }
+  }
+  if(no_of_fails > 1) {
+    ErrorNo = diag::err_omp_atomic_fail_extra_clauses;
+    *ErrorLoc = ClauseLoc;
+  }
+
+  return !ErrorNo;
+}
+
 } // namespace
 
 StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
@@ -11993,6 +12058,8 @@
   SourceLocation AtomicKindLoc;
   OpenMPClauseKind MemOrderKind = OMPC_unknown;
   SourceLocation MemOrderLoc;
+  llvm::omp::Clause SubClause = OMPC_unknown;
+  SourceLocation SubClauseLoc;
   bool MutexClauseEncountered = false;
   llvm::SmallSet<OpenMPClauseKind, 2> EncounteredAtomicKinds;
   for (const OMPClause *C : Clauses) {
@@ -12021,6 +12088,16 @@
       }
       break;
     }
+    case OMPC_fail: {
+      if (AtomicKind != OMPC_compare) {
+        Diag(C->getBeginLoc(), diag::err_omp_atomic_fail_no_compare)
+            << SourceRange(C->getBeginLoc(), C->getEndLoc());
+        return StmtError();
+      }
+      SubClause = OMPC_fail;
+      SubClauseLoc = C->getBeginLoc();
+      break;
+    }
     case OMPC_seq_cst:
     case OMPC_acq_rel:
     case OMPC_acquire:
@@ -12499,6 +12576,17 @@
       CE = Checker.getCond();
       // We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
       IsXLHSInRHSPart = Checker.isXBinopExpr();
+      if (SubClause == OMPC_fail) {
+        OpenMPAtomicFailChecker Checker(*this);
+        SourceLocation ErrorLoc, NoteLoc;
+        NoteLoc = ErrorLoc = Body->getBeginLoc();
+        ErrorLoc = SubClauseLoc;
+        if(!Checker.checkSubClause(Clauses,&ErrorLoc)) {
+          unsigned errorNo = Checker.getErrorDesc();
+          Diag(ErrorLoc, errorNo);
+          return StmtError();
+        }
+      }
     }
   }
 
@@ -16432,6 +16520,9 @@
   case OMPC_compare:
     Res = ActOnOpenMPCompareClause(StartLoc, EndLoc);
     break;
+  case OMPC_fail:
+    Res = ActOnOpenMPFailClause(StartLoc, EndLoc);
+    break;
   case OMPC_seq_cst:
     Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc);
     break;
@@ -16584,6 +16675,11 @@
   return new (Context) OMPCompareClause(StartLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPFailClause(SourceLocation StartLoc,
+                                       SourceLocation EndLoc) {
+  return OMPFailClause::Create(Context, StartLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                          SourceLocation EndLoc) {
   return new (Context) OMPSeqCstClause(StartLoc, EndLoc);
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -3233,6 +3233,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -3614,6 +3615,46 @@
       Val.getValue().Loc, Val.getValue().RLoc);
 }
 
+OMPClause *Parser::ParseOpenMPFailClause(OMPClause *clause) {
+
+  OMPFailClause *failClause = static_cast<OMPFailClause *>(clause);
+  SourceLocation LParenLoc;
+  if(Tok.is(tok::l_paren)) {
+    LParenLoc = Tok.getLocation();
+    ConsumeAnyToken();
+  } else {
+    Diag(diag::err_expected_lparen_after)
+        << getOpenMPClauseName(OMPC_fail);
+    return clause;
+  }
+
+
+  OpenMPClauseKind CKind = Tok.isAnnotation()
+				 ? OMPC_unknown
+				 : getOpenMPClauseKind(PP.getSpelling(Tok));
+  if(CKind == OMPC_unknown) {
+    Diag(diag::err_omp_expected_clause)
+        << ("atomic compare fail");
+    return clause;
+  }
+  OMPClause *MemoryOrderClause = ParseOpenMPClause(CKind, false);
+  SourceLocation MemOrderLoc;
+  //Store Memory Order SubClause for Sema.
+  if(MemoryOrderClause) {
+    MemOrderLoc = Tok.getLocation();
+  }
+
+  if(Tok.is(tok::r_paren)) {
+    failClause->initFailClause(LParenLoc,MemoryOrderClause,MemOrderLoc);
+    ConsumeAnyToken();
+  } else {
+    const IdentifierInfo *Arg = Tok.getIdentifierInfo();
+    Diag(Tok, diag::err_expected_rparen_after) << (Arg ? Arg->getName() : "atomic compare fail");
+  }
+
+  return clause;
+}
+
 /// Parsing of OpenMP clauses like 'ordered'.
 ///
 ///    ordered-clause:
@@ -3646,7 +3687,11 @@
 
   if (ParseOnly)
     return nullptr;
-  return Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
+  OMPClause *clause = Actions.ActOnOpenMPClause(Kind, Loc, Tok.getLocation());
+  if(Kind == llvm::omp::Clause::OMPC_fail) {
+    clause = ParseOpenMPFailClause(clause);
+  }
+  return clause;
 }
 
 /// Parsing of OpenMP clauses with single expressions and some additional
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -365,6 +365,16 @@
 #include "clang/Basic/OpenMPKinds.def"
     }
     llvm_unreachable("Invalid OpenMP 'depend' clause type");
+  case OMPC_fail:
+    switch (Type) {
+    case OMPC_ATOMIC_DEFAULT_MEM_ORDER_unknown:
+      return "unknown";
+#define OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(Name) 		               \
+  case OMPC_ATOMIC_DEFAULT_MEM_ORDER_##Name:				       \
+    return #Name;
+#include "clang/Basic/OpenMPKinds.def"
+    }
+    llvm_unreachable("Invalid OpenMP 'fail' clause type");
   case OMPC_device:
     switch (Type) {
     case OMPC_DEVICE_unknown:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -557,6 +557,8 @@
 
 void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseProfiler::VisitOMPFailClause(const OMPFailClause *) {}
+
 void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -127,6 +127,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -219,6 +220,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_fail:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -410,6 +412,25 @@
   return Clause;
 }
 
+OMPFailClause *
+OMPFailClause::Create(const ASTContext &C, SourceLocation StartLoc,
+                      SourceLocation EndLoc) {
+  void *Mem =
+      C.Allocate(totalSizeToAlloc<SourceLocation, OpenMPAtomicDefaultMemOrderClauseKind>(2, 1), alignof(OMPFailClause));
+  auto *Clause =
+      new (Mem) OMPFailClause(StartLoc, EndLoc);
+  return Clause;
+}
+
+OMPFailClause *
+OMPFailClause::CreateEmpty(const ASTContext &C) {
+  void *Mem =
+      C.Allocate(totalSizeToAlloc<SourceLocation, OpenMPAtomicDefaultMemOrderClauseKind>(2, 1), alignof(OMPFailClause));
+  auto *Clause =
+      new (Mem) OMPFailClause();
+  return Clause;
+}
+
 void OMPPrivateClause::setPrivateCopies(ArrayRef<Expr *> VL) {
   assert(VL.size() == varlist_size() &&
          "Number of private copies is not the same as the preallocated buffer");
@@ -1798,6 +1819,19 @@
   OS << "compare";
 }
 
+void OMPClausePrinter::VisitOMPFailClause(OMPFailClause *Node) {
+  OS << "fail";
+  //if(Node->getMemOrderClauseKind() == OMPC_ATOMIC_DEFAULT_MEM_ORDER_seq_cst) {
+    //printf("VisitOMPFailClause # seq_cst \n");
+  //}
+  if(Node) {
+    OS << "(";
+    OS << getOpenMPSimpleClauseTypeName(Node->getClauseKind(),
+                                        Node->getMemOrderClauseKind());
+    OS << ")";
+  } 
+}
+
 void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   OS << "seq_cst";
 }
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11295,6 +11295,9 @@
   /// Called on well-formed 'compare' clause.
   OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
+  /// Called on well-formed 'fail' clause.
+  OMPClause *ActOnOpenMPFailClause(SourceLocation StartLoc,
+                                   SourceLocation EndLoc);
   /// Called on well-formed 'seq_cst' clause.
   OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                      SourceLocation EndLoc);
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -431,6 +431,8 @@
   /// a statement expression and builds a suitable expression statement.
   StmtResult handleExprStmt(ExprResult E, ParsedStmtContext StmtCtx);
 
+  OMPClause *ParseOpenMPFailClause(OMPClause *clause);
+
 public:
   Parser(Preprocessor &PP, Sema &Actions, bool SkipFunctionBodies);
   ~Parser() override;
Index: clang/include/clang/Basic/OpenMPKinds.def
===================================================================
--- clang/include/clang/Basic/OpenMPKinds.def
+++ clang/include/clang/Basic/OpenMPKinds.def
@@ -113,9 +113,11 @@
 OPENMP_LINEAR_KIND(uval)
 
 // Modifiers for 'atomic_default_mem_order' clause.
-OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(seq_cst)
 OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(acq_rel)
+OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(acquire)
 OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(relaxed)
+OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(release)
+OPENMP_ATOMIC_DEFAULT_MEM_ORDER_KIND(seq_cst)
 
 // Map types for 'map' clause.
 OPENMP_MAP_KIND(alloc)
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10594,6 +10594,10 @@
   "expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|"
   "expect lvalue for result value|expect scalar value|expect integer value|unexpected 'else' statement|expect '==' operator|expect an assignment statement 'v = x'|"
   "expect a 'if' statement|expect no more than two statements|expect a compound statement|expect 'else' statement|expect a form 'r = x == e; if (r) ...'}0">;
+def err_omp_atomic_fail_wrong_or_no_clauses : Error<"expected a memory order clause">;
+def err_omp_atomic_fail_extra_mem_order_clauses : Error<"directive '#pragma omp atomic compare fail' cannot contain more than one memory order clause">;
+def err_omp_atomic_fail_extra_clauses : Error<"directive '#pragma omp atomic compare' cannot contain more than one fail clause">;
+def err_omp_atomic_fail_no_compare : Error<"expected 'compare' clause with the 'fail' modifier">;
 def err_omp_atomic_several_clauses : Error<
   "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3294,6 +3294,11 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPFailClause(OMPFailClause *) {
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -2266,6 +2266,139 @@
   }
 };
 
+/// This represents 'fail' clause in the '#pragma omp atomic'
+/// directive.
+///
+/// \code
+/// #pragma omp atomic compare fail
+/// \endcode
+/// In this example directive '#pragma omp atomic compare' has 'fail' clause.
+class OMPFailClause final : public OMPClause,
+      private llvm::TrailingObjects<OMPFailClause, SourceLocation,
+                                    OpenMPAtomicDefaultMemOrderClauseKind> {
+  OMPClause *MemoryOrderClause;
+
+  friend class OMPClauseReader;
+  friend TrailingObjects;
+
+  /// Define the sizes of each trailing object array except the last one. This
+  /// is required for TrailingObjects to work properly.
+  size_t numTrailingObjects(OverloadToken<SourceLocation>) const {
+    // 2 locations: for '(' and argument location.
+    return 2;
+  }
+
+  /// Sets the location of '(' in fail clause.
+  void setLParenLoc(SourceLocation Loc) {
+    *getTrailingObjects<SourceLocation>() = Loc;
+  }
+
+  /// Sets the location of memoryOrder clause argument in fail clause.
+  void setArgumentLoc(SourceLocation Loc) {
+    *std::next(getTrailingObjects<SourceLocation>(), 1) = Loc;
+  }
+
+  /// Sets the mem_order clause for 'atomic compare fail' directive.
+  void setMemOrderClauseKind(OpenMPAtomicDefaultMemOrderClauseKind memOrder) {
+    OpenMPAtomicDefaultMemOrderClauseKind *dMOCK = getTrailingObjects<OpenMPAtomicDefaultMemOrderClauseKind>();
+    //*getTrailingObjects<OpenMPAtomicDefaultMemOrderClauseKind>() = memOrder;
+    *dMOCK = memOrder;
+  }
+
+  /// Sets the mem_order clause for 'atomic compare fail' directive.
+  void setMemOrderClause(OMPClause *MemoryOrderClauseParam) {
+    MemoryOrderClause = MemoryOrderClauseParam;
+  }
+public:
+  /// Build 'fail' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPFailClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_fail, StartLoc, EndLoc) {}
+
+  /// Build an empty clause.
+  OMPFailClause()
+      : OMPClause(llvm::omp::OMPC_fail, SourceLocation(), SourceLocation()) {
+  }
+
+  static OMPFailClause *CreateEmpty(const ASTContext &C);
+  static OMPFailClause *Create(const ASTContext &C,
+                              SourceLocation StartLoc,
+                              SourceLocation EndLoc);
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_fail;
+  }
+
+  void
+  initFailClause(SourceLocation LParenLoc,
+                 OMPClause *memOClause,
+                 SourceLocation MemOrderLoc) {
+
+    setLParenLoc(LParenLoc);
+    MemoryOrderClause = memOClause;
+    setArgumentLoc(MemOrderLoc);
+
+    OpenMPAtomicDefaultMemOrderClauseKind memClauseKind = OMPC_ATOMIC_DEFAULT_MEM_ORDER_unknown;
+    OpenMPClauseKind clauseKind = MemoryOrderClause->getClauseKind();
+    switch(clauseKind) {
+      case llvm::omp::OMPC_acq_rel:
+      case llvm::omp::OMPC_acquire:
+        memClauseKind = OMPC_ATOMIC_DEFAULT_MEM_ORDER_acquire;
+	break;
+      case llvm::omp::OMPC_relaxed:
+      case llvm::omp::OMPC_release:
+        memClauseKind = OMPC_ATOMIC_DEFAULT_MEM_ORDER_relaxed;
+	break;
+      case llvm::omp::OMPC_seq_cst:
+        memClauseKind = OMPC_ATOMIC_DEFAULT_MEM_ORDER_seq_cst;
+	break;
+      default : break;
+    }
+    setMemOrderClauseKind(memClauseKind);
+  }
+
+  /// Gets the location of '(' in fail clause.
+  SourceLocation getLParenLoc() const {
+    return *getTrailingObjects<SourceLocation>();
+  }
+
+  OMPClause *getMemoryOrder() {
+    return MemoryOrderClause;
+  }
+
+  const OMPClause *const_getMemoryOrder() const {
+    return static_cast<const OMPClause *>(MemoryOrderClause);
+  }
+
+  /// Gets the location of memoryOrder clause argument in fail clause.
+  SourceLocation getArgumentLoc() const {
+    return *std::next(getTrailingObjects<SourceLocation>(), 1);
+  }
+
+  /// Gets the dependence kind in clause for 'depobj' directive.
+  OpenMPAtomicDefaultMemOrderClauseKind getMemOrderClauseKind() const {
+    return *getTrailingObjects<OpenMPAtomicDefaultMemOrderClauseKind>();
+  }
+};
+
 /// This represents 'seq_cst' clause in the '#pragma omp atomic'
 /// directive.
 ///
Index: clang/include/clang/AST/ASTNodeTraverser.h
===================================================================
--- clang/include/clang/AST/ASTNodeTraverser.h
+++ clang/include/clang/AST/ASTNodeTraverser.h
@@ -214,6 +214,10 @@
   }
 
   void Visit(const OMPClause *C) {
+    if(OMPFailClause::classof(C)) {
+      Visit(static_cast<const OMPFailClause *>(C));
+      return;
+    }
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(C);
       for (const auto *S : C->children())
@@ -221,6 +225,13 @@
     });
   }
 
+  void Visit(const OMPFailClause *C) {
+    getNodeDelegate().AddChild([=] {
+      getNodeDelegate().Visit(C);
+      const OMPClause *mOC = C->const_getMemoryOrder();
+      Visit(mOC);
+    });
+  }
   void Visit(const GenericSelectionExpr::ConstAssociation &A) {
     getNodeDelegate().AddChild([=] {
       getNodeDelegate().Visit(A);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to