https://github.com/jhuber6 updated 
https://github.com/llvm/llvm-project/pull/141142

>From f2c18ba64744320a8e2a63938b17137a1b6e74d7 Mon Sep 17 00:00:00 2001
From: Joseph Huber <hube...@outlook.com>
Date: Thu, 22 May 2025 16:21:34 -0500
Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded
 operators

Summary:
When there are overloaded C++ operators in the global namespace the AST
node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify
the uses to handle this case, basically just treating it as a binary
expression with two arguments.
---
 clang/lib/Sema/SemaOpenMP.cpp         | 162 +++++++++++++++-----------
 clang/test/OpenMP/atomic_messages.cpp |  31 +++++
 2 files changed, 126 insertions(+), 67 deletions(-)

diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index f16f841d62edd..8d580d1968238 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -11763,51 +11763,61 @@ bool 
OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
   X = BO->getLHS();
 
   auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
-  if (!Cond) {
+  auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
+  Expr *LHS = nullptr;
+  Expr *RHS = nullptr;
+  if (Cond) {
+    LHS = Cond->getLHS();
+    RHS = Cond->getRHS();
+  } else if (Call) {
+    LHS = Call->getArg(0);
+    RHS = Call->getArg(1);
+  } else {
     ErrorInfo.Error = ErrorTy::NotABinaryOp;
     ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
     ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
 
-  switch (Cond->getOpcode()) {
-  case BO_EQ: {
-    C = Cond;
+  if ((Cond && Cond->getOpcode() == BO_EQ) ||
+      (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
+    C = S->getCond();
     D = BO->getRHS();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-      E = Cond->getRHS();
-    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      E = Cond->getLHS();
+    if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+      E = RHS;
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+      E = LHS;
     } else {
       ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+          S->getCond()->getSourceRange();
       return false;
     }
-    break;
-  }
-  case BO_LT:
-  case BO_GT: {
+  } else if ((Cond &&
+              (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
+             (Call &&
+              (Call->getOperator() == OverloadedOperatorKind::OO_Less ||
+               Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
     E = BO->getRHS();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
-        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
-      C = Cond;
-    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
-               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      C = Cond;
+    if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
+        checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
+      C = S->getCond();
+    } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
+               checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+      C = S->getCond();
       IsXBinopExpr = false;
     } else {
       ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+          S->getCond()->getSourceRange();
       return false;
     }
-    break;
-  }
-  default:
+  } else {
     ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
 
@@ -11857,52 +11867,60 @@ bool 
OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
   }
 
   auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
-  if (!Cond) {
+  auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond());
+  Expr *LHS = nullptr;
+  Expr *RHS = nullptr;
+  if (Cond) {
+    LHS = Cond->getLHS();
+    RHS = Cond->getRHS();
+  } else if (Call) {
+    LHS = Call->getArg(0);
+    RHS = Call->getArg(1);
+  } else {
     ErrorInfo.Error = ErrorTy::NotABinaryOp;
     ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
-        CO->getCond()->getSourceRange();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
CO->getCond()->getSourceRange();
     return false;
   }
 
-  switch (Cond->getOpcode()) {
-  case BO_EQ: {
-    C = Cond;
+  if ((Cond && Cond->getOpcode() == BO_EQ) ||
+      (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) {
+    C = CO->getCond();
     D = CO->getTrueExpr();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-      E = Cond->getRHS();
-    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      E = Cond->getLHS();
+    if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+      E = RHS;
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+      E = LHS;
     } else {
       ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
CO->getCond()->getSourceRange();
       return false;
     }
-    break;
-  }
-  case BO_LT:
-  case BO_GT: {
+  } else if ((Cond &&
+              (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) ||
+             (Call &&
+              (Call->getOperator() == OverloadedOperatorKind::OO_Less ||
+               Call->getOperator() == OverloadedOperatorKind::OO_Greater))) {
+
     E = CO->getTrueExpr();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
-        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
-      C = Cond;
-    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
-               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      C = Cond;
+    if (checkIfTwoExprsAreSame(ContextRef, X, LHS) &&
+        checkIfTwoExprsAreSame(ContextRef, E, RHS)) {
+      C = CO->getCond();
+    } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) &&
+               checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+      C = CO->getCond();
       IsXBinopExpr = false;
     } else {
       ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
CO->getCond()->getSourceRange();
       return false;
     }
-    break;
-  }
-  default:
+  } else {
     ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
CO->getCond()->getSourceRange();
     return false;
   }
 
@@ -12063,31 +12081,41 @@ bool 
OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
   D = BO->getRHS();
 
   auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
-  if (!Cond) {
+  auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond());
+  Expr *LHS = nullptr;
+  Expr *RHS = nullptr;
+  if (Cond) {
+    LHS = Cond->getLHS();
+    RHS = Cond->getRHS();
+  } else if (Call) {
+    LHS = Call->getArg(0);
+    RHS = Call->getArg(1);
+  } else {
     ErrorInfo.Error = ErrorTy::NotABinaryOp;
     ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
     ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
-  if (Cond->getOpcode() != BO_EQ) {
+  if ((Cond && Cond->getOpcode() != BO_EQ) ||
+      (Call && Call->getOperator() != OverloadedOperatorKind::OO_EqualEqual)) {
     ErrorInfo.Error = ErrorTy::NotEQ;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
 
-  if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-    E = Cond->getRHS();
-  } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-    E = Cond->getLHS();
+  if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) {
+    E = RHS;
+  } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) {
+    E = LHS;
   } else {
     ErrorInfo.Error = ErrorTy::InvalidComparison;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
 
-  C = Cond;
+  C = S->getCond();
 
   if (!S->getElse()) {
     ErrorInfo.Error = ErrorTy::NoElse;
diff --git a/clang/test/OpenMP/atomic_messages.cpp 
b/clang/test/OpenMP/atomic_messages.cpp
index d492f6ee1e896..c4e240a0ebb4e 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -991,3 +991,34 @@ int mixed() {
   // expected-note@+1 {{in instantiation of function template specialization 
'mixed<int>' requested here}}
   return mixed<int>();
 }
+
+#ifdef OMP51
+struct U {};
+struct U operator<(U, U);
+struct U operator>(U, U);
+struct U operator==(U, U);
+
+template <typename T> void templated() {
+  T cx, cv, ce, cd;
+#pragma omp atomic compare capture
+  if (cx == ce) {
+    cx = cd;
+  } else {
+    cv = cx;
+  }
+#pragma omp atomic compare capture
+  {
+    cv = cx;
+    if (ce > cx) {
+      cx = ce;
+    }
+  }
+#pragma omp atomic compare capture
+  {
+    cv = cx;
+    if (cx < ce) {
+      cx = ce;
+    }
+  }
+}
+#endif

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to