https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/141142
>From f2c18ba64744320a8e2a63938b17137a1b6e74d7 Mon Sep 17 00:00:00 2001 From: Joseph Huber <hube...@outlook.com> Date: Thu, 22 May 2025 16:21:34 -0500 Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded operators Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. --- clang/lib/Sema/SemaOpenMP.cpp | 162 +++++++++++++++----------- clang/test/OpenMP/atomic_messages.cpp | 31 +++++ 2 files changed, 126 insertions(+), 67 deletions(-) diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index f16f841d62edd..8d580d1968238 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -11763,51 +11763,61 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S, X = BO->getLHS(); auto *Cond = dyn_cast<BinaryOperator>(S->getCond()); - if (!Cond) { + auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond()); + Expr *LHS = nullptr; + Expr *RHS = nullptr; + if (Cond) { + LHS = Cond->getLHS(); + RHS = Cond->getRHS(); + } else if (Call) { + LHS = Call->getArg(0); + RHS = Call->getArg(1); + } else { ErrorInfo.Error = ErrorTy::NotABinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } - switch (Cond->getOpcode()) { - case BO_EQ: { - C = Cond; + if ((Cond && Cond->getOpcode() == BO_EQ) || + (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) { + C = S->getCond(); D = BO->getRHS(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) { + E = RHS; + } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) { + E = LHS; } else { ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + S->getCond()->getSourceRange(); return false; } - break; - } - case BO_LT: - case BO_GT: { + } else if ((Cond && + (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) || + (Call && + (Call->getOperator() == OverloadedOperatorKind::OO_Less || + Call->getOperator() == OverloadedOperatorKind::OO_Greater))) { E = BO->getRHS(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { - C = Cond; - } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - C = Cond; + if (checkIfTwoExprsAreSame(ContextRef, X, LHS) && + checkIfTwoExprsAreSame(ContextRef, E, RHS)) { + C = S->getCond(); + } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) && + checkIfTwoExprsAreSame(ContextRef, X, RHS)) { + C = S->getCond(); IsXBinopExpr = false; } else { ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + S->getCond()->getSourceRange(); return false; } - break; - } - default: + } else { ErrorInfo.Error = ErrorTy::InvalidBinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } @@ -11857,52 +11867,60 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S, } auto *Cond = dyn_cast<BinaryOperator>(CO->getCond()); - if (!Cond) { + auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond()); + Expr *LHS = nullptr; + Expr *RHS = nullptr; + if (Cond) { + LHS = Cond->getLHS(); + RHS = Cond->getRHS(); + } else if (Call) { + LHS = Call->getArg(0); + RHS = Call->getArg(1); + } else { ErrorInfo.Error = ErrorTy::NotABinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = - CO->getCond()->getSourceRange(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange(); return false; } - switch (Cond->getOpcode()) { - case BO_EQ: { - C = Cond; + if ((Cond && Cond->getOpcode() == BO_EQ) || + (Call && Call->getOperator() == OverloadedOperatorKind::OO_EqualEqual)) { + C = CO->getCond(); D = CO->getTrueExpr(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) { + E = RHS; + } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) { + E = LHS; } else { ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange(); return false; } - break; - } - case BO_LT: - case BO_GT: { + } else if ((Cond && + (Cond->getOpcode() == BO_LT || Cond->getOpcode() == BO_GT)) || + (Call && + (Call->getOperator() == OverloadedOperatorKind::OO_Less || + Call->getOperator() == OverloadedOperatorKind::OO_Greater))) { + E = CO->getTrueExpr(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { - C = Cond; - } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - C = Cond; + if (checkIfTwoExprsAreSame(ContextRef, X, LHS) && + checkIfTwoExprsAreSame(ContextRef, E, RHS)) { + C = CO->getCond(); + } else if (checkIfTwoExprsAreSame(ContextRef, E, LHS) && + checkIfTwoExprsAreSame(ContextRef, X, RHS)) { + C = CO->getCond(); IsXBinopExpr = false; } else { ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange(); return false; } - break; - } - default: + } else { ErrorInfo.Error = ErrorTy::InvalidBinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CO->getCond()->getSourceRange(); return false; } @@ -12063,31 +12081,41 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S, D = BO->getRHS(); auto *Cond = dyn_cast<BinaryOperator>(S->getCond()); - if (!Cond) { + auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond()); + Expr *LHS = nullptr; + Expr *RHS = nullptr; + if (Cond) { + LHS = Cond->getLHS(); + RHS = Cond->getRHS(); + } else if (Call) { + LHS = Call->getArg(0); + RHS = Call->getArg(1); + } else { ErrorInfo.Error = ErrorTy::NotABinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } - if (Cond->getOpcode() != BO_EQ) { + if ((Cond && Cond->getOpcode() != BO_EQ) || + (Call && Call->getOperator() != OverloadedOperatorKind::OO_EqualEqual)) { ErrorInfo.Error = ErrorTy::NotEQ; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, LHS)) { + E = RHS; + } else if (checkIfTwoExprsAreSame(ContextRef, X, RHS)) { + E = LHS; } else { ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } - C = Cond; + C = S->getCond(); if (!S->getElse()) { ErrorInfo.Error = ErrorTy::NoElse; diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp index d492f6ee1e896..c4e240a0ebb4e 100644 --- a/clang/test/OpenMP/atomic_messages.cpp +++ b/clang/test/OpenMP/atomic_messages.cpp @@ -991,3 +991,34 @@ int mixed() { // expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}} return mixed<int>(); } + +#ifdef OMP51 +struct U {}; +struct U operator<(U, U); +struct U operator>(U, U); +struct U operator==(U, U); + +template <typename T> void templated() { + T cx, cv, ce, cd; +#pragma omp atomic compare capture + if (cx == ce) { + cx = cd; + } else { + cv = cx; + } +#pragma omp atomic compare capture + { + cv = cx; + if (ce > cx) { + cx = ce; + } + } +#pragma omp atomic compare capture + { + cv = cx; + if (cx < ce) { + cx = ce; + } + } +} +#endif _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits