tianshilei1992 updated this revision to Diff 405379.
tianshilei1992 added a comment.

rebase


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D118632/new/

https://reviews.llvm.org/D118632

Files:
  clang/include/clang/AST/StmtOpenMP.h
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp
  llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Index: llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1357,7 +1357,6 @@
   /// if (e == x) { x = d; } (this one is not in the spec)
   ///
   /// \param Loc          The insert and source location description.
-  /// \param AllocIP      Instruction to create AllocaInst before.
   /// \param X            The target atomic pointer to be updated.
   /// \param E            The expected value ('e') for forms that use an
   ///                     equality comparison or an expression ('expr') for
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -11372,6 +11372,8 @@
   Expr *V = nullptr;
   Expr *E = nullptr;
   Expr *UE = nullptr;
+  Expr *D = nullptr;
+  Expr *CE = nullptr;
   bool IsXLHSInRHSPart = false;
   bool IsPostfixUpdate = false;
   // OpenMP [2.12.6, atomic Construct]
@@ -11768,17 +11770,18 @@
           << ErrorInfo.Error << ErrorInfo.NoteRange;
       return StmtError();
     }
-    // TODO: For now we emit an error here and in emitOMPAtomicExpr we ignore
-    // code gen.
-    unsigned DiagID = Diags.getCustomDiagID(
-        DiagnosticsEngine::Error, "atomic compare is not supported for now");
-    Diag(AtomicKindLoc, DiagID);
+    X = Checker.getX();
+    E = Checker.getE();
+    D = Checker.getD();
+    CE = Checker.getCond();
+    // We reuse this bool variable to tell if it is in the form 'x ordop expr'.
+    IsXLHSInRHSPart = Checker.isXBinopExpr();
   }
 
   setFunctionHasBranchProtectedScope();
 
   return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    X, V, E, UE, IsXLHSInRHSPart,
+                                    X, V, E, UE, D, CE, IsXLHSInRHSPart,
                                     IsPostfixUpdate);
 }
 
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6011,11 +6011,52 @@
   }
 }
 
+static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,
+                                     llvm::AtomicOrdering AO, const Expr *X,
+                                     const Expr *E, const Expr *D,
+                                     const Expr *CE, bool IsXBinopExpr,
+                                     SourceLocation Loc) {
+
+  llvm::OpenMPIRBuilder &OMPBuilder =
+      CGF.CGM.getOpenMPRuntime().getOMPBuilder();
+  // llvm::OpenMPIRBuilder::InsertPointTy AllocaIP(
+  //     CGF.AllocaInsertPt->getParent(), CGF.AllocaInsertPt->getIterator());
+
+  OMPAtomicCompareOp Op;
+  assert(isa<BinaryOperator>(CE) && "CE is not a BinaryOperator");
+  switch (cast<BinaryOperator>(CE)->getOpcode()) {
+  case BO_EQ:
+    Op = OMPAtomicCompareOp::EQ;
+    break;
+  case BO_LT:
+    Op = OMPAtomicCompareOp::MIN;
+    break;
+  case BO_GT:
+    Op = OMPAtomicCompareOp::MAX;
+    break;
+  default:
+    llvm_unreachable("unsupported atomic compare binary operator");
+  }
+
+  LValue XLVal = CGF.EmitLValue(X);
+  llvm::Value *XPtr = XLVal.getPointer(CGF);
+  llvm::Value *EVal = CGF.EmitScalarExpr(E);
+  llvm::Value *DVal = D ? CGF.EmitScalarExpr(D) : nullptr;
+
+  llvm::OpenMPIRBuilder::AtomicOpValue XOpVal{
+      XPtr, XPtr->getType()->getPointerElementType(),
+      X->getType().isVolatileQualified(),
+      X->getType()->hasSignedIntegerRepresentation()};
+
+  CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare(
+      CGF.Builder, XOpVal, EVal, DVal, AO, Op, IsXBinopExpr));
+}
+
 static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
                               llvm::AtomicOrdering AO, bool IsPostfixUpdate,
                               const Expr *X, const Expr *V, const Expr *E,
