================ @@ -15745,6 +15760,388 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, buildPreInits(Context, PreInits)); } +StmtResult +SemaOpenMP::ActOnOpenMPReverseDirective(ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc) { + ASTContext &Context = getASTContext(); + Scope *CurScope = SemaRef.getCurScope(); + assert(Clauses.empty() && "reverse directive does not accept any clauses; " + "must have beed checked before"); + + // Empty statement should only be possible if there already was an error. + if (!AStmt) + return StmtError(); + + constexpr unsigned NumLoops = 1; + Stmt *Body = nullptr; + SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers( + NumLoops); + SmallVector<SmallVector<Stmt *, 0>, NumLoops + 1> OriginalInits; + if (!checkTransformableLoopNest(OMPD_reverse, AStmt, NumLoops, LoopHelpers, + Body, OriginalInits)) + return StmtError(); + + // Delay applying the transformation to when template is completely + // instantiated. + if (SemaRef.CurContext->isDependentContext()) + return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses, + AStmt, nullptr, nullptr); + + assert(LoopHelpers.size() == NumLoops && + "Expecting a single-dimensional loop iteration space"); + assert(OriginalInits.size() == NumLoops && + "Expecting a single-dimensional loop iteration space"); + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers.front(); + + // Find the loop statement. + Stmt *LoopStmt = nullptr; + collectLoopStmts(AStmt, {LoopStmt}); + + // Determine the PreInit declarations. + SmallVector<Stmt *> PreInits; + addLoopPreInits(Context, LoopHelper, LoopStmt, OriginalInits[0], PreInits); + + auto *IterationVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef); + QualType IVTy = IterationVarRef->getType(); + uint64_t IVWidth = Context.getTypeSize(IVTy); + auto *OrigVar = cast<DeclRefExpr>(LoopHelper.Counters.front()); + + // Iteration variable SourceLocations. + SourceLocation OrigVarLoc = OrigVar->getExprLoc(); + SourceLocation OrigVarLocBegin = OrigVar->getBeginLoc(); + SourceLocation OrigVarLocEnd = OrigVar->getEndLoc(); + + // Locations pointing to the transformation. + SourceLocation TransformLoc = StartLoc; + SourceLocation TransformLocBegin = StartLoc; + SourceLocation TransformLocEnd = EndLoc; + + // Internal variable names. + std::string OrigVarName = OrigVar->getNameInfo().getAsString(); + std::string TripCountName = (Twine(".tripcount.") + OrigVarName).str(); + std::string ForwardIVName = (Twine(".forward.iv.") + OrigVarName).str(); + std::string ReversedIVName = (Twine(".reversed.iv.") + OrigVarName).str(); + + // LoopHelper.Updates will read the logical iteration number from + // LoopHelper.IterationVarRef, compute the value of the user loop counter of + // that logical iteration from it, then assign it to the user loop counter + // variable. We cannot directly use LoopHelper.IterationVarRef as the + // induction variable of the generated loop because it may cause an underflow: + // \code + // for (unsigned i = 0; i < n; ++i) + // body(i); + // \endcode + // + // Naive reversal: + // \code + // for (unsigned i = n-1; i >= 0; ++i) + // body(i); + // \endcode + // + // Instead, we introduce a new iteration variable representing the logical + // iteration counter of the original loop, convert it to the logical iteration + // number of the reversed loop, then let LoopHelper.Updates compute the user's + // loop iteration variable from it. + // \code + // for (auto .forward.iv = 0; .forward.iv < n; ++.forward.iv) { + // auto .reversed.iv = n - .forward.iv - 1; + // i = (.reversed.iv + 0) * 1 // LoopHelper.Updates + // body(i); // Body + // } + // \endcode + + // Subexpressions with more than one use. One of the constraints of an AST is + // that every node object must appear at most once, hence we define a lambda + // that creates a new AST node at every use. + CaptureVars CopyTransformer(SemaRef); + auto MakeNumIterations = [&CopyTransformer, &LoopHelper]() -> Expr * { + return AssertSuccess( + CopyTransformer.TransformExpr(LoopHelper.NumIterations)); + }; + + // Create the iteration variable for the forward loop (from 0 to n-1). + VarDecl *ForwardIVDecl = + buildVarDecl(SemaRef, {}, IVTy, ForwardIVName, nullptr, OrigVar); + auto MakeForwardRef = [&SemaRef = this->SemaRef, ForwardIVDecl, IVTy, + OrigVarLoc]() { + return buildDeclRefExpr(SemaRef, ForwardIVDecl, IVTy, OrigVarLoc); + }; + + // Iteration variable for the reversed induction variable (from n-1 downto 0): + // Reuse the iteration variable created by checkOpenMPLoop. + auto *ReversedIVDecl = cast<VarDecl>(IterationVarRef->getDecl()); + ReversedIVDecl->setDeclName( + &SemaRef.PP.getIdentifierTable().get(ReversedIVName)); + + // For init-statement: + // \code + // auto .forward.iv = 0 + // \endcode + IntegerLiteral *Zero = + IntegerLiteral::Create(Context, llvm::APInt::getZero(IVWidth), + ForwardIVDecl->getType(), OrigVarLoc); + SemaRef.AddInitializerToDecl(ForwardIVDecl, Zero, /*DirectInit=*/false); + StmtResult Init = new (Context) + DeclStmt(DeclGroupRef(ForwardIVDecl), OrigVarLocBegin, OrigVarLocEnd); + if (!Init.isUsable()) + return StmtError(); + + // Forward iv cond-expression: + // \code + // .forward.iv < NumIterations + // \endcode + ExprResult Cond = + SemaRef.BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_LT, + MakeForwardRef(), MakeNumIterations()); + if (!Cond.isUsable()) + return StmtError(); + + // Forward incr-statement: ++.forward.iv + ExprResult Incr = SemaRef.BuildUnaryOp(CurScope, LoopHelper.Inc->getExprLoc(), + UO_PreInc, MakeForwardRef()); + if (!Incr.isUsable()) + return StmtError(); + + // Reverse the forward-iv: auto .reversed.iv = MakeNumIterations() - 1 - + // .forward.iv + IntegerLiteral *One = IntegerLiteral::Create(Context, llvm::APInt(IVWidth, 1), + IVTy, TransformLoc); + ExprResult Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub, + MakeNumIterations(), One); + if (!Minus.isUsable()) + return StmtError(); + Minus = SemaRef.BuildBinOp(CurScope, TransformLoc, BO_Sub, Minus.get(), + MakeForwardRef()); + if (!Minus.isUsable()) + return StmtError(); + StmtResult InitReversed = new (Context) DeclStmt( + DeclGroupRef(ReversedIVDecl), TransformLocBegin, TransformLocEnd); + if (!InitReversed.isUsable()) + return StmtError(); + SemaRef.AddInitializerToDecl(ReversedIVDecl, Minus.get(), + /*DirectInit=*/false); + + // The new loop body. + SmallVector<Stmt *> BodyStmts; + BodyStmts.push_back(InitReversed.get()); + llvm::append_range(BodyStmts, LoopHelper.Updates); + if (auto *CXXRangeFor = dyn_cast<CXXForRangeStmt>(LoopStmt)) + BodyStmts.push_back(CXXRangeFor->getLoopVarStmt()); + BodyStmts.push_back(Body); + auto *ReversedBody = + CompoundStmt::Create(Context, BodyStmts, FPOptionsOverride(), + Body->getBeginLoc(), Body->getEndLoc()); + + // Finally create the reversed For-statement. + auto *ReversedFor = new (Context) + ForStmt(Context, Init.get(), Cond.get(), nullptr, Incr.get(), + ReversedBody, LoopHelper.Init->getBeginLoc(), + LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); + return OMPReverseDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + ReversedFor, + buildPreInits(Context, PreInits)); +} + +StmtResult SemaOpenMP::ActOnOpenMPInterchangeDirective( + ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc) { + ASTContext &Context = getASTContext(); + DeclContext *CurContext = SemaRef.CurContext; + Scope *CurScope = SemaRef.getCurScope(); + + // Empty statement should only be possible if there already was an error. + if (!AStmt) + return StmtError(); + + // interchange without permutation clause swaps two loops. + const OMPPermutationClause *PermutationClause = + OMPExecutableDirective::getSingleClause<OMPPermutationClause>(Clauses); + size_t NumLoops = PermutationClause ? PermutationClause->getNumLoops() : 2; + + // Verify and diagnose loop nest. + SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops); + Stmt *Body = nullptr; + SmallVector<SmallVector<Stmt *, 0>, 2> OriginalInits; + if (!checkTransformableLoopNest(OMPD_interchange, AStmt, NumLoops, + LoopHelpers, Body, OriginalInits)) + return StmtError(); + + // Delay interchange to when template is completely instantiated. + if (CurContext->isDependentContext()) + return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses, + NumLoops, AStmt, nullptr, nullptr); + + // An invalid expression in the permutation clause is set to nullptr in + // ActOnOpenMPPermutationClause. + if (PermutationClause && llvm::any_of(PermutationClause->getArgsRefs(), + [](Expr *E) { return !E; })) + return StmtError(); + + assert(LoopHelpers.size() == NumLoops && + "Expecting loop iteration space dimensionaly to match number of " + "affected loops"); + assert(OriginalInits.size() == NumLoops && + "Expecting loop iteration space dimensionaly to match number of " + "affected loops"); + + // Decode the permutation clause. + SmallVector<uint64_t, 2> Permutation; + if (!PermutationClause) { + Permutation = {1, 0}; + } else { + ArrayRef<Expr *> PermArgs = PermutationClause->getArgsRefs(); + llvm::BitVector Flags(PermArgs.size()); + for (Expr *PermArg : PermArgs) { + std::optional<llvm::APSInt> PermCstExpr = + PermArg->getIntegerConstantExpr(Context); + if (!PermCstExpr) + continue; + uint64_t PermInt = PermCstExpr->getZExtValue(); + assert(1 <= PermInt && PermInt <= NumLoops && + "Must be a permutation; diagnostic emitted in " + "ActOnOpenMPPermutationClause"); + if (Flags[PermInt - 1]) { + SourceRange ExprRange(PermArg->getBeginLoc(), PermArg->getEndLoc()); + Diag(PermArg->getExprLoc(), + diag::err_omp_interchange_permutation_value_repeated) + << PermInt << ExprRange; + continue; + } + Flags[PermInt - 1] = true; + + Permutation.push_back(PermInt - 1); + } + + if (Permutation.size() != NumLoops) + return StmtError(); + } + + // Nothing to transform with trivial permutation. + if (NumLoops <= 1 || llvm::all_of(llvm::enumerate(Permutation), [](auto p) { + auto [Idx, Arg] = p; + return Idx == Arg; + })) + return OMPInterchangeDirective::Create(Context, StartLoc, EndLoc, Clauses, + NumLoops, AStmt, AStmt, nullptr); + + // Find the affected loops. + SmallVector<Stmt *> LoopStmts(NumLoops, nullptr); + collectLoopStmts(AStmt, LoopStmts); + + // Collect pre-init statements on the order before the permuation. + SmallVector<Stmt *> PreInits; + for (auto I : llvm::seq<int>(NumLoops)) { + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; + + assert(LoopHelper.Counters.size() == 1 && + "Single-dimensional loop iteration space expected"); + auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front()); + + std::string OrigVarName = OrigCntVar->getNameInfo().getAsString(); + addLoopPreInits(Context, LoopHelper, LoopStmts[I], OriginalInits[I], + PreInits); + } + + SmallVector<VarDecl *> PermutedIndVars; + PermutedIndVars.resize(NumLoops); + CaptureVars CopyTransformer(SemaRef); + + // Create the permuted loops from the inside to the outside of the + // interchanged loop nest. Body of the innermost new loop is the original + // innermost body. + Stmt *Inner = Body; + for (auto TargetIdx : llvm::reverse(llvm::seq<int>(NumLoops))) { + // Get the original loop that belongs to this new position. + uint64_t SourceIdx = Permutation[TargetIdx]; + OMPLoopBasedDirective::HelperExprs &SourceHelper = LoopHelpers[SourceIdx]; + Stmt *SourceLoopStmt = LoopStmts[SourceIdx]; + assert(SourceHelper.Counters.size() == 1 && + "Single-dimensional loop iteration space expected"); + auto *OrigCntVar = cast<DeclRefExpr>(SourceHelper.Counters.front()); + + // Normalized loop counter variable: From 0 to n-1, always an integer type. + DeclRefExpr *IterVarRef = cast<DeclRefExpr>(SourceHelper.IterationVarRef); + QualType IVTy = IterVarRef->getType(); + assert(IVTy->isIntegerType() && + "Expected the logical iteration counter to be an integer"); + + std::string OrigVarName = OrigCntVar->getNameInfo().getAsString(); + SourceLocation OrigVarLoc = IterVarRef->getExprLoc(); + + // Make a copy of the NumIterations expression for each use: By the AST + // constraints, every expression object in a DeclContext must be unique. + auto MakeNumIterations = [&CopyTransformer, &SourceHelper]() -> Expr * { + return AssertSuccess( + CopyTransformer.TransformExpr(SourceHelper.NumIterations)); + }; + + // Iteration variable for the permuted loop. Reuse the one from + // checkOpenMPLoop which will also be used to update the original loop + // variable. + std::string PermutedCntName = + (Twine(".permuted_") + llvm::utostr(TargetIdx) + ".iv." + OrigVarName) + .str(); + auto *PermutedCntDecl = cast<VarDecl>(IterVarRef->getDecl()); + PermutedCntDecl->setDeclName( + &SemaRef.PP.getIdentifierTable().get(PermutedCntName)); + PermutedIndVars[TargetIdx] = PermutedCntDecl; + auto MakePermutedRef = [this, PermutedCntDecl, IVTy, OrigVarLoc]() { + return buildDeclRefExpr(SemaRef, PermutedCntDecl, IVTy, OrigVarLoc); + }; + + // For init-statement: + // \code{c} + // auto .permuted_{target}.iv = 0 + // \endcode + ExprResult Zero = SemaRef.ActOnIntegerConstant(OrigVarLoc, 0); + if (!Zero.isUsable()) + return StmtError(); + SemaRef.AddInitializerToDecl(PermutedCntDecl, Zero.get(), + /*DirectInit=*/false); + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef(PermutedCntDecl), OrigCntVar->getBeginLoc(), + OrigCntVar->getEndLoc()); + if (!InitStmt.isUsable()) + return StmtError(); + + // For cond-expression: + // \code{c} + // .permuted_{target}.iv < NumIterations + // \endcode + ExprResult CondExpr = + SemaRef.BuildBinOp(CurScope, SourceHelper.Cond->getExprLoc(), BO_LT, + MakePermutedRef(), MakeNumIterations()); + if (!CondExpr.isUsable()) + return StmtError(); + + // For incr-statement: + // \code{c} + // ++.tile.iv + // \endcode + ExprResult IncrStmt = SemaRef.BuildUnaryOp( + CurScope, SourceHelper.Inc->getExprLoc(), UO_PreInc, MakePermutedRef()); + if (!IncrStmt.isUsable()) + return StmtError(); + + SmallVector<Stmt *, 4> BodyParts; + llvm::append_range(BodyParts, SourceHelper.Updates); ---------------- alexey-bataev wrote:
```suggestion SmallVector<Stmt *, 4> BodyParts(SourceHelper.Updates.begin(), SourceHelper.Updates.end()); ``` https://github.com/llvm/llvm-project/pull/92030 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits