================ @@ -5965,6 +5967,264 @@ static bool teamsLoopCanBeParallelFor(Stmt *AStmt, Sema &SemaRef) { return Checker.teamsLoopCanBeParallelFor(); } +static Expr *getInitialExprFromCapturedExpr(Expr *Cond) { + + Expr *SubExpr = Cond->IgnoreParenImpCasts(); + + if (auto *DeclRef = dyn_cast<DeclRefExpr>(SubExpr)) { + if (auto *CapturedExprDecl = + dyn_cast<OMPCapturedExprDecl>(DeclRef->getDecl())) { + + // Retrieve the initial expression from the captured expression + return CapturedExprDecl->getInit(); + } + } + return nullptr; +} + +static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, Expr *, + SemaOpenMP *, bool); + +/// cloneAssociatedStmt() function is for cloning the Associated Statement +/// present with a Directive and then modifying it. By this we avoid modifying +/// the original Associated Statement. +static StmtResult cloneAssociatedStmt(const ASTContext &Context, Stmt *StmtP, + SemaOpenMP *SemaPtr, bool NoContext) { + StmtResult ResultAssocStmt; + if (auto *AssocStmt = dyn_cast<CapturedStmt>(StmtP)) { + CapturedDecl *CDecl = AssocStmt->getCapturedDecl(); + Stmt *AssocExprStmt = AssocStmt->getCapturedStmt(); + auto *AssocExpr = dyn_cast<Expr>(AssocExprStmt); + Expr *NewCallOrPseudoObjOrBinExpr = replaceWithNewTraitsOrDirectCall( + Context, AssocExpr, SemaPtr, NoContext); + + // Copy Current Captured Decl to a New Captured Decl for noting the + // Annotation + CapturedDecl *NewDecl = + CapturedDecl::Create(const_cast<ASTContext &>(Context), + CDecl->getDeclContext(), CDecl->getNumParams()); + NewDecl->setBody(static_cast<Stmt *>(NewCallOrPseudoObjOrBinExpr)); + for (unsigned I : llvm::seq<unsigned>(CDecl->getNumParams())) { + if (I != CDecl->getContextParamPosition()) + NewDecl->setParam(I, CDecl->getParam(I)); + else + NewDecl->setContextParam(I, CDecl->getContextParam()); + } + + // Create a New Captured Stmt containing the New Captured Decl + SmallVector<CapturedStmt::Capture, 4> Captures; + SmallVector<Expr *, 4> CaptureInits; + for (const CapturedStmt::Capture &Capture : AssocStmt->captures()) + Captures.push_back(Capture); + for (Expr *CaptureInit : AssocStmt->capture_inits()) + CaptureInits.push_back(CaptureInit); + auto *NewStmt = CapturedStmt::Create( + Context, AssocStmt->getCapturedStmt(), + AssocStmt->getCapturedRegionKind(), Captures, CaptureInits, NewDecl, + const_cast<RecordDecl *>(AssocStmt->getCapturedRecordDecl())); + + ResultAssocStmt = NewStmt; + } + return ResultAssocStmt; +} + +/// replaceWithNewTraitsOrDirectCall() is for transforming the call traits. +/// Call traits associated with a function call are removed and replaced with +/// a direct call. For clause "nocontext" only, the direct call is then +/// modified to have call traits for a non-dispatch variant. +static Expr *replaceWithNewTraitsOrDirectCall(const ASTContext &Context, + Expr *AssocExpr, + SemaOpenMP *SemaPtr, + bool NoContext) { + BinaryOperator *BinaryCopyOpr = nullptr; + bool IsBinaryOp = false; + Expr *PseudoObjExprOrCall = AssocExpr; + if (auto *BinOprExpr = dyn_cast<BinaryOperator>(AssocExpr)) { + IsBinaryOp = true; + BinaryCopyOpr = BinaryOperator::Create( + Context, BinOprExpr->getLHS(), BinOprExpr->getRHS(), + BinOprExpr->getOpcode(), BinOprExpr->getType(), + BinOprExpr->getValueKind(), BinOprExpr->getObjectKind(), + BinOprExpr->getOperatorLoc(), FPOptionsOverride()); + PseudoObjExprOrCall = BinaryCopyOpr->getRHS(); + } + + Expr *CallWithoutInvariants = PseudoObjExprOrCall; + // Change PseudoObjectExpr to a direct call + if (auto *PseudoObjExpr = dyn_cast<PseudoObjectExpr>(PseudoObjExprOrCall)) + CallWithoutInvariants = *((PseudoObjExpr->semantics_begin()) - 1); + + Expr *FinalCall = CallWithoutInvariants; // For noinvariants clause + if (NoContext) { + // example to explain the changes done for "nocontext" clause: + // + // #pragma omp declare variant(foo_variant_dispatch) + // match(construct = {dispatch}) + // #pragma omp declare variant(foo_variant_allCond) + // match(user = {condition(1)}) + // ... + // #pragma omp dispatch nocontext(cond_true) + // foo(i, j); // with traits: CodeGen call to + // foo_variant_dispatch(i,j) + // dispatch construct is changed to: + // if (cond_true) { + // foo(i,j) // with traits: CodeGen call to foo_variant_allCond(i,j) + // } else { + // #pragma omp dispatch + // foo(i,j) // with traits: CodeGen call to foo_variant_dispatch(i,j) + // } + + // Convert StmtResult to a CallExpr before calling ActOnOpenMPCall() + auto *CallExprWithinStmt = cast<CallExpr>(CallWithoutInvariants); + int NumArgs = CallExprWithinStmt->getNumArgs(); + clang::Expr **Args = CallExprWithinStmt->getArgs(); + // ActOnOpenMPCall() adds traits to a simple function call + // e.g. invariant function call traits to "foo(i,j)", if they are present. + ExprResult ER = SemaPtr->ActOnOpenMPCall( + CallExprWithinStmt, SemaPtr->SemaRef.getCurScope(), + CallExprWithinStmt->getBeginLoc(), MultiExprArg(Args, NumArgs), + CallExprWithinStmt->getRParenLoc(), static_cast<Expr *>(nullptr)); + FinalCall = ER.get(); + } + + if (IsBinaryOp) { + BinaryCopyOpr->setRHS(FinalCall); + return BinaryCopyOpr; + } + + return FinalCall; +} + +static StmtResult combine2Stmts(ASTContext &Context, Stmt *FirstStmt, + Stmt *SecondStmt) { + + llvm::SmallVector<Stmt *, 2> NewCombinedStmtVector; + NewCombinedStmtVector.push_back(FirstStmt); + NewCombinedStmtVector.push_back(SecondStmt); + auto *CombinedStmt = CompoundStmt::Create( + Context, llvm::ArrayRef<Stmt *>(NewCombinedStmtVector), + FPOptionsOverride(), SourceLocation(), SourceLocation()); + return CombinedStmt; +} + +template <typename SpecificClause> +static bool hasClausesOfKind(ArrayRef<OMPClause *> Clauses) { + auto ClausesOfKind = + OMPExecutableDirective::getClausesOfKind<SpecificClause>(Clauses); + return ClausesOfKind.begin() != ClausesOfKind.end(); +} + +StmtResult SemaOpenMP::transformDispatchDirective( + OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName, + OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) { + + StmtResult RetValue; + llvm::SmallVector<OMPClause *, 8> DependClauseVector; + for (const OMPDependClause *ConstDependClause : + OMPExecutableDirective::getClausesOfKind<OMPDependClause>(Clauses)) { + auto *DependClause = const_cast<OMPDependClause *>(ConstDependClause); + DependClauseVector.push_back(DependClause); + } + + // #pragma omp dispatch depend() is changed to #pragma omp taskwait depend() + // This is done by calling ActOnOpenMPExecutableDirective() for the + // new taskwait directive. + StmtResult DispatchDepend2taskwait = ActOnOpenMPExecutableDirective( + OMPD_taskwait, DirName, CancelRegion, DependClauseVector, NULL, StartLoc, + EndLoc); + + if (OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses)) { + + if (OMPExecutableDirective::getSingleClause<OMPNocontextClause>(Clauses)) { + Diag(StartLoc, diag::warn_omp_dispatch_clause_novariants_nocontext); + } + + const OMPNovariantsClause *NoVariantsC = + OMPExecutableDirective::getSingleClause<OMPNovariantsClause>(Clauses); + // #pragma omp dispatch novariants(c2) depend(out: x) + // foo(); + // becomes: + // #pragma omp taskwait depend(out: x) + // if (c2) { + // foo(); + // } else { + // #pragma omp dispatch + // foo(); <--- foo() is replaced with foo_variant() in CodeGen + // } + Expr *Cond = getInitialExprFromCapturedExpr(NoVariantsC->getCondition()); ---------------- alexey-bataev wrote:
Why do you need to do it in sema? https://github.com/llvm/llvm-project/pull/117904 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits