================
@@ -5965,6 +5967,263 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema 
&SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+// Next two functions, cloneAssociatedStmt() &
+// replaceWithNewTraitsOrDirectCall(), are for transforming the call traits.
+// e.g.
+// #pragma omp declare variant(foo_variant_dispatch) 
match(construct={dispatch})
+// #pragma omp declare variant(foo_variant_allCond) match(user={condition(1)})
+// ..
+//     #pragma omp dispatch nocontext(cond_true)
+//         foo(i, j);
+// is changed to:
+// if (cond_true) {
+//    foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
+// } else {
+//   #pragma omp dispatch
+//   foo(i,j)  // with traits: CodeGen call to foo_variant_dispatch(i,j)
+// }
+//
+// The next 2 functions, are for:
+// if (cond_true) {
+//   foo(i,j) // with traits: runtime call to foo_variant_allCond(i,j)
+// }
+//
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr 
*,
+                                              SemaOpenMP *, bool);
+
+static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP,
+                                      SemaOpenMP *SemaPtr, bool NoContext) {
+  if (CapturedStmt *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) {
+    CapturedDecl *CDecl = AssocStmt->getCapturedDecl();
+    Stmt *AssocExprStmt = AssocStmt->getCapturedStmt();
+    Expr *AssocExpr = dyn_cast<Expr>(AssocExprStmt);
+    Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall(
+        Context, AssocExpr, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr));
+    for (unsigned i = 0; i < CDecl->getNumParams(); ++i) {
+      if (i != CDecl->getContextParamPosition())
+        NewDecl->setParam(i, CDecl->getParam(i));
+      else
+        NewDecl->setContextParam(i, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (auto capture : AssocStmt->captures())
+      Captures.push_back(capture);
+    for (auto capture_init : AssocStmt->capture_inits())
+      CaptureInits.push_back(capture_init);
+    CapturedStmt *NewStmt = CapturedStmt::Create(
+        Context, AssocStmt->getCapturedStmt(),
+        AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl()));
+
+    return NewStmt;
+  }
+  return static_cast<Stmt *>(nullptr);
+}
+
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context,
+                                              Expr *AssocExpr,
+                                              SemaOpenMP *SemaPtr,
+                                              bool NoContext) {
+  BinaryOperator *BinaryCopyOpr = nullptr;
+  bool IsBinaryOp = false;
+  Expr *PseudoObjExprOrCall = AssocExpr;
+  if (BinaryOperator *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) {
+    IsBinaryOp = true;
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(),
+        BinOprExpr->getOpcode(), BinOprExpr->getType(),
+        BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(),
+        BinOprExpr->getOperatorLoc(), FPOptionsOverride());
+    PseudoObjExprOrCall = BinaryCopyOpr->getRHS();
+  }
+
+  Expr *CallWithoutInvariants = PseudoObjExprOrCall;
+  // Change PseudoObjectExpr to a direct call
+  if (PseudoObjectExpr *PseudoObjExpr =
+          dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall))
+    CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1);
+
+  Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause
+  if (NoContext) {
+    // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall()
+    CallExpr *CallExprWithinStmt = dyn_cast<CallExpr>(CallWithoutInvariants);
+    int NumArgs = CallExprWithinStmt->getNumArgs();
+    clang::Expr **Args = CallExprWithinStmt->getArgs();
+    // ActOnOpenMPCall() adds traits to a simple function call
+    // e.g. invariant function call traits to "foo(i,j)", if they are present.
+    ExprResult ER = SemaPtr->ActOnOpenMPCall(
+        CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(),
+        CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs),
+        CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr));
+    FinalCall = ER.get();
+  }
+
+  if (IsBinaryOp) {
+    BinaryCopyOpr->setRHS(FinalCall);
+    return BinaryCopyOpr;
+  }
+
+  return FinalCall;
+}
+
+static StmtResult combine2Stmts(ASTContext &Context, Stmt *FirstStmt,
+                                Stmt *SecondStmt) {
+
+  llvm::SmallVector<Stmt *, 2> NewCombinedStmtVector;
+  NewCombinedStmtVector.push_back(FirstStmt);
+  NewCombinedStmtVector.push_back(SecondStmt);
+  CompoundStmt *CombinedStmt = CompoundStmt::Create(
----------------
alexey-bataev wrote:

`auto *CombinedStmt = CompoundStmt::Create(`

https://github.com/llvm/llvm-project/pull/117904
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to