-                              const Expr *UE, bool IsXLHSInRHSPart,
-                              SourceLocation Loc) {
+                              const Expr *UE, const Expr *D, const Expr *CE,
+                              bool IsXLHSInRHSPart, SourceLocation Loc) {
   switch (Kind) {
   case OMPC_read:
     emitOMPAtomicReadExpr(CGF, AO, X, V, Loc);
@@ -6032,7 +6073,7 @@
                              IsXLHSInRHSPart, Loc);
     break;
   case OMPC_compare:
-    // Do nothing here as we already emit an error.
+    emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc);
     break;
   case OMPC_if:
   case OMPC_final:
@@ -6178,8 +6219,8 @@
   LexicalScope Scope(*this, S.getSourceRange());
   EmitStopPoint(S.getAssociatedStmt());
   emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(),
-                    S.getExpr(), S.getUpdateExpr(), S.isXLHSInRHSPart(),
-                    S.getBeginLoc());
+                    S.getExpr(), S.getUpdateExpr(), S.getD(), S.getCondExpr(),
+                    S.isXLHSInRHSPart(), S.getBeginLoc());
 }
 
 static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -863,16 +863,20 @@
                                                    !IsStandalone);
 }
 
-OMPAtomicDirective *OMPAtomicDirective::Create(
-    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
-    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-    Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate) {
+OMPAtomicDirective *
+OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc,
+                           SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
+                           Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E,
+                           Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
+                           bool IsPostfixUpdate) {
   auto *Dir = createDirective<OMPAtomicDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/4, StartLoc, EndLoc);
+      C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc);
   Dir->setX(X);
   Dir->setV(V);
   Dir->setExpr(E);
   Dir->setUpdateExpr(UE);
+  Dir->setD(D);
+  Dir->setCond(Cond);
   Dir->IsXLHSInRHSPart = IsXLHSInRHSPart;
   Dir->IsPostfixUpdate = IsPostfixUpdate;
   return Dir;
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -2863,6 +2863,8 @@
     POS_V,
     POS_E,
     POS_UpdateExpr,
+    POS_D,
+    POS_Cond,
   };
 
   /// Set 'x' part of the associated expression/statement.
@@ -2877,6 +2879,10 @@
   void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
   /// Set 'expr' part of the associated expression/statement.
   void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
+  /// Set 'd' part of the associated expression/statement.
+  void setD(Expr *D) { Data->getChildren()[DataPositionTy::POS_D] = D; }
+  /// Set conditional expression in `atomic compare`.
+  void setCond(Expr *C) { Data->getChildren()[DataPositionTy::POS_Cond] = C; }
 
 public:
   /// Creates directive with a list of \a Clauses and 'x', 'v' and 'expr'
@@ -2894,6 +2900,8 @@
   /// \param UE Helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
+  /// \param D 'd' part of the associated expression/statement.
+  /// \param Cond Conditional expression in `atomic compare` construct.
   /// \param IsXLHSInRHSPart true if \a UE has the first form and false if the
   /// second.
   /// \param IsPostfixUpdate true if original value of 'x' must be stored in
@@ -2901,7 +2909,8 @@
   static OMPAtomicDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-         Expr *E, Expr *UE, bool IsXLHSInRHSPart, bool IsPostfixUpdate);
+         Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
+         bool IsPostfixUpdate);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -2951,6 +2960,20 @@
   const Expr *getExpr() const {
     return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
   }
+  /// Get 'd' part of the associated expression/statement.
+  Expr *getD() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  Expr *getD() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_D]);
+  }
+  /// Get
+  Expr *getCondExpr() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_Cond]);
+  }
+  Expr *getCondExpr() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_Cond]);
+  }
 
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == OMPAtomicDirectiveClass;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to