llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-codegen Author: Roger Ferrer Ibáñez (rofirrim) <details> <summary>Changes</summary> This is preparatory work for the implementation of `#pragma omp fuse` in https://github.com/llvm/llvm-project/pull/139293 **Note**: this change builds on top of https://github.com/llvm/llvm-project/pull/155848 This change adds an additional class to hold data that will be shared between all loop transformations: those that apply to canonical loop nests (the majority) and those that apply to canonical loop sequences (`fuse` in OpenMP 6.0). This class is not a statement by itself and its goal is to avoid having to replicate information between classes. Also simplfiy the way we handle the "generated loops" information as we currently only need to know if it is zero or non-zero. --- Patch is 24.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155849.diff 11 Files Affected: - (modified) clang/include/clang/AST/StmtOpenMP.h (+62-53) - (modified) clang/include/clang/Basic/OpenMPKinds.h (+7) - (modified) clang/include/clang/Basic/StmtNodes.td (+8-6) - (modified) clang/lib/AST/StmtOpenMP.cpp (+7-6) - (modified) clang/lib/AST/StmtProfile.cpp (+7-7) - (modified) clang/lib/Basic/OpenMPKinds.cpp (+7-1) - (modified) clang/lib/CodeGen/CGStmtOpenMP.cpp (+2-1) - (modified) clang/lib/Sema/SemaOpenMP.cpp (+4-2) - (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+7-7) - (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+7-7) - (modified) clang/tools/libclang/CIndex.cpp (+9-9) ``````````diff diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index 2fb33d3036bca..602a516c0d43f 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -889,23 +889,24 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { /// Calls the specified callback function for all the loops in \p CurStmt, /// from the outermost to the innermost. - static bool - doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, - unsigned NumLoops, - llvm::function_ref<bool(unsigned, Stmt *)> Callback, - llvm::function_ref<void(OMPLoopTransformationDirective *)> - OnTransformationCallback); + static bool doForAllLoops( + Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, + llvm::function_ref<bool(unsigned, Stmt *)> Callback, + llvm::function_ref<void(OMPCanonicalLoopNestTransformationDirective *)> + OnTransformationCallback); static bool doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, const Stmt *)> Callback, - llvm::function_ref<void(const OMPLoopTransformationDirective *)> + llvm::function_ref< + void(const OMPCanonicalLoopNestTransformationDirective *)> OnTransformationCallback) { auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) { return Callback(Cnt, CurStmt); }; auto &&NewTransformCb = - [OnTransformationCallback](OMPLoopTransformationDirective *A) { + [OnTransformationCallback]( + OMPCanonicalLoopNestTransformationDirective *A) { OnTransformationCallback(A); }; return doForAllLoops(const_cast<Stmt *>(CurStmt), TryImperfectlyNestedLoops, @@ -918,7 +919,7 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, Stmt *)> Callback) { - auto &&TransformCb = [](OMPLoopTransformationDirective *) {}; + auto &&TransformCb = [](OMPCanonicalLoopNestTransformationDirective *) {}; return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback, TransformCb); } @@ -955,31 +956,42 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { } }; -/// The base class for all loop transformation directives. -class OMPLoopTransformationDirective : public OMPLoopBasedDirective { - friend class ASTStmtReader; +/// Common class of data shared between +/// OMPCanonicalLoopNestTransformationDirective and transformations over +/// canonical loop sequences. +class OMPLoopTransformationDirective { + /// Number of (top-level) generated loops. + /// This value is 1 for most transformations as they only map one loop nest + /// into another. + /// Some loop transformations (like a non-partial 'unroll') may not generate + /// a loop nest, so this would be 0. + /// Some loop transformations (like 'fuse' with looprange and 'split') may + /// generate more than one loop nest, so the value would be >= 1. + unsigned NumGeneratedLoops = 1; - /// Number of loops generated by this loop transformation. - unsigned NumGeneratedLoops = 0; +protected: + void setNumGeneratedLoops(unsigned N) { NumGeneratedLoops = N; } + +public: + unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } +}; + +/// The base class for all transformation directives of canonical loop nests. +class OMPCanonicalLoopNestTransformationDirective + : public OMPLoopBasedDirective, + public OMPLoopTransformationDirective { + friend class ASTStmtReader; protected: - explicit OMPLoopTransformationDirective(StmtClass SC, - OpenMPDirectiveKind Kind, - SourceLocation StartLoc, - SourceLocation EndLoc, - unsigned NumAssociatedLoops) + explicit OMPCanonicalLoopNestTransformationDirective( + StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc, + SourceLocation EndLoc, unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} - /// Set the number of loops generated by this loop transformation. - void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } - public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Return the number of loops generated by this loop transformation. - unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } - /// Get the de-sugared statements after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5545,7 +5557,8 @@ class OMPTargetTeamsDistributeSimdDirective final : public OMPLoopDirective { }; /// This represents the '#pragma omp tile' loop transformation directive. -class OMPTileDirective final : public OMPLoopTransformationDirective { +class OMPTileDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5557,11 +5570,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective { explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPTileDirectiveClass, - llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + : OMPCanonicalLoopNestTransformationDirective( + OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5622,7 +5633,8 @@ class OMPTileDirective final : public OMPLoopTransformationDirective { }; /// This represents the '#pragma omp stripe' loop transformation directive. -class OMPStripeDirective final : public OMPLoopTransformationDirective { +class OMPStripeDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5634,11 +5646,9 @@ class OMPStripeDirective final : public OMPLoopTransformationDirective { explicit OMPStripeDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPStripeDirectiveClass, - llvm::omp::OMPD_stripe, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + : OMPCanonicalLoopNestTransformationDirective( + OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc, + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5702,7 +5712,8 @@ class OMPStripeDirective final : public OMPLoopTransformationDirective { /// #pragma omp unroll /// for (int i = 0; i < 64; ++i) /// \endcode -class OMPUnrollDirective final : public OMPLoopTransformationDirective { +class OMPUnrollDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5713,9 +5724,9 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective { }; explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) - : OMPLoopTransformationDirective(OMPUnrollDirectiveClass, - llvm::omp::OMPD_unroll, StartLoc, EndLoc, - 1) {} + : OMPCanonicalLoopNestTransformationDirective(OMPUnrollDirectiveClass, + llvm::omp::OMPD_unroll, + StartLoc, EndLoc, 1) {} /// Set the pre-init statements. void setPreInits(Stmt *PreInits) { @@ -5776,7 +5787,8 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective { /// for (int i = 0; i < n; ++i) /// ... /// \endcode -class OMPReverseDirective final : public OMPLoopTransformationDirective { +class OMPReverseDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5788,11 +5800,9 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective { explicit OMPReverseDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPReverseDirectiveClass, - llvm::omp::OMPD_reverse, StartLoc, - EndLoc, NumLoops) { - setNumGeneratedLoops(NumLoops); - } + : OMPCanonicalLoopNestTransformationDirective( + OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc, + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5848,7 +5858,8 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective { /// for (int j = 0; j < n; ++j) /// .. /// \endcode -class OMPInterchangeDirective final : public OMPLoopTransformationDirective { +class OMPInterchangeDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5860,11 +5871,9 @@ class OMPInterchangeDirective final : public OMPLoopTransformationDirective { explicit OMPInterchangeDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPInterchangeDirectiveClass, - llvm::omp::OMPD_interchange, StartLoc, - EndLoc, NumLoops) { - setNumGeneratedLoops(NumLoops); - } + : OMPCanonicalLoopNestTransformationDirective( + OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc, + EndLoc, NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h index f40db4c13c55a..d3285cd9c6a14 100644 --- a/clang/include/clang/Basic/OpenMPKinds.h +++ b/clang/include/clang/Basic/OpenMPKinds.h @@ -365,6 +365,13 @@ bool isOpenMPTaskingDirective(OpenMPDirectiveKind Kind); /// functions bool isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind); +/// Checks if the specified directive is a loop transformation directive that +/// applies to a canonical loop nest. +/// \param DKind Specified directive. +/// \return True iff the directive is a loop transformation. +bool isOpenMPCanonicalLoopNestTransformationDirective( + OpenMPDirectiveKind DKind); + /// Checks if the specified directive is a loop transformation directive. /// \param DKind Specified directive. /// \return True iff the directive is a loop transformation. diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index c9c173f5c7469..781577549573d 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -227,12 +227,14 @@ def OMPLoopBasedDirective : StmtNode<OMPExecutableDirective, 1>; def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>; def OMPParallelDirective : StmtNode<OMPExecutableDirective>; def OMPSimdDirective : StmtNode<OMPLoopDirective>; -def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>; -def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPStripeDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>; +def OMPCanonicalLoopNestTransformationDirective + : StmtNode<OMPLoopBasedDirective, 1>; +def OMPTileDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPStripeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPUnrollDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPReverseDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPInterchangeDirective + : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPForDirective : StmtNode<OMPLoopDirective>; def OMPForSimdDirective : StmtNode<OMPLoopDirective>; def OMPSectionsDirective : StmtNode<OMPExecutableDirective>; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index 2eeb5e45ab511..36ecaf6489ef0 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -125,12 +125,13 @@ OMPLoopBasedDirective::tryToFindNextInnerLoop(Stmt *CurStmt, bool OMPLoopBasedDirective::doForAllLoops( Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, Stmt *)> Callback, - llvm::function_ref<void(OMPLoopTransformationDirective *)> + llvm::function_ref<void(OMPCanonicalLoopNestTransformationDirective *)> OnTransformationCallback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { while (true) { - auto *Dir = dyn_cast<OMPLoopTransformationDirective>(CurStmt); + auto *Dir = + dyn_cast<OMPCanonicalLoopNestTransformationDirective>(CurStmt); if (!Dir) break; @@ -369,11 +370,11 @@ OMPForDirective *OMPForDirective::Create( return Dir; } -Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { +Stmt *OMPCanonicalLoopNestTransformationDirective::getTransformedStmt() const { switch (getStmtClass()) { #define STMT(CLASS, PARENT) #define ABSTRACT_STMT(CLASS) -#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ +#define OMPCANONICALLOOPNESTTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ case Stmt::CLASS##Class: \ return static_cast<const CLASS *>(this)->getTransformedStmt(); #include "clang/AST/StmtNodes.inc" @@ -382,11 +383,11 @@ Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { } } -Stmt *OMPLoopTransformationDirective::getPreInits() const { +Stmt *OMPCanonicalLoopNestTransformationDirective::getPreInits() const { switch (getStmtClass()) { #define STMT(CLASS, PARENT) #define ABSTRACT_STMT(CLASS) -#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ +#define OMPCANONICALLOOPNESTTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ case Stmt::CLASS##Class: \ return static_cast<const CLASS *>(this)->getPreInits(); #include "clang/AST/StmtNodes.inc" diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 2035fa7635f2a..7a9b7fb431099 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -999,30 +999,30 @@ void StmtProfiler::VisitOMPSimdDirective(const OMPSimdDirective *S) { VisitOMPLoopDirective(S); } -void StmtProfiler::VisitOMPLoopTransformationDirective( - const OMPLoopTransformationDirective *S) { +void StmtProfiler::VisitOMPCanonicalLoopNestTransformationDirective( + const OMPCanonicalLoopNestTransformationDirective *S) { VisitOMPLoopBasedDirective(S); } void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPStripeDirective(const OMPStripeDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPInterchangeDirective( const OMPInterchangeDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index 220b31b0f19bc..3f8f64df8702e 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -717,11 +717,17 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) { Kind == OMPD_teams_loop || Kind == OMPD_target_teams_loop; } -bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { +bool clang::isOpenMPCanonicalLoopNestTransformationDirective( + OpenMPDirectiveKind DKind) { return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse || DKind == OMPD_interchange || DKind == OMPD_stripe; } +bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { + // FIXME: There will be more cases when we implement 'fuse'. + return isOpenMPCanonicalLoopNestTransformationDirective(DKind); +} + bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) { return DKind == OMPD_parallel_for || DKind == OMPD_parallel_for_simd || DKind == OMPD_parallel_master || diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index f6a0ca574a191..6f795b45bc381 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1927,7 +1927,8 @@ static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, return; } if (SimplifiedS == NextLoop) { - if (auto *Dir = dyn_cast<OMPLoopTransformationDirective>(SimplifiedS)) + if (auto *Dir = + dyn_cast<OMPCanonicalLoopNestTransformationDirective>(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 7d800c446b595..a02850c66b4fe 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -4145,7 +4145,8 @@ class DSAAttrChecker final : public StmtVisitor<DSAAttrChecker, void> { VisitSubCaptures(S); } - void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) { + void VisitOMPCanonicalLoopNestTransformationDirective( + OMPCanonicalLoopNestTransformationDirective *S) { // Loop transformation directives do not introduce data sharing VisitStmt(S); } @@ -9748,7 +9749,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr, } return false; }, - [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) { + [&SemaRef, + &Captures](OMPCanonicalLoopNestTransformationDirective *Transform) { Stmt *DependentPreInits = Transform->getPreInits(); if (!DependentPreInits) return; diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 3f37dfbc3dea9..13618b4a03d1e 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2442,30 +2442,30 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { VisitOMPLoopDirective(D); } -void ASTStmtReader::VisitOMPLoopTransformationDirective( - OMPLoopTransformationDirective *D) { +void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective( + OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); D->setNumG... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/155849 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits