https://github.com/jhuber6 updated 
https://github.com/llvm/llvm-project/pull/141142

>From a45dc43315631f28ced9cf5a14890e46e011e6d2 Mon Sep 17 00:00:00 2001
From: Joseph Huber <hube...@outlook.com>
Date: Thu, 22 May 2025 16:21:34 -0500
Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded
 operators

Summary:
When there are overloaded C++ operators in the global namespace the AST
node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify
the uses to handle this case, basically just treating it as a binary
expression with two arguments.
---
 clang/lib/Sema/SemaOpenMP.cpp         | 320 ++++++++++++++++++--------
 clang/test/OpenMP/atomic_messages.cpp |  31 +++
 2 files changed, 249 insertions(+), 102 deletions(-)

diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index f16f841d62edd..a0ad814c366d8 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -11762,52 +11762,98 @@ bool 
OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
 
   X = BO->getLHS();
 
-  auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
-  if (!Cond) {
-    ErrorInfo.Error = ErrorTy::NotABinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
-    return false;
-  }
-
-  switch (Cond->getOpcode()) {
-  case BO_EQ: {
-    C = Cond;
-    D = BO->getRHS();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-      E = Cond->getRHS();
-    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      E = Cond->getLHS();
-    } else {
-      ErrorInfo.Error = ErrorTy::InvalidComparison;
+  if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+    switch (Cond->getOpcode()) {
+    case BO_EQ: {
+      C = Cond;
+      D = BO->getRHS();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+        E = Cond->getRHS();
+      } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+        E = Cond->getLHS();
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    case BO_LT:
+    case BO_GT: {
+      E = BO->getRHS();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+          checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+        C = Cond;
+      } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+                 checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+        C = Cond;
+        IsXBinopExpr = false;
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    default:
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
       ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
       ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
       return false;
     }
-    break;
-  }
-  case BO_LT:
-  case BO_GT: {
-    E = BO->getRHS();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
-        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
-      C = Cond;
-    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
-               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      C = Cond;
-      IsXBinopExpr = false;
-    } else {
-      ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+  } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+    if (Call->getNumArgs() != 2) {
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
       return false;
     }
-    break;
-  }
-  default:
-    ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    switch (Call->getOperator()) {
+    case clang::OverloadedOperatorKind::OO_EqualEqual: {
+      C = Call;
+      D = BO->getLHS();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+        E = Call->getArg(1);
+      } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+        E = Call->getArg(0);
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    case clang::OverloadedOperatorKind::OO_Greater:
+    case clang::OverloadedOperatorKind::OO_Less: {
+      E = BO->getRHS();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+          checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+        C = Call;
+      } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+                 checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+        C = Call;
+        IsXBinopExpr = false;
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    default:
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+      return false;
+    }
+  } else {
+    ErrorInfo.Error = ErrorTy::NotABinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
 
@@ -11856,53 +11902,99 @@ bool 
OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
     return false;
   }
 
