================
@@ -5965,6 +5967,269 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema 
&SemaRef) {
   return Checker.teamsLoopCanBeParallelFor();
 }
 
+static Expr *getInitialExprFromCapturedExpr(Expr *Cond) {
+
+  Expr *SubExpr = Cond->IgnoreParenImpCasts();
+
+  if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) {
+    if (auto *CapturedExprDecl =
+            dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) {
+
+      // Retrieve the initial expression from the captured expression
+      return CapturedExprDecl->getInit();
+    }
+  }
+  return nullptr;
+}
+
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr 
*,
+                                              SemaOpenMP *, bool);
+
+/// cloneAssociatedStmt() function is for cloning the Associated Statement
+/// present with a Directive and then modifying it. By this we avoid modifying
+/// the original Associated Statement.
+static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP,
+                                      SemaOpenMP *SemaPtr, bool NoContext) {
+  StmtResult ResultAssocStmt;
+  if (auto *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) {
+    CapturedDecl *CDecl = AssocStmt->getCapturedDecl();
+    Stmt *AssocExprStmt = AssocStmt->getCapturedStmt();
+    auto *AssocExpr = dyn_cast<Expr>(AssocExprStmt);
+    Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall(
+        Context, AssocExpr, SemaPtr, NoContext);
+
+    // Copy Current Captured Decl to a New Captured Decl for noting the
+    // Annotation
+    CapturedDecl *NewDecl =
+        CapturedDecl::Create(const_cast<ASTContext &>(Context),
+                             CDecl->getDeclContext(), CDecl->getNumParams());
+    NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr));
+    for (unsigned I : llvm::seq<unsigned>(CDecl->getNumParams())) {
+      if (I != CDecl->getContextParamPosition())
+        NewDecl->setParam(I, CDecl->getParam(I));
+      else
+        NewDecl->setContextParam(I, CDecl->getContextParam());
+    }
+
+    // Create a New Captured Stmt containing the New Captured Decl
+    SmallVector<CapturedStmt::Capture, 4> Captures;
+    SmallVector<Expr *, 4> CaptureInits;
+    for (const CapturedStmt::Capture &Capture : AssocStmt->captures())
+      Captures.push_back(Capture);
+    for (Expr *CaptureInit : AssocStmt->capture_inits())
+      CaptureInits.push_back(CaptureInit);
+    auto *NewStmt = CapturedStmt::Create(
+        Context, AssocStmt->getCapturedStmt(),
+        AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl,
+        const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl()));
+
+    ResultAssocStmt = NewStmt;
+  }
+  return ResultAssocStmt;
+}
+
+/// replaceWithNewTraitsOrDirectCall() is for transforming the call traits.
+/// Call traits associated with a function call are removed and replaced with
+/// a direct call. For clause "nocontext" only, the direct call is then
+/// modified to have call traits for a non-dispatch variant.
+static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context,
+                                              Expr *AssocExpr,
+                                              SemaOpenMP *SemaPtr,
+                                              bool NoContext) {
+  BinaryOperator *BinaryCopyOpr = nullptr;
+  bool IsBinaryOp = false;
+  Expr *PseudoObjExprOrCall = AssocExpr;
+  if (auto *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) {
+    IsBinaryOp = true;
+    BinaryCopyOpr = BinaryOperator::Create(
+        Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(),
+        BinOprExpr->getOpcode(), BinOprExpr->getType(),
+        BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(),
+        BinOprExpr->getOperatorLoc(), FPOptionsOverride());
+    PseudoObjExprOrCall = BinaryCopyOpr->getRHS();
+  }
+
+  Expr *CallWithoutInvariants = PseudoObjExprOrCall;
+  // Change PseudoObjectExpr to a direct call
+  if (auto *PseudoObjExpr = dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall))
+    CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1);
+
+  Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause
+  if (NoContext) {
+    // example to explain the changes done for "nocontext" clause:
+    //
+    // #pragma omp declare variant(foo_variant_dispatch)
+    //                                  match(construct = {dispatch})
+    // #pragma omp declare variant(foo_variant_allCond)
+    //                                 match(user = {condition(1)})
+    // ...
+    //     #pragma omp dispatch nocontext(cond_true)
+    //         foo(i, j); // with traits: CodeGen call to
+    //         foo_variant_dispatch(i,j)
+    // dispatch construct is changed to:
+    // if (cond_true) {
+    //    foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j)
+    // } else {
+    //   #pragma omp dispatch
+    //   foo(i,j)  // with traits: CodeGen call to foo_variant_dispatch(i,j)
+    // }
----------------
alexey-bataev wrote:

I assume all these helpers can be directly codegened instead of building them 
in sema

https://github.com/llvm/llvm-project/pull/117904
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to