https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/141142
>From a45dc43315631f28ced9cf5a14890e46e011e6d2 Mon Sep 17 00:00:00 2001 From: Joseph Huber <hube...@outlook.com> Date: Thu, 22 May 2025 16:21:34 -0500 Subject: [PATCH] [OpenMP] Fix atomic compare handling with overloaded operators Summary: When there are overloaded C++ operators in the global namespace the AST node for these is not a `BinaryExpr` but a `CXXOperatorCallExpr`. Modify the uses to handle this case, basically just treating it as a binary expression with two arguments. --- clang/lib/Sema/SemaOpenMP.cpp | 320 ++++++++++++++++++-------- clang/test/OpenMP/atomic_messages.cpp | 31 +++ 2 files changed, 249 insertions(+), 102 deletions(-) diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index f16f841d62edd..a0ad814c366d8 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -11762,52 +11762,98 @@ bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S, X = BO->getLHS(); - auto *Cond = dyn_cast<BinaryOperator>(S->getCond()); - if (!Cond) { - ErrorInfo.Error = ErrorTy::NotABinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); - return false; - } - - switch (Cond->getOpcode()) { - case BO_EQ: { - C = Cond; - D = BO->getRHS(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); - } else { - ErrorInfo.Error = ErrorTy::InvalidComparison; + if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) { + switch (Cond->getOpcode()) { + case BO_EQ: { + C = Cond; + D = BO->getRHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + break; + } + case BO_LT: + case BO_GT: { + E = BO->getRHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { + C = Cond; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + C = Cond; + IsXBinopExpr = false; + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + break; + } + default: + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); return false; } - break; - } - case BO_LT: - case BO_GT: { - E = BO->getRHS(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { - C = Cond; - } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - C = Cond; - IsXBinopExpr = false; - } else { - ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) { + if (Call->getNumArgs() != 2) { + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); return false; } - break; - } - default: - ErrorInfo.Error = ErrorTy::InvalidBinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + switch (Call->getOperator()) { + case clang::OverloadedOperatorKind::OO_EqualEqual: { + C = Call; + D = BO->getLHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) { + E = Call->getArg(1); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) { + E = Call->getArg(0); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + break; + } + case clang::OverloadedOperatorKind::OO_Greater: + case clang::OverloadedOperatorKind::OO_Less: { + E = BO->getRHS(); + if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) && + checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) { + C = Call; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) && + checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) { + C = Call; + IsXBinopExpr = false; + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + break; + } + default: + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + } else { + ErrorInfo.Error = ErrorTy::NotABinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } @@ -11856,53 +11902,99 @@ bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S, return false; } - auto *Cond = dyn_cast<BinaryOperator>(CO->getCond()); - if (!Cond) { - ErrorInfo.Error = ErrorTy::NotABinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = - CO->getCond()->getSourceRange(); - return false; - } - - switch (Cond->getOpcode()) { - case BO_EQ: { - C = Cond; - D = CO->getTrueExpr(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); - } else { - ErrorInfo.Error = ErrorTy::InvalidComparison; + if (auto *Cond = dyn_cast<BinaryOperator>(CO->getCond())) { + switch (Cond->getOpcode()) { + case BO_EQ: { + C = Cond; + D = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + break; + } + case BO_LT: + case BO_GT: { + E = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { + C = Cond; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && + checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + C = Cond; + IsXBinopExpr = false; + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + break; + } + default: + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); return false; } - break; - } - case BO_LT: - case BO_GT: { - E = CO->getTrueExpr(); - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) { - C = Cond; - } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) && - checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - C = Cond; - IsXBinopExpr = false; - } else { - ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(CO->getCond())) { + if (Call->getNumArgs() != 2) { + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); return false; } - break; - } - default: - ErrorInfo.Error = ErrorTy::InvalidBinaryOp; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + switch (Call->getOperator()) { + case clang::OverloadedOperatorKind::OO_EqualEqual: { + C = Call; + D = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) { + E = Call->getArg(1); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) { + E = Call->getArg(0); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + break; + } + case clang::OverloadedOperatorKind::OO_Less: + case clang::OverloadedOperatorKind::OO_Greater: { + E = CO->getTrueExpr(); + if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0)) && + checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(1))) { + C = Call; + } else if (checkIfTwoExprsAreSame(ContextRef, E, Call->getArg(0)) && + checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) { + C = Call; + IsXBinopExpr = false; + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + break; + } + default: + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + } else { + ErrorInfo.Error = ErrorTy::NotABinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = + CO->getCond()->getSourceRange(); return false; } @@ -12062,32 +12154,56 @@ bool OpenMPAtomicCompareCaptureChecker::checkForm3(IfStmt *S, X = BO->getLHS(); D = BO->getRHS(); - auto *Cond = dyn_cast<BinaryOperator>(S->getCond()); - if (!Cond) { + if (auto *Cond = dyn_cast<BinaryOperator>(S->getCond())) { + C = Cond; + if (Cond->getOpcode() != BO_EQ) { + ErrorInfo.Error = ErrorTy::NotEQ; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + + if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { + E = Cond->getRHS(); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { + E = Cond->getLHS(); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); + return false; + } + } else if (auto *Call = dyn_cast<CXXOperatorCallExpr>(S->getCond())) { + C = Call; + if (Call->getNumArgs() != 2) { + ErrorInfo.Error = ErrorTy::InvalidBinaryOp; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + if (Call->getOperator() != clang::OverloadedOperatorKind::OO_EqualEqual) { + ErrorInfo.Error = ErrorTy::NotEQ; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + + if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(0))) { + E = Call->getArg(1); + } else if (checkIfTwoExprsAreSame(ContextRef, X, Call->getArg(1))) { + E = Call->getArg(0); + } else { + ErrorInfo.Error = ErrorTy::InvalidComparison; + ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Call->getExprLoc(); + ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Call->getSourceRange(); + return false; + } + } else { ErrorInfo.Error = ErrorTy::NotABinaryOp; ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc(); ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange(); return false; } - if (Cond->getOpcode() != BO_EQ) { - ErrorInfo.Error = ErrorTy::NotEQ; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); - return false; - } - - if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) { - E = Cond->getRHS(); - } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) { - E = Cond->getLHS(); - } else { - ErrorInfo.Error = ErrorTy::InvalidComparison; - ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc(); - ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange(); - return false; - } - - C = Cond; if (!S->getElse()) { ErrorInfo.Error = ErrorTy::NoElse; diff --git a/clang/test/OpenMP/atomic_messages.cpp b/clang/test/OpenMP/atomic_messages.cpp index d492f6ee1e896..c4e240a0ebb4e 100644 --- a/clang/test/OpenMP/atomic_messages.cpp +++ b/clang/test/OpenMP/atomic_messages.cpp @@ -991,3 +991,34 @@ int mixed() { // expected-note@+1 {{in instantiation of function template specialization 'mixed<int>' requested here}} return mixed<int>(); } + +#ifdef OMP51 +struct U {}; +struct U operator<(U, U); +struct U operator>(U, U); +struct U operator==(U, U); + +template <typename T> void templated() { + T cx, cv, ce, cd; +#pragma omp atomic compare capture + if (cx == ce) { + cx = cd; + } else { + cv = cx; + } +#pragma omp atomic compare capture + { + cv = cx; + if (ce > cx) { + cx = ce; + } + } +#pragma omp atomic compare capture + { + cv = cx; + if (cx < ce) { + cx = ce; + } + } +} +#endif _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits