tianshilei1992 updated this revision to Diff 370036.
tianshilei1992 added a comment.

rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D102449/new/

https://reviews.llvm.org/D102449

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/include/clang/AST/StmtOpenMP.h
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/OpenMP/atomic_compare_codegen.cpp
  clang/tools/libclang/CIndex.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -160,6 +160,7 @@
 def OMPC_Write : Clause<"write"> { let clangClass = "OMPWriteClause"; }
 def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; }
 def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; }
+def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; }
 def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; }
 def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
@@ -500,6 +501,7 @@
     VersionedClause<OMPC_Write>,
     VersionedClause<OMPC_Update>,
     VersionedClause<OMPC_Capture>,
+    VersionedClause<OMPC_Compare, 51>
   ];
   let allowedOnceClauses = [
     VersionedClause<OMPC_SeqCst>,
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2271,6 +2271,8 @@
 
 void OMPClauseEnqueue::VisitOMPCaptureClause(const OMPCaptureClause *) {}
 
+void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {}
+
 void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/test/OpenMP/atomic_compare_codegen.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/atomic_compare_codegen.cpp
@@ -0,0 +1,374 @@
+
+// RUN: %clang_cc1 -verify -triple x86_64-apple-darwin10 -target-cpu core2 -fopenmp -fopenmp-version=51 -x c -emit-llvm %s -o - | FileCheck %s
+// RUN: %clang_cc1 -fopenmp -fopenmp-version=51 -x c -triple x86_64-apple-darwin10 -target-cpu core2 -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp -fopenmp-version=51 -x c -triple x86_64-apple-darwin10 -target-cpu core2 -include-pch %t -verify %s -emit-llvm -o - | FileCheck %s
+
+// RUN: %clang_cc1 -verify -triple x86_64-apple-darwin10 -target-cpu core2 -fopenmp-simd -fopenmp-version=51 -x c -emit-llvm %s -o - | FileCheck --check-prefix SIMD-ONLY0 %s
+// RUN: %clang_cc1 -fopenmp-simd -fopenmp-version=51 -x c -triple x86_64-apple-darwin10 -target-cpu core2 -emit-pch -o %t %s
+// RUN: %clang_cc1 -fopenmp-simd -fopenmp-version=51 -x c -triple x86_64-apple-darwin10 -target-cpu core2 -include-pch %t -verify %s -emit-llvm -o - | FileCheck --check-prefix SIMD-ONLY0 %s
+// SIMD-ONLY0-NOT: {{__kmpc|__tgt}}
+// expected-no-diagnostics
+
+// Integral types
+char cx, ce, cd;
+unsigned char ucx, uce, ucd;
+short sx, se, sd;
+unsigned short usx, use, usd;
+int ix, ie, id;
+unsigned int uix, uie, uid;
+long lx, le, ld;
+unsigned long ulx, ule, uld;
+long long llx, lle, lld;
+unsigned long long ullx, ulle, ulld;
+
+void foo() {
+// char
+// {
+
+#pragma omp atomic compare
+  cx = ce > cx ? ce : cx;
+
+#pragma omp atomic compare
+  cx = ce < cx ? ce : cx;
+
+#pragma omp atomic compare
+  cx = cx > ce ? ce : cx;
+
+#pragma omp atomic compare
+  cx = cx < ce ? ce : cx;
+
+#pragma omp atomic compare
+  cx = cx == ce ? cd : cx;
+
+#pragma omp atomic compare
+  if (ce > cx) { cx = ce; }
+
+#pragma omp atomic compare
+  if (ce < cx) { cx = ce; }
+
+#pragma omp atomic compare
+  if (cx > ce) { cx = ce; }
+
+#pragma omp atomic compare
+  if (cx < ce) { cx = ce; }
+
+#pragma omp atomic compare
+  if (cx == ce) { cx = cd; }
+
+// }
+
+// unsigned char
+// {
+
+#pragma omp atomic compare
+  ucx = uce > ucx ? uce : ucx;
+
+#pragma omp atomic compare
+  ucx = uce < ucx ? uce : ucx;
+
+#pragma omp atomic compare
+  ucx = ucx > uce ? uce : ucx;
+
+#pragma omp atomic compare
+  ucx = ucx < uce ? uce : ucx;
+
+#pragma omp atomic compare
+  ucx = ucx == uce ? ucd : ucx;
+
+#pragma omp atomic compare
+  if (uce > ucx) { ucx = uce; }
+
+#pragma omp atomic compare
+  if (uce < ucx) { ucx = uce; }
+
+#pragma omp atomic compare
+  if (ucx > uce) { ucx = uce; }
+
+#pragma omp atomic compare
+  if (ucx < uce) { ucx = uce; }
+
+#pragma omp atomic compare
+  if (ucx == uce) { ucx = ucd; }
+
+// }
+
+// short
+// {
+
+#pragma omp atomic compare
+  sx = se > sx ? se : sx;
+
+#pragma omp atomic compare
+  sx = se < sx ? se : sx;
+
+#pragma omp atomic compare
+  sx = sx > se ? se : sx;
+
+#pragma omp atomic compare
+  sx = sx < se ? se : sx;
+
+#pragma omp atomic compare
+  sx = sx == se ? sd : sx;
+
+#pragma omp atomic compare
+  if (se > sx) { sx = se; }
+
+#pragma omp atomic compare
+  if (se < sx) { sx = se; }
+
+#pragma omp atomic compare
+  if (sx > se) { sx = se; }
+
+#pragma omp atomic compare
+  if (sx < se) { sx = se; }
+
+#pragma omp atomic compare
+  if (sx == se) { sx = sd; }
+
+// }
+
+// unsigned short
+// {
+
+#pragma omp atomic compare
+  usx = use > usx ? use : usx;
+
+#pragma omp atomic compare
+  usx = use < usx ? use : usx;
+
+#pragma omp atomic compare
+  usx = usx > use ? use : usx;
+
+#pragma omp atomic compare
+  usx = usx < use ? use : usx;
+
+#pragma omp atomic compare
+  usx = usx == use ? usd : usx;
+
+#pragma omp atomic compare
+  if (use > usx) { usx = use; }
+
+#pragma omp atomic compare
+  if (use < usx) { usx = use; }
+
+#pragma omp atomic compare
+  if (usx > use) { usx = use; }
+
+#pragma omp atomic compare
+  if (usx < use) { usx = use; }
+
+#pragma omp atomic compare
+  if (usx == use) { usx = usd; }
+
+// }
+
+// int
+// {
+
+#pragma omp atomic compare
+  ix = ie > ix ? ie : ix;
+
+#pragma omp atomic compare
+  ix = ie < ix ? ie : ix;
+
+#pragma omp atomic compare
+  ix = ix > ie ? ie : ix;
+
+#pragma omp atomic compare
+  ix = ix < ie ? ie : ix;
+
+#pragma omp atomic compare
+  ix = ix == ie ? id : ix;
+
+#pragma omp atomic compare
+  if (ie > ix) { ix = ie; }
+
+#pragma omp atomic compare
+  if (ie < ix) { ix = ie; }
+
+#pragma omp atomic compare
+  if (ix > ie) { ix = ie; }
+
+#pragma omp atomic compare
+  if (ix < ie) { ix = ie; }
+
+#pragma omp atomic compare
+  if (ix == ie) { ix = id; }
+
+// }
+
+// unsigned int
+// {
+
+#pragma omp atomic compare
+  uix = uie > uix ? uie : uix;
+
+#pragma omp atomic compare
+  uix = uie < uix ? uie : uix;
+
+#pragma omp atomic compare
+  uix = uix > uie ? uie : uix;
+
+#pragma omp atomic compare
+  uix = uix < uie ? uie : uix;
+
+#pragma omp atomic compare
+  uix = uix == uie ? uid : uix;
+
+#pragma omp atomic compare
+  if (uie > uix) { uix = uie; }
+
+#pragma omp atomic compare
+  if (uie < uix) { uix = uie; }
+
+#pragma omp atomic compare
+  if (uix > uie) { uix = uie; }
+
+#pragma omp atomic compare
+  if (uix < uie) { uix = uie; }
+
+#pragma omp atomic compare
+  if (uix == uie) { uix = uid; }
+
+// }
+
+// long
+// {
+
+#pragma omp atomic compare
+  lx = le > lx ? le : lx;
+
+#pragma omp atomic compare
+  lx = le < lx ? le : lx;
+
+#pragma omp atomic compare
+  lx = lx > le ? le : lx;
+
+#pragma omp atomic compare
+  lx = lx < le ? le : lx;
+
+#pragma omp atomic compare
+  lx = lx == le ? ld : lx;
+
+#pragma omp atomic compare
+  if (le > lx) { lx = le; }
+
+#pragma omp atomic compare
+  if (le < lx) { lx = le; }
+
+#pragma omp atomic compare
+  if (lx > le) { lx = le; }
+
+#pragma omp atomic compare
+  if (lx < le) { lx = le; }
+
+#pragma omp atomic compare
+  if (lx == le) { lx = ld; }
+
+// }
+
+// unsigned long
+// {
+
+#pragma omp atomic compare
+  ulx = ule > ulx ? ule : ulx;
+
+#pragma omp atomic compare
+  ulx = ule < ulx ? ule : ulx;
+
+#pragma omp atomic compare
+  ulx = ulx > ule ? ule : ulx;
+
+#pragma omp atomic compare
+  ulx = ulx < ule ? ule : ulx;
+
+#pragma omp atomic compare
+  ulx = ulx == ule ? uld : ulx;
+
+#pragma omp atomic compare
+  if (ule > ulx) { ulx = ule; }
+
+#pragma omp atomic compare
+  if (ule < ulx) { ulx = ule; }
+
+#pragma omp atomic compare
+  if (ulx > ule) { ulx = ule; }
+
+#pragma omp atomic compare
+  if (ulx < ule) { ulx = ule; }
+
+#pragma omp atomic compare
+  if (ulx == ule) { ulx = uld; }
+
+// }
+
+// long long
+// {
+
+#pragma omp atomic compare
+  llx = lle > llx ? lle : llx;
+
+#pragma omp atomic compare
+  llx = lle < llx ? lle : llx;
+
+#pragma omp atomic compare
+  llx = llx > lle ? lle : llx;
+
+#pragma omp atomic compare
+  llx = llx < lle ? lle : llx;
+
+#pragma omp atomic compare
+  llx = llx == lle ? lld : llx;
+
+#pragma omp atomic compare
+  if (lle > llx) { llx = lle; }
+
+#pragma omp atomic compare
+  if (lle < llx) { llx = lle; }
+
+#pragma omp atomic compare
+  if (llx > lle) { llx = lle; }
+
+#pragma omp atomic compare
+  if (llx < lle) { llx = lle; }
+
+#pragma omp atomic compare
+  if (llx == lle) { llx = lld; }
+
+// }
+
+// unsigned long long
+// {
+
+#pragma omp atomic compare
+  ullx = ulle > ullx ? ulle : ullx;
+
+#pragma omp atomic compare
+  ullx = ulle < ullx ? ulle : ullx;
+
+#pragma omp atomic compare
+  ullx = ullx > ulle ? ulle : ullx;
+
+#pragma omp atomic compare
+  ullx = ullx < ulle ? ulle : ullx;
+
+#pragma omp atomic compare
+  ullx = ullx == ulle ? ulld : ullx;
+
+#pragma omp atomic compare
+  if (ulle > ullx) { ullx = ulle; }
+
+#pragma omp atomic compare
+  if (ulle < ullx) { ullx = ulle; }
+
+#pragma omp atomic compare
+  if (ullx > ulle) { ullx = ulle; }
+
+#pragma omp atomic compare
+  if (ullx < ulle) { ullx = ulle; }
+
+#pragma omp atomic compare
+  if (ullx == ulle) { ullx = ulld; }
+
+// }
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6135,6 +6135,8 @@
 
 void OMPClauseWriter::VisitOMPCaptureClause(OMPCaptureClause *) {}
 
+void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {}
+
 void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -11748,6 +11748,9 @@
   case llvm::omp::OMPC_capture:
     C = new (Context) OMPCaptureClause();
     break;
+  case llvm::omp::OMPC_compare:
+    C = new (Context) OMPCompareClause();
+    break;
   case llvm::omp::OMPC_seq_cst:
     C = new (Context) OMPSeqCstClause();
     break;
@@ -12100,6 +12103,8 @@
 
 void OMPClauseReader::VisitOMPCaptureClause(OMPCaptureClause *) {}
 
+void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {}
+
 void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -9381,6 +9381,13 @@
   return C;
 }
 
+template <typename Derived>
+OMPClause *
+TreeTransform<Derived>::TransformOMPCompareClause(OMPCompareClause *C) {
+  // No need to rebuild this clause, no template-dependent parameters.
+  return C;
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPSeqCstClause(OMPSeqCstClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -6274,6 +6274,7 @@
       case OMPC_write:
       case OMPC_update:
       case OMPC_capture:
+      case OMPC_compare:
       case OMPC_seq_cst:
       case OMPC_acq_rel:
       case OMPC_acquire:
@@ -10514,8 +10515,367 @@
   bool checkBinaryOperation(BinaryOperator *AtomicBinOp, unsigned DiagId = 0,
                             unsigned NoteId = 0);
 };
+
+/// Get the node id of the fixed point of corresponding statement
+llvm::FoldingSetNodeID getNodeId(ASTContext &Context, const Expr *S) {
+  llvm::FoldingSetNodeID Id;
+  S->IgnoreParenImpCasts()->Profile(Id, Context, true);
+  return Id;
+}
+
+/// Check if two statemenmts are the same fixed point statements
+bool checkIfTwoExprsAreSame(ASTContext &Context, const Expr *LHS,
+                            const Expr *RHS) {
+  return getNodeId(Context, LHS) == getNodeId(Context, RHS);
+}
+
+// OpenMP 5.1 [2.19.7 atomic Construct]
+// cond-update-stmt, a conditional update statement that has one of the
+// following forms:
+// if (expr ordop x) { x = expr; }
+// if (x ordop expr) { x = expr; }
+// if (x == e) { x = d; }
+bool checkIfCondUpdateStmt(ASTContext &Context, IfStmt *IS, Expr *&X, Expr *&E,
+                           Expr *&D, Expr *&C) {
+  auto *Cond = dyn_cast<BinaryOperator>(IS->getCond());
+  auto *Then = dyn_cast<BinaryOperator>(IS->getThen());
+  // Then can be compound statement
+  if (!Then && isa<CompoundStmt>(IS->getThen())) {
+    auto *CS = cast<CompoundStmt>(IS->getThen());
+    if (CS->size() != 1)
+      return false;
+    Then = dyn_cast<BinaryOperator>(*CS->child_begin());
+  }
+
+  if (!Cond || !Then)
+    return false;
+
+  // The if-stmt cannot have else statement
+  if (IS->getElse())
+    return false;
+
+  // Only ==, <, > are supported
+  if (Cond->getOpcode() != BO_EQ && Cond->getOpcode() != BO_LT &&
+      Cond->getOpcode() != BO_GT)
+    return false;
+
+  X = Then->getLHS();
+  C = Cond;
+
+  llvm::FoldingSetNodeID XId = getNodeId(Context, X);
+  llvm::FoldingSetNodeID LHSId = getNodeId(Context, Cond->getLHS());
+  llvm::FoldingSetNodeID RHSId = getNodeId(Context, Cond->getRHS());
+
+  if (Cond->getOpcode() == BO_EQ) {
+    // if (x == e) { x = d; }
+    D = Then->getRHS();
+    if (LHSId == XId)
+      E = Cond->getRHS();
+    else if (RHSId == XId)
+      E = Cond->getLHS();
+    else
+      return false;
+  } else {
+    // if (expr ordop x) { x = expr; }
+    // if (x ordop expr) { x = expr; }
+    E = Then->getRHS();
+
+    llvm::FoldingSetNodeID EId = getNodeId(Context, E);
+    if (LHSId == EId) {
+      if (RHSId != XId)
+        return false;
+    } else if (RHSId == EId) {
+      if (LHSId != XId)
+        return false;
+    } else
+      return false;
+  }
+
+  return true;
+}
+
+/// Helper class for checking expression in 'omp atomic compare' construct.
+class OpenMPAtomicCompareChecker {
+protected:
+  Sema &SemaRef;
+  ASTContext &Context;
+  /// 'x' lvalue part of the source atomic expression.
+  Expr *X = nullptr;
+  /// Depending on different cases, it can be:
+  /// 'e' rvalue part of the source atomic expression.
+  /// 'expr' rvalue part of the source atomic expression.
+  Expr *E = nullptr;
+  /// 'd' rvalue part of the source atomic expression.
+  Expr *D = nullptr;
+  /// 'cond' rvalue part of the source atomic expression, which is in one of the
+  /// following forms:
+  /// expr ordop x
+  /// x ordop expr
+  /// x == e
+  /// Note that when 'x' is on RHS, a corresponding LHS version is generated for
+  /// easier codegen.
+  Expr *C = nullptr;
+
+public:
+  /// All kinds of errors that can occur in `atomic compare`
+  enum class ErrorTy {
+    /// Not only one statement in compound statement
+    NotOneSubstatement = 0,
+    /// Not a cond-update statement (ref. spec 5.1)
+    NotCondUpdateStatement,
+    /// Not an assignment statement
+    NotAssignmentStatement,
+    /// Not a conditional statement
+    NotCondStatement,
+    /// Not a binary operator
+    NotBinaryOperator,
+    /// Not required ordop (ordop must be <, >, or ==)
+    NotRequiredOrderOp,
+    /// 'x' is not at false expression
+    NotAtFalseExpr,
+    /// 'x' is not in conditional expression
+    NotInCondExpr,
+    /// Not required expression 'x ordop e' or 'e ordop x'
+    NotAOrdopEExpr,
+    /// Not a lvalue scalar type expression
+    NotLValueScalarExpr,
+    /// Not a scalar type expression
+    NotScalarExpr,
+    /// Not an integral expression
+    NotIntegerExpr,
+    /// No error
+    NoError
+  };
+
+  /// Error descriptor type which will be returned to Sema
+  struct ErrorDescTy {
+    ErrorTy ErrorNo = ErrorTy::NoError;
+    SourceLocation ErrorLoc, NoteLoc;
+    SourceRange ErrorRange, NoteRange;
+  };
+
+  OpenMPAtomicCompareChecker(Sema &S)
+      : SemaRef(S), Context(S.getASTContext()) {}
+
+  /// Note: If `checkStatement` returns false, it is undefined behavior to call
+  /// the following getters.
+  /// Return the 'x' lvalue part of the source atomic expression.
+  Expr *getX() const { return X; }
+  /// Return the 'e' or 'expr' rvalue part of the source atomic expression.
+  Expr *getE() const { return E; }
+  /// Return the 'd' rvalue part of the source atomic expression.
+  Expr *getD() const { return D; }
+  /// Return the 'cond' rvalue part of the source atomic expression.
+  Expr *getCondExpr() const { return C; }
+
+  /// Check given statement if it conforms with requiremenet of `atomic
+  /// compare`, and set `x`, `e`, `d`, `expr` and conditional expr accordingly.
+  /// if (expr ordop x) { x = expr; }
+  /// if (x ordop expr) { x = expr; }
+  /// if (x == e) { x = d; }
+  /// x = expr ordop x ? expr : x;
+  /// x = expr ordop x ? expr : x;
+  /// x = x == e ? d : x;
+  bool checkStatement(Stmt *S);
+
+  /// Return the error descriptor that will guide the error message emission.
+  /// Note: it is undefined behavior to call it if `checkStatement` returns
+  /// true.
+  ErrorDescTy getErrorDesc() const { return ErrorDesc; }
+
+protected:
+  /// Error descriptor
+  ErrorDescTy ErrorDesc;
+  /// Check if all results conform with spec in terms of lvalue/rvalue and
+  /// scalar type.
+  bool checkResult();
+  /// Update {expr ordop x} to {x ordop expr} for easier codegen if needed.
+  void updateCondExpr();
+};
 } // namespace
 
+void OpenMPAtomicCompareChecker::updateCondExpr() {
+  // We don't apply any correctness check in this function.
+  auto *Cond = cast<BinaryOperator>(C);
+
+  if (Cond->getOpcode() == BO_EQ)
+    return;
+
+  if (checkIfTwoExprsAreSame(Context, X, Cond->getLHS()))
+    return;
+
+  BinaryOperatorKind Op = Cond->getOpcode() == BO_LT ? BO_GT : BO_LT;
+  ExprResult Result =
+      SemaRef.CreateBuiltinBinOp(Cond->getOperatorLoc(), Op, X, E);
+  assert(!Result.isInvalid());
+
+  C = Result.get();
+}
+
+bool OpenMPAtomicCompareChecker::checkStatement(Stmt *S) {
+  auto *CS = dyn_cast<CompoundStmt>(S);
+  if (CS) {
+    if (CS->size() != 1) {
+      ErrorDesc.ErrorNo = ErrorTy::NotOneSubstatement;
+      ErrorDesc.NoteRange = ErrorDesc.ErrorRange = CS->getSourceRange();
+      ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = CS->getBeginLoc();
+      return false;
+    }
+    S = CS->body_front();
+  }
+
+  // Check if the statement is in one of the following forms (cond-update-stmt):
+  // if (expr ordop x) { x = expr; }
+  // if (x ordop expr) { x = expr; }
+  // if (x == e) { x = d; }
+  if (auto *IS = dyn_cast<IfStmt>(S)) {
+    if (!checkIfCondUpdateStmt(Context, IS, X, E, D, C)) {
+      ErrorDesc.ErrorNo = ErrorTy::NotCondUpdateStatement;
+      ErrorDesc.NoteRange = ErrorDesc.ErrorRange = IS->getSourceRange();
+      ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = IS->getBeginLoc();
+      return false;
+    }
+
+    updateCondExpr();
+
+    return checkResult();
+  }
+
+  // Check if the statement is in one of the following forms (cond-expr-stmt):
+  // x = expr ordop x ? expr : x;
+  // x = x ordop expr ? expr : x;
+  // x = x == e ? d : x;
+  auto *BO = dyn_cast<BinaryOperator>(S);
+  if (!BO || BO->getOpcode() != BO_Assign) {
+    ErrorDesc.ErrorNo = ErrorTy::NotAssignmentStatement;
+    ErrorDesc.NoteRange = ErrorDesc.ErrorRange = S->getSourceRange();
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = S->getBeginLoc();
+    return false;
+  }
+
+  X = BO->getLHS();
+
+  auto *CO = dyn_cast<ConditionalOperator>(BO->getRHS()->IgnoreParenImpCasts());
+  if (!CO) {
+    ErrorDesc.ErrorNo = ErrorTy::NotCondStatement;
+    ErrorDesc.NoteRange = ErrorDesc.ErrorRange = BO->getRHS()->getSourceRange();
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = BO->getRHS()->getBeginLoc();
+    return false;
+  }
+
+  if (!checkIfTwoExprsAreSame(Context, X, CO->getFalseExpr())) {
+    ErrorDesc.ErrorNo = ErrorTy::NotAtFalseExpr;
+    ErrorDesc.NoteRange = ErrorDesc.ErrorRange =
+        CO->getFalseExpr()->getSourceRange();
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = CO->getFalseExpr()->getExprLoc();
+    return false;
+  }
+
+  auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
+  if (!Cond) {
+    ErrorDesc.ErrorNo = ErrorTy::NotBinaryOperator;
+    ErrorDesc.NoteRange = ErrorDesc.ErrorRange =
+        CO->getCond()->getSourceRange();
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = CO->getCond()->getBeginLoc();
+    return false;
+  }
+
+  // Only ==, <, > are supported
+  if (Cond->getOpcode() != BO_LT && Cond->getOpcode() != BO_GT &&
+      Cond->getOpcode() != BO_EQ) {
+    ErrorDesc.ErrorNo = ErrorTy::NotRequiredOrderOp;
+    ErrorDesc.NoteRange = ErrorDesc.ErrorRange = Cond->getSourceRange();
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = Cond->getOperatorLoc();
+    return false;
+  }
+
+  if (Cond->getOpcode() == BO_EQ) {
+    if (checkIfTwoExprsAreSame(Context, X, Cond->getLHS()))
+      E = Cond->getRHS();
+    else if (checkIfTwoExprsAreSame(Context, X, Cond->getRHS()))
+      E = Cond->getLHS();
+    else {
+      ErrorDesc.ErrorNo = ErrorTy::NotInCondExpr;
+      ErrorDesc.NoteRange = ErrorDesc.ErrorRange = Cond->getSourceRange();
+      ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = Cond->getExprLoc();
+      return false;
+    }
+    D = CO->getTrueExpr();
+  } else {
+    E = CO->getTrueExpr();
+    if (!((checkIfTwoExprsAreSame(Context, X, Cond->getLHS()) &&
+           checkIfTwoExprsAreSame(Context, E, Cond->getRHS())) ||
+          (checkIfTwoExprsAreSame(Context, X, Cond->getRHS()) &&
+           checkIfTwoExprsAreSame(Context, E, Cond->getLHS())))) {
+      ErrorDesc.ErrorNo = ErrorTy::NotAOrdopEExpr;
+      ErrorDesc.ErrorRange = Cond->getSourceRange();
+      ErrorDesc.NoteRange = X->getSourceRange();
+      ErrorDesc.ErrorLoc = Cond->getExprLoc();
+      ErrorDesc.NoteLoc = X->getExprLoc();
+      return false;
+    }
+  }
+
+  C = Cond;
+
+  updateCondExpr();
+
+  return checkResult();
+}
+
+bool OpenMPAtomicCompareChecker::checkResult() {
+  // 'x' and 'e' cannot be nullptr
+  assert(X && E && "X and E cannot be nullptr");
+
+  // 'x' and 'e' have to be scalar type. 'x' has to be lvalue, but no such this
+  // requirement for 'e' and 'd'.
+  if (!X->isLValue() || !X->getType()->isScalarType()) {
+    ErrorDesc.ErrorNo = ErrorTy::NotLValueScalarExpr;
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = X->getExprLoc();
+    ErrorDesc.ErrorRange = ErrorDesc.NoteRange = X->getSourceRange();
+    return false;
+  }
+
+  if (!E->getType()->isScalarType()) {
+    ErrorDesc.ErrorNo = ErrorTy::NotScalarExpr;
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = E->getExprLoc();
+    ErrorDesc.ErrorRange = ErrorDesc.NoteRange = E->getSourceRange();
+    return false;
+  }
+
+  if (D && !D->getType()->isScalarType()) {
+    ErrorDesc.ErrorNo = ErrorTy::NotScalarExpr;
+    ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = D->getExprLoc();
+    ErrorDesc.ErrorRange = ErrorDesc.NoteRange = D->getSourceRange();
+    return false;
+  }
+
+  // Currently we only support integer type in CodeGen. We check the
+  // compatibility here. The following code can be removed if we figure out how
+  // to support floating-point types.
+  auto &&CheckIfInteger = [&](Expr *E) {
+    if (!E->getType()->isIntegerType()) {
+      ErrorDesc.ErrorNo = ErrorTy::NotIntegerExpr;
+      ErrorDesc.NoteLoc = ErrorDesc.ErrorLoc = E->getExprLoc();
+      ErrorDesc.ErrorRange = ErrorDesc.NoteRange = E->getSourceRange();
+      return false;
+    }
+
+    return true;
+  };
+
+  if (!CheckIfInteger(X))
+    return false;
+
+  if (!CheckIfInteger(E))
+    return false;
+
+  if (D && !CheckIfInteger(D))
+    return false;
+
+  return true;
+}
+
 bool OpenMPAtomicUpdateChecker::checkBinaryOperation(
     BinaryOperator *AtomicBinOp, unsigned DiagId, unsigned NoteId) {
   ExprAnalysisErrorCode ErrorFound = NoError;
@@ -10695,24 +11055,30 @@
   OpenMPClauseKind MemOrderKind = OMPC_unknown;
   SourceLocation MemOrderLoc;
   for (const OMPClause *C : Clauses) {
-    if (C->getClauseKind() == OMPC_read || C->getClauseKind() == OMPC_write ||
-        C->getClauseKind() == OMPC_update ||
-        C->getClauseKind() == OMPC_capture) {
-      if (AtomicKind != OMPC_unknown) {
+    switch (C->getClauseKind()) {
+    default:
+      llvm_unreachable("unknow clause for atomic directive");
+    case OMPC_read:
+    case OMPC_write:
+    case OMPC_update:
+    case OMPC_capture:
+    case OMPC_compare:
+      if (AtomicKind == OMPC_unknown) {
+        AtomicKind = C->getClauseKind();
+        AtomicKindLoc = C->getBeginLoc();
+        break;
+      } else {
         Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses)
             << SourceRange(C->getBeginLoc(), C->getEndLoc());
         Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause)
             << getOpenMPClauseName(AtomicKind);
-      } else {
-        AtomicKind = C->getClauseKind();
-        AtomicKindLoc = C->getBeginLoc();
+        return StmtError();
       }
-    }
-    if (C->getClauseKind() == OMPC_seq_cst ||
-        C->getClauseKind() == OMPC_acq_rel ||
-        C->getClauseKind() == OMPC_acquire ||
-        C->getClauseKind() == OMPC_release ||
-        C->getClauseKind() == OMPC_relaxed) {
+    case OMPC_seq_cst:
+    case OMPC_acq_rel:
+    case OMPC_acquire:
+    case OMPC_release:
+    case OMPC_relaxed:
       if (MemOrderKind != OMPC_unknown) {
         Diag(C->getBeginLoc(), diag::err_omp_several_mem_order_clauses)
             << getOpenMPDirectiveName(OMPD_atomic) << 0
@@ -10723,8 +11089,10 @@
         MemOrderKind = C->getClauseKind();
         MemOrderLoc = C->getBeginLoc();
       }
+      break;
     }
   }
+
   // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions
   // If atomic-clause is read then memory-order-clause must not be acq_rel or
   // release.
@@ -10752,10 +11120,18 @@
   if (auto *EWC = dyn_cast<ExprWithCleanups>(Body))
     Body = EWC->getSubExpr();
 
+  // Stands for 'x' in the spec
   Expr *X = nullptr;
+  // Stands for 'v' in the spec
   Expr *V = nullptr;
+  // Stands for 'd' in the spec
+  Expr *D = nullptr;
+  // Stands for 'e' or 'expr' in the spec
   Expr *E = nullptr;
+  // Stands for update-stmt in the spec
   Expr *UE = nullptr;
+  // Stands for conditional statement in the spec
+  Expr *CE = nullptr;
   bool IsXLHSInRHSPart = false;
   bool IsPostfixUpdate = false;
   // OpenMP [2.12.6, atomic Construct]
@@ -10837,8 +11213,8 @@
     if (ErrorFound != NoError) {
       Diag(ErrorLoc, diag::err_omp_atomic_read_not_expression_statement)
           << ErrorRange;
-      Diag(NoteLoc, diag::note_omp_atomic_read_write) << ErrorFound
-                                                      << NoteRange;
+      Diag(NoteLoc, diag::note_omp_atomic_read_write)
+          << ErrorFound << NoteRange;
       return StmtError();
     }
     if (CurContext->isDependentContext())
@@ -10899,8 +11275,8 @@
     if (ErrorFound != NoError) {
       Diag(ErrorLoc, diag::err_omp_atomic_write_not_expression_statement)
           << ErrorRange;
-      Diag(NoteLoc, diag::note_omp_atomic_read_write) << ErrorFound
-                                                      << NoteRange;
+      Diag(NoteLoc, diag::note_omp_atomic_read_write)
+          << ErrorFound << NoteRange;
       return StmtError();
     }
     if (CurContext->isDependentContext())
@@ -10916,9 +11292,10 @@
     //  x = expr binop x;
     OpenMPAtomicUpdateChecker Checker(*this);
     if (Checker.checkStatement(
-            Body, (AtomicKind == OMPC_update)
-                      ? diag::err_omp_atomic_update_not_expression_statement
-                      : diag::err_omp_atomic_not_expression_statement,
+            Body,
+            (AtomicKind == OMPC_update)
+                ? diag::err_omp_atomic_update_not_expression_statement
+                : diag::err_omp_atomic_not_expression_statement,
             diag::note_omp_atomic_update))
       return StmtError();
     if (!CurContext->isDependentContext()) {
@@ -11132,21 +11509,37 @@
             SourceRange(Body->getBeginLoc(), Body->getBeginLoc());
         ErrorFound = NotACompoundStatement;
       }
-      if (ErrorFound != NoError) {
-        Diag(ErrorLoc, diag::err_omp_atomic_capture_not_compound_statement)
-            << ErrorRange;
-        Diag(NoteLoc, diag::note_omp_atomic_capture) << ErrorFound << NoteRange;
-        return StmtError();
-      }
-      if (CurContext->isDependentContext())
-        UE = V = E = X = nullptr;
     }
+    if (ErrorFound != NoError) {
+      Diag(ErrorLoc, diag::err_omp_atomic_capture_not_compound_statement)
+          << ErrorRange;
+      Diag(NoteLoc, diag::note_omp_atomic_capture) << ErrorFound << NoteRange;
+      return StmtError();
+    }
+    if (CurContext->isDependentContext())
+      UE = V = E = X = nullptr;
+  } else if (AtomicKind == OMPC_compare) {
+    // TODO: error handling
+    OpenMPAtomicCompareChecker Checker(*this);
+    if (!Checker.checkStatement(Body)) {
+      OpenMPAtomicCompareChecker::ErrorDescTy Err = Checker.getErrorDesc();
+      Diag(Err.ErrorLoc, diag::err_omp_atomic_compare) << Err.ErrorRange;
+      Diag(Err.NoteLoc, diag::note_omp_atomic_compare)
+          << static_cast<int>(Err.ErrorNo) << Err.NoteRange;
+      return StmtError();
+    }
+    X = Checker.getX();
+    E = Checker.getE();
+    D = Checker.getD();
+    CE = Checker.getCondExpr();
+    if (CurContext->isDependentContext())
+      X = E = D = CE = nullptr;
   }
 
   setFunctionHasBranchProtectedScope();
 
   return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    X, V, E, UE, IsXLHSInRHSPart,
+                                    X, V, E, D, UE, CE, IsXLHSInRHSPart,
                                     IsPostfixUpdate);
 }
 
@@ -13215,6 +13608,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14019,6 +14413,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14475,6 +14870,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14768,6 +15164,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14954,6 +15351,9 @@
   case OMPC_capture:
     Res = ActOnOpenMPCaptureClause(StartLoc, EndLoc);
     break;
+  case OMPC_compare:
+    Res = ActOnOpenMPCompareClause(StartLoc, EndLoc);
+    break;
   case OMPC_seq_cst:
     Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc);
     break;
@@ -15099,6 +15499,11 @@
   return new (Context) OMPCaptureClause(StartLoc, EndLoc);
 }
 
+OMPClause *Sema::ActOnOpenMPCompareClause(SourceLocation StartLoc,
+                                          SourceLocation EndLoc) {
+  return new (Context) OMPCompareClause(StartLoc, EndLoc);
+}
+
 OMPClause *Sema::ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                          SourceLocation EndLoc) {
   return new (Context) OMPSeqCstClause(StartLoc, EndLoc);
@@ -15567,6 +15972,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -2883,6 +2883,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -5566,6 +5566,16 @@
   return std::make_pair(true, RValue::get(Res));
 }
 
+static std::pair<bool, RValue> emitOMPAtomicCmpXchg(CodeGenFunction &CGF,
+                                                    LValue X, RValue E,
+                                                    RValue D,
+                                                    llvm::AtomicOrdering AO) {
+  llvm::Value *Res = CGF.Builder.CreateAtomicCmpXchg(
+      X.getPointer(CGF), E.getScalarVal(), D.getScalarVal(), AO, AO);
+  llvm::Value *Old = CGF.Builder.CreateExtractValue(Res, 0);
+  return std::make_pair(true, RValue::get(Old));
+}
+
 std::pair<bool, RValue> CodeGenFunction::EmitOMPAtomicSimpleUpdateExpr(
     LValue X, RValue E, BinaryOperatorKind BO, bool IsXLHSInRHSPart,
     llvm::AtomicOrdering AO, SourceLocation Loc,
@@ -5762,11 +5772,39 @@
   }
 }
 
+static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,
+                                     llvm::AtomicOrdering AO, const Expr *X,
+                                     const Expr *E, const Expr *D,
+                                     const Expr *CE, bool IsXLHSInRHSPart,
+                                     SourceLocation Loc) {
+  assert(X->isLValue() && "X of 'omp atomic compare' is not lvalue");
+  assert(isa<BinaryOperator>(CE->IgnoreImpCasts()) &&
+         "Cond expr in 'atomic compare' must be a binary operator.");
+
+  auto *CondBO = cast<BinaryOperator>(CE->IgnoreImpCasts());
+
+  LValue XLValue = CGF.EmitLValue(X->IgnoreParenImpCasts());
+  RValue ERValue = CGF.EmitAnyExpr(E->IgnoreParenImpCasts());
+
+  std::pair<bool, RValue> Res;
+
+  if (CondBO->getOpcode() == BO_EQ) {
+    RValue DRValue = CGF.EmitAnyExpr(D->IgnoreParenImpCasts());
+    Res = emitOMPAtomicCmpXchg(CGF, XLValue, ERValue, DRValue, AO);
+  } else
+    Res = emitOMPAtomicRMW(CGF, XLValue, ERValue, CondBO->getOpcode(), AO,
+                           /* IsXLHSInRHSPart */ false);
+  // Cannot emit atomic operation.
+  // TODO: Do we really want to emit a non-atomic operation here?
+  if (!Res.first) {
+  }
+}
+
 static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
                               llvm::AtomicOrdering AO, bool IsPostfixUpdate,
                               const Expr *X, const Expr *V, const Expr *E,
