tianshilei1992 created this revision.
tianshilei1992 added reviewers: jdoerfert, ABataev, carlo.bertolli.
Herald added subscribers: arphaman, guansong, yaxunl.
tianshilei1992 requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, sstefan1.
Herald added projects: clang, LLVM.

This patch adds the support for `atomic compare` in parser. The support
in Sema and CodeGen will come soon. For now, it simply eimits an error when it
is encountered.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D115561

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/tools/libclang/CIndex.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -180,6 +180,7 @@
 def OMPC_Write : Clause<"write"> { let clangClass = "OMPWriteClause"; }
 def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; }
 def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; }
+def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; }
 def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; }
 def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
@@ -536,6 +537,7 @@
     VersionedClause<OMPC_Write>,
     VersionedClause<OMPC_Update>,
     VersionedClause<OMPC_Capture>,
+    VersionedClause<OMPC_Compare, 51>
   ];
   let allowedOnceClauses = [
     VersionedClause<OMPC_SeqCst>,
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2273,6 +2273,8 @@
 
 void OMPClauseEnqueue::VisitOMPCaptureClause(const OMPCaptureClause *) {}
 
+void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {}
+
 void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6248,6 +6248,8 @@
 
 void OMPClauseWriter::VisitOMPCaptureClause(OMPCaptureClause *) {}
 
+void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {}
+
 void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -11761,6 +11761,9 @@
   case llvm::omp::OMPC_capture:
     C = new (Context) OMPCaptureClause();
     break;
+  case llvm::omp::OMPC_compare:
+    C = new (Context) OMPCompareClause();
+    break;
   case llvm::omp::OMPC_seq_cst:
     C = new (Context) OMPSeqCstClause();
     break;
@@ -12119,6 +12122,8 @@
 
 void OMPClauseReader::VisitOMPCaptureClause(OMPCaptureClause *) {}
 
+void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {}
+
 void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -9429,6 +9429,13 @@
   return C;
 }
 
+template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPCompareClause(OMPCompareClause *C) {
+  // No need to rebuild this clause, no template-dependent parameters.
+  return C;
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPSeqCstClause(OMPSeqCstClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -6371,6 +6371,7 @@
       case OMPC_write:
       case OMPC_update:
       case OMPC_capture:
+      case OMPC_compare:
       case OMPC_seq_cst:
       case OMPC_acq_rel:
       case OMPC_acquire:
@@ -10950,24 +10951,30 @@
   OpenMPClauseKind MemOrderKind = OMPC_unknown;
   SourceLocation MemOrderLoc;
   for (const OMPClause *C : Clauses) {
-    if (C->getClauseKind() == OMPC_read || C->getClauseKind() == OMPC_write ||
-        C->getClauseKind() == OMPC_update ||
-        C->getClauseKind() == OMPC_capture) {
-      if (AtomicKind != OMPC_unknown) {
+    switch (C->getClauseKind()) {
+    default:
+      llvm_unreachable("unknow clause for atomic directive");
+    case OMPC_read:
+    case OMPC_write:
+    case OMPC_update:
+    case OMPC_capture:
+    case OMPC_compare:
+      if (AtomicKind == OMPC_unknown) {
+        AtomicKind = C->getClauseKind();
+        AtomicKindLoc = C->getBeginLoc();
+        break;
+      } else {
         Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses)
             << SourceRange(C->getBeginLoc(), C->getEndLoc());
         Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause)
             << getOpenMPClauseName(AtomicKind);
-      } else {
-        AtomicKind = C->getClauseKind();
-        AtomicKindLoc = C->getBeginLoc();
+        return StmtError();
       }
-    }
-    if (C->getClauseKind() == OMPC_seq_cst ||
-        C->getClauseKind() == OMPC_acq_rel ||
-        C->getClauseKind() == OMPC_acquire ||
-        C->getClauseKind() == OMPC_release ||
-        C->getClauseKind() == OMPC_relaxed) {
+    case OMPC_seq_cst:
+    case OMPC_acq_rel:
+    case OMPC_acquire:
+    case OMPC_release:
+    case OMPC_relaxed:
       if (MemOrderKind != OMPC_unknown) {
         Diag(C->getBeginLoc(), diag::err_omp_several_mem_order_clauses)
             << getOpenMPDirectiveName(OMPD_atomic) << 0
@@ -10978,8 +10985,10 @@
         MemOrderKind = C->getClauseKind();
         MemOrderLoc = C->getBeginLoc();
       }
+      break;
     }
   }
+
   // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions
   // If atomic-clause is read then memory-order-clause must not be acq_rel or
   // release.
@@ -11007,10 +11016,18 @@
   if (auto *EWC = dyn_cast<ExprWithCleanups>(Body))
     Body = EWC->getSubExpr();
 
+  // Stands for 'x' in the spec
   Expr *X = nullptr;
+  // Stands for 'v' in the spec
   Expr *V = nullptr;
+  // Stands for 'd' in the spec
+  Expr *D = nullptr;
+  // Stands for 'e' or 'expr' in the spec
   Expr *E = nullptr;
+  // Stands for update-stmt in the spec
   Expr *UE = nullptr;
+  // Stands for conditional statement in the spec
+  Expr *CE = nullptr;
   bool IsXLHSInRHSPart = false;
   bool IsPostfixUpdate = false;
   // OpenMP [2.12.6, atomic Construct]
@@ -11092,8 +11109,8 @@
     if (ErrorFound != NoError) {
       Diag(ErrorLoc, diag::err_omp_atomic_read_not_expression_statement)
           << ErrorRange;
-      Diag(NoteLoc, diag::note_omp_atomic_read_write) << ErrorFound
-                                                      << NoteRange;
+      Diag(NoteLoc, diag::note_omp_atomic_read_write)
+          << ErrorFound << NoteRange;
       return StmtError();
     }
     if (CurContext->isDependentContext())
@@ -11154,8 +11171,8 @@
     if (ErrorFound != NoError) {
       Diag(ErrorLoc, diag::err_omp_atomic_write_not_expression_statement)
           << ErrorRange;
-      Diag(NoteLoc, diag::note_omp_atomic_read_write) << ErrorFound
-                                                      << NoteRange;
+      Diag(NoteLoc, diag::note_omp_atomic_read_write)
+          << ErrorFound << NoteRange;
       return StmtError();
     }
     if (CurContext->isDependentContext())
@@ -11171,9 +11188,10 @@
     //  x = expr binop x;
     OpenMPAtomicUpdateChecker Checker(*this);
     if (Checker.checkStatement(
-            Body, (AtomicKind == OMPC_update)
-                      ? diag::err_omp_atomic_update_not_expression_statement
-                      : diag::err_omp_atomic_not_expression_statement,
+            Body,
+            (AtomicKind == OMPC_update)
+                ? diag::err_omp_atomic_update_not_expression_statement
+                : diag::err_omp_atomic_not_expression_statement,
             diag::note_omp_atomic_update))
       return StmtError();
     if (!CurContext->isDependentContext()) {
@@ -11387,21 +11405,25 @@
             SourceRange(Body->getBeginLoc(), Body->getBeginLoc());
         ErrorFound = NotACompoundStatement;
       }
-      if (ErrorFound != NoError) {
-        Diag(ErrorLoc, diag::err_omp_atomic_capture_not_compound_statement)
-            << ErrorRange;
-        Diag(NoteLoc, diag::note_omp_atomic_capture) << ErrorFound << NoteRange;
-        return StmtError();
-      }
-      if (CurContext->isDependentContext())
-        UE = V = E = X = nullptr;
     }
+    if (ErrorFound != NoError) {
+      Diag(ErrorLoc, diag::err_omp_atomic_capture_not_compound_statement)
+          << ErrorRange;
+      Diag(NoteLoc, diag::note_omp_atomic_capture) << ErrorFound << NoteRange;
+      return StmtError();
+    }
+    if (CurContext->isDependentContext())
+      UE = V = E = X = nullptr;
+  } else if (AtomicKind == OMPC_compare) {
+    // TODO: For now we just return error
+    Diag(Body->getBeginLoc(), diag::err_omp_atomic_compare);
+      return StmtError();
   }
 
   setFunctionHasBranchProtectedScope();
 
   return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    X, V, E, UE, IsXLHSInRHSPart,
+                                    X, V, E, D, UE, CE, IsXLHSInRHSPart,
                                     IsPostfixUpdate);
 }
 
@@ -13476,6 +13498,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14307,6 +14330,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14768,6 +14792,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -15073,6 +15098,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -15261,6 +15287,9 @@
   case OMPC_capture:
     Res = ActOnOpenMPCaptureClause(StartLoc, EndLoc);
     break;
+  case OMPC_compare:
+    Res = ActOnOpenMPCompareClause(StartLoc, EndLoc);
+    break;
   case OMPC_seq_cst:
     Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc);
     break;
@@ -15407,6 +15436,11 @@
   return new (Context) OMPCaptureClause(StartLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPCompareClause(SourceLocation StartLoc,
+                                          SourceLocation EndLoc) {
+  return new (Context) OMPCompareClause(StartLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                          SourceLocation EndLoc) {
   return new (Context) OMPSeqCstClause(StartLoc, EndLoc);
@@ -15875,6 +15909,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -3192,6 +3192,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -5702,6 +5702,16 @@
   return std::make_pair(true, RValue::get(Res));
 }
 
+static std::pair<bool, RValue> emitOMPAtomicCmpXchg(CodeGenFunction &CGF,
+                                                    LValue X, RValue E,
+                                                    RValue D,
+                                                    llvm::AtomicOrdering AO) {
+  llvm::Value *Res = CGF.Builder.CreateAtomicCmpXchg(
+      X.getPointer(CGF), E.getScalarVal(), D.getScalarVal(), AO, AO);
+  llvm::Value *Old = CGF.Builder.CreateExtractValue(Res, 0);
+  return std::make_pair(true, RValue::get(Old));
+}
+
 std::pair<bool, RValue> CodeGenFunction::EmitOMPAtomicSimpleUpdateExpr(
     LValue X, RValue E, BinaryOperatorKind BO, bool IsXLHSInRHSPart,
     llvm::AtomicOrdering AO, SourceLocation Loc,
@@ -5901,8 +5911,8 @@
 static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
                               llvm::AtomicOrdering AO, bool IsPostfixUpdate,
                               const Expr *X, const Expr *V, const Expr *E,
-                              const Expr *UE, bool IsXLHSInRHSPart,
-                              SourceLocation Loc) {
+                              const Expr *D, const Expr *UE, const Expr *CE,
+                              bool IsXLHSInRHSPart, SourceLocation Loc) {
   switch (Kind) {
   case OMPC_read:
     emitOMPAtomicReadExpr(CGF, AO, X, V, Loc);
@@ -5918,6 +5928,10 @@
     emitOMPAtomicCaptureExpr(CGF, AO, IsPostfixUpdate, V, X, E, UE,
                              IsXLHSInRHSPart, Loc);
     break;
+  case OMPC_compare:
+    llvm_unreachable("atomic compare is not supported yet, and this should "
+                     "never be reached.");
+    break;
   case OMPC_if:
   case OMPC_final:
   case OMPC_num_threads:
@@ -6061,8 +6075,8 @@
   LexicalScope Scope(*this, S.getSourceRange());
   EmitStopPoint(S.getAssociatedStmt());
   emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(),
-                    S.getExpr(), S.getUpdateExpr(), S.isXLHSInRHSPart(),
-                    S.getBeginLoc());
+                    S.getE(), S.getD(), S.getUpdateExpr(), S.getCondExpr(),
+                    S.isXLHSInRHSPart(), S.getBeginLoc());
 }
 
 static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -163,6 +163,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -428,6 +429,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -551,6 +551,8 @@
 
 void OMPClauseProfiler::VisitOMPCaptureClause(const OMPCaptureClause *) {}
 
+void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}
+
 void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -863,16 +863,20 @@
                                                    !IsStandalone);
 }
 
-OMPAtomicDirective *OMPAtomicDirective::Create(
-    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
-    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-    Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate) {
+OMPAtomicDirective *
+OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                           SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
+                           Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E,
+                           Expr *D, Expr *UE, Expr *CE, bool IsXLHSInRHSPart,
+                           bool IsPostfixUpdate) {
   auto *Dir = createDirective<OMPAtomicDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/4, StartLoc, EndLoc);
+      C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc);
   Dir->setX(X);
   Dir->setV(V);
-  Dir->setExpr(E);
+  Dir->setE(E);
+  Dir->setD(D);
   Dir->setUpdateExpr(UE);
+  Dir->setCondExpr(CE);
   Dir->IsXLHSInRHSPart = IsXLHSInRHSPart;
   Dir->IsPostfixUpdate = IsPostfixUpdate;
   return Dir;
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -126,6 +126,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -217,6 +218,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -1792,6 +1794,10 @@
   OS << "capture";
 }
 
+void OMPClausePrinter::VisitOMPCompareClause(OMPCompareClause *) {
+  OS << "compare";
+}
+
 void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   OS << "seq_cst";
 }
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11190,6 +11190,9 @@
   /// Called on well-formed 'capture' clause.
   OMPClause *ActOnOpenMPCaptureClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
+  /// Called on well-formed 'compare' clause.
+  OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc,
+                                      SourceLocation EndLoc);
   /// Called on well-formed 'seq_cst' clause.
   OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                      SourceLocation EndLoc);
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10485,10 +10485,11 @@
   " '{v = x; x = x binop expr;}', '{v = x; x = expr binop x;}', '{x = x binop expr; v = x;}', '{x = expr binop x; v = x;}' or '{v = x; x = expr;}',"
   " '{v = x; x++;}', '{v = x; ++x;}', '{++x; v = x;}', '{x++; v = x;}', '{v = x; x--;}', '{v = x; --x;}', '{--x; v = x;}', '{x--; v = x;}'"
   " where x is an lvalue expression with scalar type">;
+def err_omp_atomic_compare : Error<"atomic compare is not supported for now">;
 def note_omp_atomic_capture: Note<
   "%select{expected assignment expression|expected compound statement|expected exactly two expression statements|expected in right hand side of the first expression}0">;
 def err_omp_atomic_several_clauses : Error<
-  "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update' or 'capture' clause">;
+  "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<
   "directive '#pragma omp %0' cannot contain more than one %select{'seq_cst', 'relaxed', |}1'acq_rel', 'acquire' or 'release' clause">;
 def err_omp_atomic_incompatible_mem_order_clause : Error<
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -2863,20 +2863,29 @@
     POS_V,
     POS_E,
     POS_UpdateExpr,
+    POS_D,
+    POS_CondExpr
   };
 
   /// Set 'x' part of the associated expression/statement.
   void setX(Expr *X) { Data->getChildren()[DataPositionTy::POS_X] = X; }
+  /// Set 'v' part of the associated expression/statement.
+  void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
+  /// Set 'd' part of the associated expression/statement.
+  void setD(Expr *D) { Data->getChildren()[DataPositionTy::POS_D] = D; }
+  /// Set 'e' or 'expr' part of the associated expression/statement.
+  void setE(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
   /// Set helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
   void setUpdateExpr(Expr *UE) {
     Data->getChildren()[DataPositionTy::POS_UpdateExpr] = UE;
   }
-  /// Set 'v' part of the associated expression/statement.
-  void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
-  /// Set 'expr' part of the associated expression/statement.
-  void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
+  /// Set the conditional expression part of the associated expression/statement
+  /// in atomic compare.
+  void setCondExpr(Expr *CE) {
+    Data->getChildren()[DataPositionTy::POS_CondExpr] = CE;
+  }
 
 public:
   /// Creates directive with a list of \a Clauses and 'x', 'v' and 'expr'
@@ -2901,7 +2910,8 @@
   static OMPAtomicDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-         Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate);
+         Expr *E, Expr *D, Expr *Ex, Expr *UE, bool IsXLHSInRHSPart,
+         bool IsPostfixUpdate);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -2919,6 +2929,27 @@
   const Expr *getX() const {
     return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_X]);
   }
+  /// Get 'd' part of the associated expression/statement.
+  Expr *getD() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  const Expr *getD() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  /// Get 'e' part of the associated expression/statement.
+  Expr *getE() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
+  }
+  const Expr *getE() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
+  }
+  /// Get 'v' part of the associated expression/statement.
+  Expr *getV() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
+  }
+  const Expr *getV() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
+  }
   /// Get helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
@@ -2930,6 +2961,15 @@
     return cast_or_null<Expr>(
         Data->getChildren()[DataPositionTy::POS_UpdateExpr]);
   }
+  /// Get the conditional expression.
+  Expr *getCondExpr() {
+    return cast_or_null<Expr>(
+        Data->getChildren()[DataPositionTy::POS_CondExpr]);
+  }
+  const Expr *getCondExpr() const {
+    return cast_or_null<Expr>(
+        Data->getChildren()[DataPositionTy::POS_CondExpr]);
+  }
   /// Return true if helper update expression has form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' and false if it has form
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
@@ -2937,20 +2977,6 @@
   /// Return true if 'v' expression must be updated to original value of
   /// 'x', false if 'v' must be updated to the new value of 'x'.
   bool isPostfixUpdate() const { return IsPostfixUpdate; }
-  /// Get 'v' part of the associated expression/statement.
-  Expr *getV() {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
-  }
-  const Expr *getV() const {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
-  }
-  /// Get 'expr' part of the associated expression/statement.
-  Expr *getExpr() {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
-  }
-  const Expr *getExpr() const {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
-  }
 
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPAtomicDirectiveClass;
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3226,6 +3226,11 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPCompareClause(OMPCompareClause *) {
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -2224,6 +2224,47 @@
   }
 };
 
+/// This represents 'compare' clause in the '#pragma omp atomic'
+/// directive.
+///
+/// \code
+/// #pragma omp atomic compare
+/// \endcode
+/// In this example directive '#pragma omp atomic' has 'compare' clause.
+class OMPCompareClause : public OMPClause {
+public:
+  /// Build 'compare' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPCompareClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_compare, StartLoc, EndLoc) {}
+
+  /// Build an empty clause.
+  OMPCompareClause()
+      : OMPClause(llvm::omp::OMPC_compare, SourceLocation(), SourceLocation()) {
+  }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_compare;
+  }
+};
+
 /// This represents 'seq_cst' clause in the '#pragma omp atomic'
 /// directive.
 ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to