tianshilei1992 updated this revision to Diff 404601.
tianshilei1992 added a comment.

rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D116637/new/

https://reviews.llvm.org/D116637

Files:
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/lib/Sema/SemaOpenMP.cpp

Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -10916,6 +10916,360 @@
   }
   return ErrorFound != NoError;
 }
+
+/// Get the node id of the fixed point of an expression \a S.
+llvm::FoldingSetNodeID getNodeId(ASTContext &Context, const Expr *S) {
+  llvm::FoldingSetNodeID Id;
+  S->IgnoreParenImpCasts()->Profile(Id, Context, true);
+  return Id;
+}
+
+/// Check if two expressions are same.
+bool checkIfTwoExprsAreSame(ASTContext &Context, const Expr *LHS,
+                            const Expr *RHS) {
+  return getNodeId(Context, LHS) == getNodeId(Context, RHS);
+}
+
+class OpenMPAtomicCompareChecker {
+public:
+  /// All kinds of errors that can occur in `atomic compare`
+  enum ErrorTy {
+    /// Empty compound statement.
+    NoStmt = 0,
+    /// More than one statement in a compound statement.
+    MoreThanOneStmt,
+    /// Not an assignment binary operator.
+    NotAnAssignment,
+    /// Not a conditional operator.
+    NotCondOp,
+    /// Wrong false expr. According to the spec, 'x' should be at the false
+    /// expression of a conditional expression.
+    WrongFalseExpr,
+    /// The condition of a conditional expression is not a binary operator.
+    NotABinaryOp,
+    /// Invalid binary operator (not <, >, or ==).
+    InvalidBinaryOp,
+    /// Invalid comparison (not x == e, e == x, x ordop expr, or expr ordop x).
+    InvalidComparison,
+    /// X is not a lvalue.
+    XNotLValue,
+    /// Not a scalar integer.
+    NotScalarInteger,
+    /// No error.
+    NoError,
+  };
+
+  struct ErrorInfoTy {
+    ErrorTy Error;
+    SourceLocation ErrorLoc;
+    SourceRange ErrorRange;
+    SourceLocation NoteLoc;
+    SourceRange NoteRange;
+  };
+
+  OpenMPAtomicCompareChecker(Sema &S)
+      : SemaRef(S), ContextRef(S.getASTContext()) {}
+
+  /// Check if statement \a S is valid for <tt>atomic compare</tt>.
+  bool checkStmt(Stmt *S, ErrorInfoTy &ErrorInfo);
+
+  Expr *getX() const { return X; }
+  Expr *getE() const { return E; }
+  Expr *getD() const { return D; }
+  Expr *getCond() const { return C; }
+  bool isSignedOp() const { return IsSignedOp; }
+  bool isXBinopExpr() const { return IsXBinopExpr; }
+
+private:
+  /// Reference to Sema.
+  Sema &SemaRef;
+  /// Reference to ASTContext
+  ASTContext &ContextRef;
+  /// 'x' lvalue part of the source atomic expression.
+  Expr *X = nullptr;
+  /// 'expr' or 'e' rvalue part of the source atomic expression.
+  Expr *E = nullptr;
+  /// 'd' rvalue part of the source atomic expression.
+  Expr *D = nullptr;
+  /// 'cond' part of the source atomic expression. It is in one of the following
+  /// forms:
+  /// expr ordop x
+  /// x ordop expr
+  /// x == e
+  /// e == x
+  Expr *C = nullptr;
+  /// True if the comparison operation is signed.
+  bool IsSignedOp = true;
+  /// True if the cond expr is in the form of 'x ordop expr'.
+  bool IsXBinopExpr = true;
+
+  /// Check if it is a valid conditional update statement (cond-update-stmt).
+  bool checkCondUpdateStmt(IfStmt *S, ErrorInfoTy &ErrorInfo);
+
+  /// Check if it is a valid conditional expression statement (cond-expr-stmt).
+  bool checkCondExprStmt(Stmt *S, ErrorInfoTy &ErrorInfo);
+
+  /// Check if all captured values have right type.
+  bool checkType(ErrorInfoTy &ErrorInfo) const;
+};
+
+bool OpenMPAtomicCompareChecker::checkCondUpdateStmt(IfStmt *S,
+                                                     ErrorInfoTy &ErrorInfo) {
+  auto *Then = S->getThen();
+  if (auto *CS = dyn_cast<CompoundStmt>(Then)) {
+    if (CS->size() == 0) {
+      ErrorInfo.Error = ErrorTy::NoStmt;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+      return false;
+    }
+    if (CS->size() > 1) {
+      ErrorInfo.Error = ErrorTy::MoreThanOneStmt;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange();
+      return false;
+    }
+    Then = CS->body_front();
+  }
+
+  auto *BO = dyn_cast<BinaryOperator>(Then);
+  if (!BO) {
+    ErrorInfo.Error = ErrorTy::NotAnAssignment;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Then->getBeginLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Then->getSourceRange();
+    return false;
+  }
+  if (BO->getOpcode() != BO_Assign) {
+    ErrorInfo.Error = ErrorTy::NotAnAssignment;
+    ErrorInfo.ErrorLoc = BO->getExprLoc();
+    ErrorInfo.NoteLoc = BO->getOperatorLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange();
+    return false;
+  }
+
+  X = BO->getLHS();
+  IsSignedOp = X->getType()->hasSignedIntegerRepresentation();
+
+  auto *Cond = dyn_cast<BinaryOperator>(S->getCond());
+  if (!Cond) {
+    ErrorInfo.Error = ErrorTy::NotABinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getCond()->getSourceRange();
+    return false;
+  }
+  if (Cond->getOpcode() != BO_EQ && Cond->getOpcode() != BO_LT &&
+      Cond->getOpcode() != BO_GT) {
+    ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    return false;
+  }
+
+  if (Cond->getOpcode() == BO_EQ) {
+    C = Cond;
+    D = BO->getRHS();
+    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+      E = Cond->getRHS();
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+      E = Cond->getLHS();
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+  } else {
+    E = BO->getRHS();
+    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+      C = Cond;
+    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+      C = Cond;
+      IsXBinopExpr = false;
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool OpenMPAtomicCompareChecker::checkCondExprStmt(Stmt *S,
+                                                   ErrorInfoTy &ErrorInfo) {
+  auto *BO = dyn_cast<BinaryOperator>(S);
+  if (!BO) {
+    ErrorInfo.Error = ErrorTy::NotAnAssignment;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = S->getBeginLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = S->getSourceRange();
+    return false;
+  }
+  if (BO->getOpcode() != BO_Assign) {
+    ErrorInfo.Error = ErrorTy::NotAnAssignment;
+    ErrorInfo.ErrorLoc = BO->getExprLoc();
+    ErrorInfo.NoteLoc = BO->getOperatorLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getSourceRange();
+    return false;
+  }
+
+  X = BO->getLHS();
+  IsSignedOp = X->getType()->hasSignedIntegerRepresentation();
+
+  auto *CO = dyn_cast<ConditionalOperator>(BO->getRHS()->IgnoreParenImpCasts());
+  if (!CO) {
+    ErrorInfo.Error = ErrorTy::NotCondOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = BO->getRHS()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = BO->getRHS()->getSourceRange();
+    return false;
+  }
+
+  if (!checkIfTwoExprsAreSame(ContextRef, X, CO->getFalseExpr())) {
+    ErrorInfo.Error = ErrorTy::WrongFalseExpr;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getFalseExpr()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+        CO->getFalseExpr()->getSourceRange();
+    return false;
+  }
+
+  auto *Cond = dyn_cast<BinaryOperator>(CO->getCond());
+  if (!Cond) {
+    ErrorInfo.Error = ErrorTy::NotABinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CO->getCond()->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange =
+        CO->getCond()->getSourceRange();
+    return false;
+  }
+
+  if (Cond->getOpcode() != BO_EQ && Cond->getOpcode() != BO_LT &&
+      Cond->getOpcode() != BO_GT) {
+    ErrorInfo.Error = ErrorTy::InvalidBinaryOp;
+    ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+    ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+    return false;
+  }
+
+  if (Cond->getOpcode() == BO_EQ) {
+    C = Cond;
+    D = CO->getTrueExpr();
+    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS())) {
+      E = Cond->getRHS();
+    } else if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+      E = Cond->getLHS();
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+  } else {
+    E = CO->getTrueExpr();
+    if (checkIfTwoExprsAreSame(ContextRef, X, Cond->getLHS()) &&
+        checkIfTwoExprsAreSame(ContextRef, E, Cond->getRHS())) {
+      C = Cond;
+    } else if (checkIfTwoExprsAreSame(ContextRef, E, Cond->getLHS()) &&
+               checkIfTwoExprsAreSame(ContextRef, X, Cond->getRHS())) {
+      C = Cond;
+      IsXBinopExpr = false;
+    } else {
+      ErrorInfo.Error = ErrorTy::InvalidComparison;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = Cond->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = Cond->getSourceRange();
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool OpenMPAtomicCompareChecker::checkType(ErrorInfoTy &ErrorInfo) const {
+  // 'x' and 'e' cannot be nullptr
+  assert(X && E && "X and E cannot be nullptr");
+
+  {
+    if (!X->isLValue()) {
+      ErrorInfo.Error = ErrorTy::XNotLValue;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = X->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = X->getSourceRange();
+      return false;
+    }
+
+    auto Type = X->getType();
+    if (!Type->isScalarType() || !Type->isIntegerType()) {
+      ErrorInfo.Error = ErrorTy::NotScalarInteger;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = X->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = X->getSourceRange();
+      return false;
+    }
+  }
+
+  {
+    auto Type = E->getType();
+    if (!Type->isScalarType() || !Type->isIntegerType()) {
+      ErrorInfo.Error = ErrorTy::NotScalarInteger;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = E->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = E->getSourceRange();
+      return false;
+    }
+  }
+
+  if (D) {
+    auto Type = D->getType();
+    if (!Type->isScalarType() || !Type->isIntegerType()) {
+      ErrorInfo.Error = ErrorTy::NotScalarInteger;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = D->getExprLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = D->getSourceRange();
+      return false;
+    }
+  }
+
+  return true;
+}
+
+bool OpenMPAtomicCompareChecker::checkStmt(
+    Stmt *S, OpenMPAtomicCompareChecker::ErrorInfoTy &ErrorInfo) {
+  auto *CS = dyn_cast<CompoundStmt>(S);
+  if (CS) {
+    if (CS->size() == 0) {
+      ErrorInfo.Error = ErrorTy::NoStmt;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+      return false;
+    }
+
+    if (CS->size() != 1) {
+      ErrorInfo.Error = ErrorTy::MoreThanOneStmt;
+      ErrorInfo.ErrorLoc = ErrorInfo.NoteLoc = CS->getBeginLoc();
+      ErrorInfo.ErrorRange = ErrorInfo.NoteRange = CS->getSourceRange();
+      return false;
+    }
+    S = CS->body_front();
+  }
+
+  auto Res = false;
+
+  if (auto *IS = dyn_cast<IfStmt>(S)) {
+    // Check if the statement is in one of the following forms
+    // (cond-update-stmt):
+    // if (expr ordop x) { x = expr; }
+    // if (x ordop expr) { x = expr; }
+    // if (x == e) { x = d; }
+    Res = checkCondUpdateStmt(IS, ErrorInfo);
+  } else {
+    // Check if the statement is in one of the following forms (cond-expr-stmt):
+    // x = expr ordop x ? expr : x;
+    // x = x ordop expr ? expr : x;
+    // x = x == e ? d : x;
+    Res = checkCondExprStmt(S, ErrorInfo);
+  }
+
+  if (!Res)
+    return false;
+
+  return checkType(ErrorInfo);
+}
 } // namespace
 
 StmtResult Sema::ActOnOpenMPAtomicDirective(ArrayRef<OMPClause *> Clauses,
@@ -11396,6 +11750,15 @@
     if (CurContext->isDependentContext())
       UE = V = E = X = nullptr;
   } else if (AtomicKind == OMPC_compare) {
+    OpenMPAtomicCompareChecker::ErrorInfoTy ErrorInfo;
+    OpenMPAtomicCompareChecker Checker(*this);
+    if (!Checker.checkStmt(Body, ErrorInfo)) {
+      Diag(ErrorInfo.ErrorLoc, diag::err_omp_atomic_compare)
+          << ErrorInfo.ErrorRange;
+      Diag(ErrorInfo.NoteLoc, diag::note_omp_atomic_compare)
+          << ErrorInfo.Error << ErrorInfo.NoteRange;
+      return StmtError();
+    }
     // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore
     // code gen.
     unsigned DiagID = Diags.getCustomDiagID(
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10519,6 +10519,15 @@
   " where x is an lvalue expression with scalar type">;
 def note_omp_atomic_capture: Note<
   "%select{expected assignment expression|expected compound statement|expected exactly two expression statements|expected in right hand side of the first expression}0">;
+def err_omp_atomic_compare : Error<
+  "the statement for 'atomic compare' must be a compound statement of form '{x = expr ordop x ? expr : x;}', '{x = x ordop expr? expr : x;}',"
+  " '{x = x == e ? d : x;}', '{x = e == x ? d : x;}', or 'if(expr ordop x) {x = expr;}', 'if(x ordop expr) {x = expr;}', 'if(x == e) {x = d;}',"
+  " 'if(e == x) {x = d;}' where 'x' is an lvalue expression with scalar type, 'expr', 'e', and 'd' are expressions with scalar type,"
+  " and 'ordop' is one of '<' or '>'.">;
+def note_omp_atomic_compare: Note<
+  "%select{expected compound statement|expected exactly one expression statement|expected assignment statement|expected conditional operator|expect result value to be at false expression|"
+  "expect binary operator in conditional expression|expect '<', '>' or '==' as order operator|expect comparison in a form of 'x == e', 'e == x', 'x ordop expr', or 'expr ordop x'|"
+  "expect lvalue for result value|expect scalar integer value}0">;
 def err_omp_atomic_several_clauses : Error<
   "directive '#pragma omp atomic' cannot contain more than one 'read', 'write', 'update', 'capture', or 'compare' clause">;
 def err_omp_several_mem_order_clauses : Error<
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to