-                              const Expr *UE, bool IsXLHSInRHSPart,
-                              SourceLocation Loc) {
+                              const Expr *D, const Expr *UE, const Expr *CE,
+                              bool IsXLHSInRHSPart, SourceLocation Loc) {
   switch (Kind) {
   case OMPC_read:
     emitOMPAtomicReadExpr(CGF, AO, X, V, Loc);
@@ -5782,6 +5820,9 @@
     emitOMPAtomicCaptureExpr(CGF, AO, IsPostfixUpdate, V, X, E, UE,
                              IsXLHSInRHSPart, Loc);
     break;
+  case OMPC_compare:
+    emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc);
+    break;
   case OMPC_if:
   case OMPC_final:
   case OMPC_num_threads:
@@ -5919,8 +5960,8 @@
   LexicalScope Scope(*this, S.getSourceRange());
   EmitStopPoint(S.getAssociatedStmt());
   emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(),
-                    S.getExpr(), S.getUpdateExpr(), S.isXLHSInRHSPart(),
-                    S.getBeginLoc());
+                    S.getE(), S.getD(), S.getUpdateExpr(), S.getCondExpr(),
+                    S.isXLHSInRHSPart(), S.getBeginLoc());
 }
 
 static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -153,6 +153,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -396,6 +397,7 @@
   case OMPC_read:
   case OMPC_write:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -546,6 +546,8 @@
 
 void OMPClauseProfiler::VisitOMPCaptureClause(const OMPCaptureClause *) {}
 
+void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}
+
 void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -805,16 +805,20 @@
                                                    !IsStandalone);
 }
 
-OMPAtomicDirective *OMPAtomicDirective::Create(
-    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
-    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-    Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate) {
+OMPAtomicDirective *
+OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                           SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
+                           Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E,
+                           Expr *D, Expr *UE, Expr *CE, bool IsXLHSInRHSPart,
+                           bool IsPostfixUpdate) {
   auto *Dir = createDirective<OMPAtomicDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/4, StartLoc, EndLoc);
+      C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc);
   Dir->setX(X);
   Dir->setV(V);
-  Dir->setExpr(E);
+  Dir->setE(E);
+  Dir->setD(D);
   Dir->setUpdateExpr(UE);
+  Dir->setCondExpr(CE);
   Dir->IsXLHSInRHSPart = IsXLHSInRHSPart;
   Dir->IsPostfixUpdate = IsPostfixUpdate;
   return Dir;
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -126,6 +126,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -215,6 +216,7 @@
   case OMPC_write:
   case OMPC_update:
   case OMPC_capture:
+  case OMPC_compare:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -1765,6 +1767,10 @@
   OS << "capture";
 }
 
+void OMPClausePrinter::VisitOMPCompareClause(OMPCompareClause *) {
+  OS << "compare";
+}
+
 void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   OS << "seq_cst";
 }
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11062,6 +11062,9 @@
   /// Called on well-formed 'capture' clause.
   OMPClause *ActOnOpenMPCaptureClause(SourceLocation StartLoc,
                                       SourceLocation EndLoc);
+  /// Called on well-formed 'compare' clause.
+  OMPClause *ActOnOpenMPCompareClause(SourceLocation StartLoc,
+                                      SourceLocation EndLoc);
   /// Called on well-formed 'seq_cst' clause.
   OMPClause *ActOnOpenMPSeqCstClause(SourceLocation StartLoc,
                                      SourceLocation EndLoc);
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10442,10 +10442,15 @@
   " '{v = x; x = x binop expr;}', '{v = x; x = expr binop x;}', '{x = x binop expr; v = x;}', '{x = expr binop x; v = x;}' or '{v = x; x = expr;}',"
   " '{v = x; x++;}', '{v = x; ++x;}', '{++x; v = x;}', '{x++; v = x;}', '{v = x; x--;}', '{v = x; --x;}', '{--x; v = x;}', '{x--; v = x;}'"
   " where x is an lvalue expression with scalar type">;
+def err_omp_atomic_compare : Error<
+  "the statement for 'atomic compare' must be a compound statement of form '{x = expr ordop x ? expr : x;}', '{x = x ordop expr ? expr : x;}', 'x = x == e ? d : x;',"
+  " '{if (expr ordop x) { x = expr; }}', '{if (x ordop expr) { x = expr; }}', or '{if (x == e) { x = d; }}' where x, v are lvalue expressions with scalar type, e, d are expression with scalar type.">;
 def note_omp_atomic_capture: Note<
   "%select{expected assignment expression|expected compound statement|expected exactly two expression statements|expected in right hand side of the first expression}0">;
+def note_omp_atomic_compare: Note<
+  "%select{expected compound statement of size one|expected conditional update statement|expected assignment expression|expected conditional expression|expected binary operator|expected <, > or == for ordop|expected 'x' at false expression|expected 'x' in conditional expression|expected 'x ordop e'|expected lvalue scalar expression|expected scalar expression|only integers supported}0">;
 def err_omp_atomic_several_clauses : Error<
-  "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update' or 'capture' clause">;
+  "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<
   "directive '#pragma omp %0' cannot contain more than one %select{'seq_cst', 'relaxed', |}1'acq_rel', 'acquire' or 'release' clause">;
 def err_omp_atomic_incompatible_mem_order_clause : Error<
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -2799,20 +2799,29 @@
     POS_V,
     POS_E,
     POS_UpdateExpr,
+    POS_D,
+    POS_CondExpr
   };
 
   /// Set 'x' part of the associated expression/statement.
   void setX(Expr *X) { Data->getChildren()[DataPositionTy::POS_X] = X; }
+  /// Set 'v' part of the associated expression/statement.
+  void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
+  /// Set 'd' part of the associated expression/statement.
+  void setD(Expr *D) { Data->getChildren()[DataPositionTy::POS_D] = D; }
+  /// Set 'e' or 'expr' part of the associated expression/statement.
+  void setE(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
   /// Set helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
   void setUpdateExpr(Expr *UE) {
     Data->getChildren()[DataPositionTy::POS_UpdateExpr] = UE;
   }
-  /// Set 'v' part of the associated expression/statement.
-  void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
-  /// Set 'expr' part of the associated expression/statement.
-  void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
+  /// Set the conditional expression part of the associated expression/statement
+  /// in atomic compare.
+  void setCondExpr(Expr *CE) {
+    Data->getChildren()[DataPositionTy::POS_CondExpr] = CE;
+  }
 
 public:
   /// Creates directive with a list of \a Clauses and 'x', 'v' and 'expr'
@@ -2837,7 +2846,8 @@
   static OMPAtomicDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-         Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate);
+         Expr *E, Expr *D, Expr *Ex, Expr *UE, bool IsXLHSInRHSPart,
+         bool IsPostfixUpdate);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -2855,6 +2865,27 @@
   const Expr *getX() const {
     return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_X]);
   }
+  /// Get 'd' part of the associated expression/statement.
+  Expr *getD() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  const Expr *getD() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  /// Get 'e' part of the associated expression/statement.
+  Expr *getE() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
+  }
+  const Expr *getE() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
+  }
+  /// Get 'v' part of the associated expression/statement.
+  Expr *getV() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
+  }
+  const Expr *getV() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
+  }
   /// Get helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
@@ -2866,6 +2897,15 @@
     return cast_or_null<Expr>(
         Data->getChildren()[DataPositionTy::POS_UpdateExpr]);
   }
+  /// Get the conditional expression.
+  Expr *getCondExpr() {
+    return cast_or_null<Expr>(
+        Data->getChildren()[DataPositionTy::POS_CondExpr]);
+  }
+  const Expr *getCondExpr() const {
+    return cast_or_null<Expr>(
+        Data->getChildren()[DataPositionTy::POS_CondExpr]);
+  }
   /// Return true if helper update expression has form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' and false if it has form
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
@@ -2873,20 +2913,6 @@
   /// Return true if 'v' expression must be updated to original value of
   /// 'x', false if 'v' must be updated to the new value of 'x'.
   bool isPostfixUpdate() const { return IsPostfixUpdate; }
-  /// Get 'v' part of the associated expression/statement.
-  Expr *getV() {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
-  }
-  const Expr *getV() const {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
-  }
-  /// Get 'expr' part of the associated expression/statement.
-  Expr *getExpr() {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
-  }
-  const Expr *getExpr() const {
-    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
-  }
 
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPAtomicDirectiveClass;
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3218,6 +3218,11 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPCompareClause(OMPCompareClause *) {
+  return true;
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -2149,6 +2149,47 @@
   }
 };
 
+/// This represents 'compare' clause in the '#pragma omp atomic'
+/// directive.
+///
+/// \code
+/// #pragma omp atomic compare
+/// \endcode
+/// In this example directive '#pragma omp atomic' has 'compare' clause.
+class OMPCompareClause : public OMPClause {
+public:
+  /// Build 'compare' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPCompareClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_compare, StartLoc, EndLoc) {}
+
+  /// Build an empty clause.
+  OMPCompareClause()
+      : OMPClause(llvm::omp::OMPC_compare, SourceLocation(), SourceLocation()) {
+  }
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_compare;
+  }
+};
+
 /// This represents 'seq_cst' clause in the '#pragma omp atomic'
 /// directive.
 ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to