-  auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
-  if (!Cond) {
-    ErrorInfo.Error = ErrorTy::NotABinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
-        CO->getCond()->getSourceRange();
-    return false;
-  }
-
-  switch (Cond->getOpcode()) {
-  case BO_EQ: {
-    C = Cond;
-    D = CO->getTrueExpr();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-      E = Cond->getRHS();
-    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      E = Cond->getLHS();
-    } else {
-      ErrorInfo.Error = ErrorTy::InvalidComparison;
+  if (auto *Cond = dyn_cast<BinaryOperator>(CO->getCond())) {
+    switch (Cond->getOpcode()) {
+    case BO_EQ: {
+      C = Cond;
+      D = CO->getTrueExpr();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+        E = Cond->getRHS();
+      } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+        E = Cond->getLHS();
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    case BO_LT:
+    case BO_GT: {
+      E = CO->getTrueExpr();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+          checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+        C = Cond;
+      } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+                 checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+        C = Cond;
+        IsXBinopExpr = false;
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    default:
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
       ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
       ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
       return false;
     }
-    break;
-  }
-  case BO_LT:
-  case BO_GT: {
-    E = CO->getTrueExpr();
-    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
-        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
-      C = Cond;
-    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
-               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-      C = Cond;
-      IsXBinopExpr = false;
-    } else {
-      ErrorInfo.Error = ErrorTy::InvalidComparison;
-      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+  } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond())) {
+    if (Call->getNumArgs() != 2) {
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
       return false;
     }
-    break;
-  }
-  default:
-    ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    switch (Call->getOperator()) {
+    case clang::OverloadedOperatorKind::OO_EqualEqual: {
+      C = Call;
+      D = CO->getTrueExpr();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+        E = Call->getArg(1);
+      } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+        E = Call->getArg(0);
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    case clang::OverloadedOperatorKind::OO_Less:
+    case clang::OverloadedOperatorKind::OO_Greater: {
+      E = CO->getTrueExpr();
+      if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) &&
+          checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) {
+        C = Call;
+      } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) &&
+                 checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+        C = Call;
+        IsXBinopExpr = false;
+      } else {
+        ErrorInfo.Error = ErrorTy::InvalidComparison;
+        ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+        ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+        return false;
+      }
+      break;
+    }
+    default:
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+      return false;
+    }
+  } else {
+    ErrorInfo.Error = ErrorTy::NotABinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+        CO->getCond()->getSourceRange();
     return false;
   }
 
@@ -12062,32 +12154,56 @@ bool 
OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S,
   X = BO->getLHS();
   D = BO->getRHS();
 
-  auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
-  if (!Cond) {
+  if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) {
+    C = Cond;
+    if (Cond->getOpcode() != BO_EQ) {
+      ErrorInfo.Error = ErrorTy::NotEQ;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+
+    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+      E = Cond->getRHS();
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+      E = Cond->getLHS();
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+  } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) {
+    C = Call;
+    if (Call->getNumArgs() != 2) {
+      ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+      return false;
+    }
+    if (Call->getOperator() != clang::OverloadedOperatorKind::OO_EqualEqual) {
+      ErrorInfo.Error = ErrorTy::NotEQ;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+      return false;
+    }
+
+    if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) {
+      E = Call->getArg(1);
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) {
+      E = Call->getArg(0);
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange();
+      return false;
+    }
+  } else {
     ErrorInfo.Error = ErrorTy::NotABinaryOp;
     ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
     ErrorInfo.ErrorRange = ErrorInfo.NoteRange = 
S->getCond()->getSourceRange();
     return false;
   }
-  if (Cond->getOpcode() != BO_EQ) {
-    ErrorInfo.Error = ErrorTy::NotEQ;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
-    return false;
-  }
-
-  if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
-    E = Cond->getRHS();
-  } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
-    E = Cond->getLHS();
-  } else {
-    ErrorInfo.Error = ErrorTy::InvalidComparison;
-    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
-    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
-    return false;
-  }
-
-  C = Cond;
 
   if (!S->getElse()) {
     ErrorInfo.Error = ErrorTy::NoElse;
diff --git a/clang/test/OpenMP/atomic_messages.cpp 
b/clang/test/OpenMP/atomic_messages.cpp
index d492f6ee1e896..c4e240a0ebb4e 100644
--- a/clang/test/OpenMP/atomic_messages.cpp
+++ b/clang/test/OpenMP/atomic_messages.cpp
@@ -991,3 +991,34 @@ int mixed() {
   // expected-note@+1 {{in instantiation of function template specialization 
'mixed<int>' requested here}}
   return mixed<int>();
 }
+
+#ifdef OMP51
+struct U {};
+struct U operator<(U, U);
+struct U operator>(U, U);
+struct U operator==(U, U);
+
+template <typename T> void templated() {
+  T cx, cv, ce, cd;
+#pragma omp atomic compare capture
+  if (cx == ce) {
+    cx = cd;
+  } else {
+    cv = cx;
+  }
+#pragma omp atomic compare capture
+  {
+    cv = cx;
+    if (ce > cx) {
+      cx = ce;
+    }
+  }
+#pragma omp atomic compare capture
+  {
+    cv = cx;
+    if (cx < ce) {
+      cx = ce;
+    }
+  }
+}
+#endif

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to