tianshilei1992 updated this revision to Diff 407309.
tianshilei1992 marked 2 inline comments as done.
tianshilei1992 added a comment.

use `LLVM_FALLTHROUGH`


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116261/new/

https://reviews.llvm.org/D116261

Files:
  clang/include/clang/AST/OpenMPClause.h
  clang/include/clang/AST/RecursiveASTVisitor.h
  clang/lib/AST/OpenMPClause.cpp
  clang/lib/AST/StmtProfile.cpp
  clang/lib/Basic/OpenMPKinds.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Parse/ParseOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReader.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/OpenMP/atomic_messages.cpp
  clang/tools/libclang/CIndex.cpp
  llvm/include/llvm/Frontend/OpenMP/OMP.td

Index: llvm/include/llvm/Frontend/OpenMP/OMP.td
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMP.td
+++ llvm/include/llvm/Frontend/OpenMP/OMP.td
@@ -181,6 +181,10 @@
 def OMPC_Update : Clause<"update"> { let clangClass = "OMPUpdateClause"; }
 def OMPC_Capture : Clause<"capture"> { let clangClass = "OMPCaptureClause"; }
 def OMPC_Compare : Clause<"compare"> { let clangClass = "OMPCompareClause"; }
+// A dummy clause if compare and capture clauses are present.
+def OMPC_CompareCapture : Clause<"compare_capture"> {
+  let clangClass = "OMPCompareCaptureClause";
+}
 def OMPC_SeqCst : Clause<"seq_cst"> { let clangClass = "OMPSeqCstClause"; }
 def OMPC_AcqRel : Clause<"acq_rel"> { let clangClass = "OMPAcqRelClause"; }
 def OMPC_Acquire : Clause<"acquire"> { let clangClass = "OMPAcquireClause"; }
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -2277,6 +2277,11 @@
 
 void OMPClauseEnqueue::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseEnqueue::VisitOMPCompareCaptureClause(
+    const OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 void OMPClauseEnqueue::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseEnqueue::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/test/OpenMP/atomic_messages.cpp
===================================================================
--- clang/test/OpenMP/atomic_messages.cpp
+++ clang/test/OpenMP/atomic_messages.cpp
@@ -958,3 +958,26 @@
   // expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}}
   return mixed<int>();
 }
+
+#if _OPENMP >= 202011
+int compare() {
+  int a, b, c;
+// omp51-error@+1 {{atomic compare is not supported for now}}
+#pragma omp atomic compare
+  {
+    if (a == b)
+      a = c;
+  }
+}
+
+int compare_capture() {
+  int a, b, c, x;
+// omp51-error@+1 {{atomic compare capture is not supported for now}}
+#pragma omp atomic compare capture
+  {
+    x = a;
+    if (a == b)
+      a = c;
+  }
+}
+#endif
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6295,6 +6295,10 @@
 
 void OMPClauseWriter::VisitOMPCompareClause(OMPCompareClause *) {}
 
+void OMPClauseWriter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 void OMPClauseWriter::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseWriter::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -11786,6 +11786,8 @@
   case llvm::omp::OMPC_compare:
     C = new (Context) OMPCompareClause();
     break;
+  case llvm::omp::OMPC_compare_capture:
+    llvm_unreachable("OMPCompareCaptureClause should never be reached");
   case llvm::omp::OMPC_seq_cst:
     C = new (Context) OMPSeqCstClause();
     break;
@@ -12146,6 +12148,10 @@
 
 void OMPClauseReader::VisitOMPCompareClause(OMPCompareClause *) {}
 
+void OMPClauseReader::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 void OMPClauseReader::VisitOMPSeqCstClause(OMPSeqCstClause *) {}
 
 void OMPClauseReader::VisitOMPAcqRelClause(OMPAcqRelClause *) {}
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -9476,6 +9476,12 @@
   return C;
 }
 
+template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPCompareCaptureClause(
+    OMPCompareCaptureClause *C) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 template <typename Derived>
 OMPClause *
 TreeTransform<Derived>::TransformOMPSeqCstClause(OMPSeqCstClause *C) {
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -35,6 +35,7 @@
 #include "llvm/ADT/IndexedMap.h"
 #include "llvm/ADT/PointerEmbeddedInt.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Frontend/OpenMP/OMPAssume.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
@@ -6365,6 +6366,7 @@
       case OMPC_update:
       case OMPC_capture:
       case OMPC_compare:
+      case OMPC_compare_capture:
       case OMPC_seq_cst:
       case OMPC_acq_rel:
       case OMPC_acquire:
@@ -11315,14 +11317,18 @@
   SourceLocation AtomicKindLoc;
   OpenMPClauseKind MemOrderKind = OMPC_unknown;
   SourceLocation MemOrderLoc;
+  bool MutexClauseEncountered = false;
+  llvm::SmallSet<OpenMPClauseKind, 2> EncounteredAtomicKinds;
   for (const OMPClause *C : Clauses) {
     switch (C->getClauseKind()) {
     case OMPC_read:
     case OMPC_write:
     case OMPC_update:
+      MutexClauseEncountered = true;
+      LLVM_FALLTHROUGH;
     case OMPC_capture:
     case OMPC_compare: {
-      if (AtomicKind != OMPC_unknown) {
+      if (AtomicKind != OMPC_unknown && MutexClauseEncountered) {
         Diag(C->getBeginLoc(), diag::err_omp_atomic_several_clauses)
             << SourceRange(C->getBeginLoc(), C->getEndLoc());
         Diag(AtomicKindLoc, diag::note_omp_previous_mem_order_clause)
@@ -11330,6 +11336,7 @@
       } else {
         AtomicKind = C->getClauseKind();
         AtomicKindLoc = C->getBeginLoc();
+        EncounteredAtomicKinds.insert(C->getClauseKind());
       }
       break;
     }
@@ -11353,10 +11360,15 @@
     // The following clauses are allowed, but we don't need to do anything here.
     case OMPC_hint:
       break;
+    case OMPC_compare_capture:
+      llvm_unreachable("OMPC_compare_capture should never be reached");
     default:
       llvm_unreachable("unknown clause is encountered");
     }
   }
+  if (EncounteredAtomicKinds.contains(OMPC_compare) &&
+      EncounteredAtomicKinds.contains(OMPC_capture))
+    AtomicKind = OMPC_compare_capture;
   // OpenMP 5.0, 2.17.7 atomic Construct, Restrictions
   // If atomic-clause is read then memory-order-clause must not be acq_rel or
   // release.
@@ -11786,6 +11798,13 @@
     }
     // TODO: We don't set X, D, E, etc. here because in code gen we will emit
     // error directly.
+  } else if (AtomicKind == OMPC_compare_capture) {
+    // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore
+    // code gen.
+    unsigned DiagID = Diags.getCustomDiagID(
+        DiagnosticsEngine::Error,
+        "atomic compare capture is not supported for now");
+    Diag(AtomicKindLoc, DiagID);
   }
 
   setFunctionHasBranchProtectedScope();
@@ -13867,6 +13886,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -14699,6 +14719,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -15161,6 +15182,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -15469,6 +15491,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -15660,6 +15683,8 @@
   case OMPC_compare:
     Res = ActOnOpenMPCompareClause(StartLoc, EndLoc);
     break;
+  case OMPC_compare_capture:
+    llvm_unreachable("compare_capture is dummy node");
   case OMPC_seq_cst:
     Res = ActOnOpenMPSeqCstClause(StartLoc, EndLoc);
     break;
@@ -16280,6 +16305,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -3222,6 +3222,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -24,6 +24,7 @@
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/OpenMPKinds.h"
 #include "clang/Basic/PrettyStackTrace.h"
+#include "llvm/ADT/SmallSet.h"
 #include "llvm/BinaryFormat/Dwarf.h"
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
@@ -6038,6 +6039,9 @@
     CGF.CGM.getDiags().Report(DiagID);
     break;
   }
+  case OMPC_compare_capture:
+    // Do nothing here as we already emit an error.
+    break;
   case OMPC_if:
   case OMPC_final:
   case OMPC_num_threads:
@@ -6148,19 +6152,21 @@
     AO = llvm::AtomicOrdering::Monotonic;
     MemOrderingSpecified = true;
   }
+  llvm::SmallSet<OpenMPClauseKind, 2> KindsEncountered;
   OpenMPClauseKind Kind = OMPC_unknown;
   for (const OMPClause *C : S.clauses()) {
     // Find first clause (skip seq_cst|acq_rel|aqcuire|release|relaxed clause,
     // if it is first).
-    if (C->getClauseKind() != OMPC_seq_cst &&
-        C->getClauseKind() != OMPC_acq_rel &&
-        C->getClauseKind() != OMPC_acquire &&
-        C->getClauseKind() != OMPC_release &&
-        C->getClauseKind() != OMPC_relaxed && C->getClauseKind() != OMPC_hint) {
-      Kind = C->getClauseKind();
-      break;
-    }
+    OpenMPClauseKind K = C->getClauseKind();
+    if (K == OMPC_seq_cst || K == OMPC_acq_rel || K == OMPC_acquire ||
+        K == OMPC_release || K == OMPC_relaxed || K == OMPC_hint)
+      continue;
+    Kind = K;
+    KindsEncountered.insert(K);
   }
+  if (KindsEncountered.contains(OMPC_compare) &&
+      KindsEncountered.contains(OMPC_capture))
+    Kind = OMPC_compare_capture;
   if (!MemOrderingSpecified) {
     llvm::AtomicOrdering DefaultOrder =
         CGM.getOpenMPRuntime().getDefaultMemoryOrdering();
Index: clang/lib/Basic/OpenMPKinds.cpp
===================================================================
--- clang/lib/Basic/OpenMPKinds.cpp
+++ clang/lib/Basic/OpenMPKinds.cpp
@@ -168,6 +168,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -434,6 +435,7 @@
   case OMPC_write:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
Index: clang/lib/AST/StmtProfile.cpp
===================================================================
--- clang/lib/AST/StmtProfile.cpp
+++ clang/lib/AST/StmtProfile.cpp
@@ -553,6 +553,11 @@
 
 void OMPClauseProfiler::VisitOMPCompareClause(const OMPCompareClause *) {}
 
+void OMPClauseProfiler::VisitOMPCompareCaptureClause(
+    const OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 void OMPClauseProfiler::VisitOMPSeqCstClause(const OMPSeqCstClause *) {}
 
 void OMPClauseProfiler::VisitOMPAcqRelClause(const OMPAcqRelClause *) {}
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -127,6 +127,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -219,6 +220,7 @@
   case OMPC_update:
   case OMPC_capture:
   case OMPC_compare:
+  case OMPC_compare_capture:
   case OMPC_seq_cst:
   case OMPC_acq_rel:
   case OMPC_acquire:
@@ -1798,6 +1800,10 @@
   OS << "compare";
 }
 
+void OMPClausePrinter::VisitOMPCompareCaptureClause(OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 void OMPClausePrinter::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   OS << "seq_cst";
 }
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -3239,6 +3239,12 @@
   return true;
 }
 
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPCompareCaptureClause(
+    OMPCompareCaptureClause *) {
+  llvm_unreachable("OMPCompareCaptureClause should never be reached");
+}
+
 template <typename Derived>
 bool RecursiveASTVisitor<Derived>::VisitOMPSeqCstClause(OMPSeqCstClause *) {
   return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -2266,6 +2266,48 @@
   }
 };
 
+/// This is a dummy clause that represents 'compare' and 'capture' clauses are
+/// present in the '#pragma omp atomic' directive.
+///
+/// \code
+/// #pragma omp atomic compare capture
+/// \endcode
+/// In this example directive '#pragma omp atomic' has 'compare' and 'capture'
+/// clauses.
+class OMPCompareCaptureClause final : public OMPClause {
+public:
+  /// Build 'compare capture' clause.
+  ///
+  /// \param StartLoc Starting location of the clause.
+  /// \param EndLoc Ending location of the clause.
+  OMPCompareCaptureClause(SourceLocation StartLoc, SourceLocation EndLoc)
+      : OMPClause(llvm::omp::OMPC_compare_capture, StartLoc, EndLoc) {}
+
+  /// Build an empty clause.
+  OMPCompareCaptureClause()
+      : OMPClause(llvm::omp::OMPC_compare_capture, SourceLocation(),
+                  SourceLocation()) {}
+
+  child_range children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+
+  const_child_range children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  child_range used_children() {
+    return child_range(child_iterator(), child_iterator());
+  }
+  const_child_range used_children() const {
+    return const_child_range(const_child_iterator(), const_child_iterator());
+  }
+
+  static bool classof(const OMPClause *T) {
+    return T->getClauseKind() == llvm::omp::OMPC_compare_capture;
+  }
+};
+
 /// This represents 'seq_cst' clause in the '#pragma omp atomic'
 /// directive.
 ///
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to