Meinersbur created this revision. Meinersbur added projects: OpenMP, clang. Herald added subscribers: dexonsmith, martong, arphaman, zzheng, guansong, yaxunl. Meinersbur requested review of this revision. Herald added a reviewer: jdoerfert. Herald added subscribers: llvm-commits, cfe-commits, sstefan1. Herald added a project: LLVM.
Implementation of the unroll directive introduced in OpenMP 5.1. Follows the approach from D76342 <https://reviews.llvm.org/D76342> for the tile directive (i.e. AST-based, not using the OpenMPIRBuilder). Tries to use `llvm.loop.unroll.*` metadata where possible, but has to fall back to an AST representation of the outer loop if the partially unrolled generated loop is associated with another directive (because it needs to compute the number of iterations). This is work in progress, tests and some diagnostics are still missing. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D99459 Files: clang/include/clang-c/Index.h clang/include/clang/AST/OpenMPClause.h clang/include/clang/AST/RecursiveASTVisitor.h clang/include/clang/AST/StmtOpenMP.h clang/include/clang/Basic/StmtNodes.td clang/include/clang/Sema/Sema.h clang/include/clang/Serialization/ASTBitCodes.h clang/lib/AST/OpenMPClause.cpp clang/lib/AST/StmtOpenMP.cpp clang/lib/AST/StmtPrinter.cpp clang/lib/AST/StmtProfile.cpp clang/lib/Basic/OpenMPKinds.cpp clang/lib/CodeGen/CGOpenMPRuntime.cpp clang/lib/CodeGen/CGStmt.cpp clang/lib/CodeGen/CGStmtOpenMP.cpp clang/lib/CodeGen/CodeGenFunction.h clang/lib/Parse/ParseOpenMP.cpp clang/lib/Sema/SemaExceptionSpec.cpp clang/lib/Sema/SemaOpenMP.cpp clang/lib/Sema/TreeTransform.h clang/lib/Serialization/ASTReader.cpp clang/lib/Serialization/ASTReaderStmt.cpp clang/lib/Serialization/ASTWriter.cpp clang/lib/Serialization/ASTWriterStmt.cpp clang/lib/StaticAnalyzer/Core/ExprEngine.cpp clang/test/OpenMP/unroll_ast_print.cpp clang/test/OpenMP/unroll_codegen.cpp clang/tools/libclang/CIndex.cpp clang/tools/libclang/CXCursor.cpp llvm/include/llvm/Frontend/OpenMP/OMP.td
Index: llvm/include/llvm/Frontend/OpenMP/OMP.td =================================================================== --- llvm/include/llvm/Frontend/OpenMP/OMP.td +++ llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -69,6 +69,8 @@ let flangClass = "OmpObjectList"; } def OMPC_Sizes: Clause<"sizes"> { let clangClass = "OMPSizesClause"; } +def OMPC_Full: Clause<"full"> { let clangClass = "OMPFullClause"; } +def OMPC_Partial: Clause<"partial"> { let clangClass = "OMPPartialClause"; } def OMPC_FirstPrivate : Clause<"firstprivate"> { let clangClass = "OMPFirstprivateClause"; let flangClass = "OmpObjectList"; @@ -381,6 +383,12 @@ VersionedClause<OMPC_Sizes, 51>, ]; } +def OMP_Unroll : Directive<"unroll"> { + let allowedOnceClauses = [ + VersionedClause<OMPC_Full, 51>, + VersionedClause<OMPC_Partial, 51>, + ]; +} def OMP_For : Directive<"for"> { let allowedClauses = [ VersionedClause<OMPC_Private>, Index: clang/tools/libclang/CXCursor.cpp =================================================================== --- clang/tools/libclang/CXCursor.cpp +++ clang/tools/libclang/CXCursor.cpp @@ -651,6 +651,9 @@ case Stmt::OMPTileDirectiveClass: K = CXCursor_OMPTileDirective; break; + case Stmt::OMPUnrollDirectiveClass: + K = CXCursor_OMPUnrollDirective; + break; case Stmt::OMPForDirectiveClass: K = CXCursor_OMPForDirective; break; Index: clang/tools/libclang/CIndex.cpp =================================================================== --- clang/tools/libclang/CIndex.cpp +++ clang/tools/libclang/CIndex.cpp @@ -2045,6 +2045,7 @@ void VisitOMPParallelDirective(const OMPParallelDirective *D); void VisitOMPSimdDirective(const OMPSimdDirective *D); void VisitOMPTileDirective(const OMPTileDirective *D); + void VisitOMPUnrollDirective(const OMPUnrollDirective *D); void VisitOMPForDirective(const OMPForDirective *D); void VisitOMPForSimdDirective(const OMPForSimdDirective *D); void VisitOMPSectionsDirective(const OMPSectionsDirective *D); @@ -2219,10 +2220,27 @@ } void OMPClauseEnqueue::VisitOMPSizesClause(const OMPSizesClause *C) { - for (auto E : C->getSizesRefs()) + for (Expr* E : C->getSizesRefs()) Visitor->AddStmt(E); } + + + + +void OMPClauseEnqueue::VisitOMPFullClause(const OMPFullClause *C) {} + + +void OMPClauseEnqueue::VisitOMPPartialClause(const OMPPartialClause *C) { + Visitor->AddStmt(C->getFactor()); +} + + + + + + + void OMPClauseEnqueue::VisitOMPAllocatorClause(const OMPAllocatorClause *C) { Visitor->AddStmt(C->getAllocator()); } @@ -2872,6 +2890,12 @@ VisitOMPLoopBasedDirective(D); } + +void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); +} + + void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) { VisitOMPLoopDirective(D); } @@ -5550,6 +5574,8 @@ return cxstring::createRef("OMPSimdDirective"); case CXCursor_OMPTileDirective: return cxstring::createRef("OMPTileDirective"); + case CXCursor_OMPUnrollDirective: + return cxstring::createRef("OMPUnrollDirective"); case CXCursor_OMPForDirective: return cxstring::createRef("OMPForDirective"); case CXCursor_OMPForSimdDirective: Index: clang/test/OpenMP/unroll_codegen.cpp =================================================================== --- /dev/null +++ clang/test/OpenMP/unroll_codegen.cpp @@ -0,0 +1,48 @@ +// Check code generation +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-llvm %s -o - | FileCheck %s --check-prefix=IR + +// Check same results after serialization round-trip +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-pch -o %t %s +// RUN: %clang_cc1 -verify -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -emit-llvm %s -o - | FileCheck %s --check-prefix=IR +// expected-no-diagnostics + +#ifndef HEADER +#define HEADER + +// placeholder for loop body code. +extern "C" void body(...) {} + +#if 0 +void func_unroll(int n) { + #pragma omp unroll + for (int i = 7; i < n; i += 3) + body(i); +} + + + +void func_unroll_full() { + #pragma omp unroll full + for (int i = 7; i < 17; i += 3) + body(i); +} +#endif + +void func_unroll_partial() { + #pragma omp unroll partial + for (int i = 7; i < 789; i += 3) + body(i); +} +#if 0 + +void func_unroll_partial_factor() { + #pragma omp unroll partial(4) + for (int i = 7; i < 789; i += 3) + body(i); +} +#endif + + + + +#endif /* HEADER */ Index: clang/test/OpenMP/unroll_ast_print.cpp =================================================================== --- /dev/null +++ clang/test/OpenMP/unroll_ast_print.cpp @@ -0,0 +1,107 @@ +// Check no warnings/errors +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -fsyntax-only -verify %s +// expected-no-diagnostics + +// Check AST and unparsing +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -ast-dump %s | FileCheck %s --check-prefix=DUMP +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -ast-print %s | FileCheck %s --check-prefix=PRINT --match-full-lines + +// Check same results after serialization round-trip +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -emit-pch -o %t %s +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -ast-dump-all %s | FileCheck %s --check-prefix=DUMP +// RUN: %clang_cc1 -triple x86_64-pc-linux-gnu -fopenmp -fopenmp-version=51 -include-pch %t -ast-print %s | FileCheck %s --check-prefix=PRINT --match-full-lines + +#ifndef HEADER +#define HEADER + +// placeholder for loop body code. +extern "C" void body(...); + + + +// PRINT-LABEL: void func_unroll() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll +void func_unroll() { + // PRINT: #pragma omp unroll + // DUMP: OMPUnrollDirective + #pragma omp unroll + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_full() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_full +void func_unroll_full() { + // PRINT: #pragma omp unroll full + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPFullClause + #pragma omp unroll full + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial +void func_unroll_partial() { + // PRINT: #pragma omp unroll partial + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: <<<NULL>>> + #pragma omp unroll partial + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial_factor() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial_factor +void func_unroll_partial_factor() { + // PRINT: #pragma omp unroll partial(4) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + // DUMP-NEXT: ConstantExpr + // DUMP-NEXT: value: Int 4 + // DUMP-NEXT: IntegerLiteral {{.*}} 4 + #pragma omp unroll partial(4) + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP-NEXT: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + + +// PRINT-LABEL: void func_unroll_partial_factor_for() { +// DUMP-LABEL: FunctionDecl {{.*}} func_unroll_partial_factor_for +void func_unroll_partial_factor_for() { + // PRINT: #pragma omp for + // DUMP: OMPForDirective + #pragma omp for + // PRINT: #pragma omp unroll partial(2) + // DUMP: OMPUnrollDirective + // DUMP-NEXT: OMPPartialClause + #pragma omp unroll partial(2) + // PRINT-NEXT: for (int i = 7; i < 17; i += 3) + // DUMP: ForStmt + for (int i = 7; i < 17; i += 3) + // PRINT-NEXT: body(i); + // DUMP: CallExpr + body(i); +} + +#endif Index: clang/lib/StaticAnalyzer/Core/ExprEngine.cpp =================================================================== --- clang/lib/StaticAnalyzer/Core/ExprEngine.cpp +++ clang/lib/StaticAnalyzer/Core/ExprEngine.cpp @@ -1294,6 +1294,7 @@ case Stmt::OMPTargetTeamsDistributeParallelForSimdDirectiveClass: case Stmt::OMPTargetTeamsDistributeSimdDirectiveClass: case Stmt::OMPTileDirectiveClass: + case Stmt::OMPUnrollDirectiveClass: case Stmt::CapturedStmtClass: { const ExplodedNode *node = Bldr.generateSink(S, Pred, Pred->getState()); Engine.addAbortedBlock(node, currBldrCtx->getBlock()); Index: clang/lib/Serialization/ASTWriterStmt.cpp =================================================================== --- clang/lib/Serialization/ASTWriterStmt.cpp +++ clang/lib/Serialization/ASTWriterStmt.cpp @@ -2210,6 +2210,11 @@ Code = serialization::STMT_OMP_TILE_DIRECTIVE; } +void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); + Code = serialization::STMT_OMP_UNROLL_DIRECTIVE; +} + void ASTStmtWriter::VisitOMPForDirective(OMPForDirective *D) { VisitOMPLoopDirective(D); Record.writeBool(D->hasCancel()); Index: clang/lib/Serialization/ASTWriter.cpp =================================================================== --- clang/lib/Serialization/ASTWriter.cpp +++ clang/lib/Serialization/ASTWriter.cpp @@ -6128,6 +6128,13 @@ Record.AddSourceLocation(C->getLParenLoc()); } +void OMPClauseWriter::VisitOMPFullClause(OMPFullClause *C) {} + +void OMPClauseWriter::VisitOMPPartialClause(OMPPartialClause *C) { + Record.AddStmt(C->getFactor()); + Record.AddSourceLocation(C->getLParenLoc()); +} + void OMPClauseWriter::VisitOMPAllocatorClause(OMPAllocatorClause *C) { Record.AddStmt(C->getAllocator()); Record.AddSourceLocation(C->getLParenLoc()); Index: clang/lib/Serialization/ASTReaderStmt.cpp =================================================================== --- clang/lib/Serialization/ASTReaderStmt.cpp +++ clang/lib/Serialization/ASTReaderStmt.cpp @@ -2287,7 +2287,7 @@ void ASTStmtReader::VisitOMPLoopBasedDirective(OMPLoopBasedDirective *D) { VisitStmt(D); - // Field CollapsedNum was read in ReadStmtFromStream. + // Field NumAssociatedLoops was read in ReadStmtFromStream. Record.skipInts(1); VisitOMPExecutableDirective(D); } @@ -2310,6 +2310,10 @@ VisitOMPLoopBasedDirective(D); } +void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) { + VisitOMPLoopBasedDirective(D); +} + void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) { VisitOMPLoopDirective(D); D->setHasCancel(Record.readBool()); @@ -3170,6 +3174,14 @@ break; } + case STMT_OMP_UNROLL_DIRECTIVE: { + unsigned NumLoops = Record[ASTStmtReader::NumStmtFields]; + assert(NumLoops == 1); + unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1]; + S = OMPUnrollDirective::CreateEmpty(Context, NumClauses); + break; + } + case STMT_OMP_FOR_DIRECTIVE: { unsigned CollapsedNum = Record[ASTStmtReader::NumStmtFields]; unsigned NumClauses = Record[ASTStmtReader::NumStmtFields + 1]; Index: clang/lib/Serialization/ASTReader.cpp =================================================================== --- clang/lib/Serialization/ASTReader.cpp +++ clang/lib/Serialization/ASTReader.cpp @@ -11745,6 +11745,12 @@ C = OMPSizesClause::CreateEmpty(Context, NumSizes); break; } + case llvm::omp::OMPC_full: + C = OMPFullClause::CreateEmpty(Context); + break; + case llvm::omp::OMPC_partial: + C = OMPPartialClause::CreateEmpty(Context); + break; case llvm::omp::OMPC_allocator: C = new (Context) OMPAllocatorClause(); break; @@ -12042,6 +12048,13 @@ C->setLParenLoc(Record.readSourceLocation()); } +void OMPClauseReader::VisitOMPFullClause(OMPFullClause *C) {} + +void OMPClauseReader::VisitOMPPartialClause(OMPPartialClause *C) { + C->setFactor(Record.readSubExpr()); + C->setLParenLoc(Record.readSourceLocation()); +} + void OMPClauseReader::VisitOMPAllocatorClause(OMPAllocatorClause *C) { C->setAllocator(Record.readExpr()); C->setLParenLoc(Record.readSourceLocation()); Index: clang/lib/Sema/TreeTransform.h =================================================================== --- clang/lib/Sema/TreeTransform.h +++ clang/lib/Sema/TreeTransform.h @@ -1633,6 +1633,18 @@ return getSema().ActOnOpenMPSizesClause(Sizes, StartLoc, LParenLoc, EndLoc); } + OMPClause *RebuildOMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPFullClause(StartLoc, EndLoc); + } + + OMPClause *RebuildOMPPartialClause(Expr *Factor, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + return getSema().ActOnOpenMPPartialClause(Factor, StartLoc, LParenLoc, + EndLoc); + } + /// Build a new OpenMP 'allocator' clause. /// /// By default, performs semantic analysis to build the new OpenMP clause. @@ -8441,6 +8453,17 @@ return Res; } +template <typename Derived> +StmtResult +TreeTransform<Derived>::TransformOMPUnrollDirective(OMPUnrollDirective *D) { + DeclarationNameInfo DirName; + getDerived().getSema().StartOpenMPDSABlock(D->getDirectiveKind(), DirName, + nullptr, D->getBeginLoc()); + StmtResult Res = getDerived().TransformOMPExecutableDirective(D); + getDerived().getSema().EndOpenMPDSABlock(Res.get()); + return Res; +} + template <typename Derived> StmtResult TreeTransform<Derived>::TransformOMPForDirective(OMPForDirective *D) { @@ -9108,6 +9131,28 @@ C->getLParenLoc(), C->getEndLoc()); } +template <typename Derived> +OMPClause *TreeTransform<Derived>::TransformOMPFullClause(OMPFullClause *C) { + if (!getDerived().AlwaysRebuild()) + return C; + return RebuildOMPFullClause(C->getBeginLoc(), C->getEndLoc()); +} + +template <typename Derived> +OMPClause * +TreeTransform<Derived>::TransformOMPPartialClause(OMPPartialClause *C) { + ExprResult T = getDerived().TransformExpr(C->getFactor()); + if (T.isInvalid()) + return nullptr; + Expr *Factor = T.get(); + bool Changed = Factor != C->getFactor(); + + if (!Changed && !getDerived().AlwaysRebuild()) + return C; + return RebuildOMPPartialClause(Factor, C->getBeginLoc(), C->getLParenLoc(), + C->getEndLoc()); +} + template <typename Derived> OMPClause * TreeTransform<Derived>::TransformOMPCollapseClause(OMPCollapseClause *C) { Index: clang/lib/Sema/SemaOpenMP.cpp =================================================================== --- clang/lib/Sema/SemaOpenMP.cpp +++ clang/lib/Sema/SemaOpenMP.cpp @@ -3804,6 +3804,11 @@ VisitStmt(S); } + void VisitOMPUnrollDirective(OMPUnrollDirective *S) { + // #pragma omp unroll does not introduce data sharing. + VisitStmt(S); + } + void VisitStmt(Stmt *S) { for (Stmt *C : S->children()) { if (C) { @@ -3969,6 +3974,7 @@ case OMPD_section: case OMPD_master: case OMPD_tile: + case OMPD_unroll: break; case OMPD_simd: case OMPD_for: @@ -5825,6 +5831,10 @@ Res = ActOnOpenMPTileDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc); break; + case OMPD_unroll: + Res = ActOnOpenMPUnrollDirective(ClausesWithImplicit, AStmt, StartLoc, + EndLoc); + break; case OMPD_for: Res = ActOnOpenMPForDirective(ClausesWithImplicit, AStmt, StartLoc, EndLoc, VarsWithInheritedDSA); @@ -12438,6 +12448,35 @@ Context, StartLoc, EndLoc, NestedLoopCount, Clauses, AStmt, B); } +bool Sema::checkTransformableLoopNest( + OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops, + SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers, + Stmt *&Body, SmallVectorImpl<Stmt *> &OriginalInits) { + return OMPLoopBasedDirective::doForAllLoops( + AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, NumLoops, + [this, &LoopHelpers, &Body, &OriginalInits, Kind](unsigned Cnt, + Stmt *CurStmt) { + VarsWithInheritedDSAType TmpDSA; + unsigned SingleNumLoops = + checkOpenMPLoop(Kind, nullptr, nullptr, CurStmt, *this, *DSAStack, + TmpDSA, LoopHelpers[Cnt]); + if (SingleNumLoops == 0) + return true; + assert(SingleNumLoops == 1 && "Expect single loop iteration space"); + if (auto *For = dyn_cast<ForStmt>(CurStmt)) { + OriginalInits.push_back(For->getInit()); + Body = For->getBody(); + } else { + assert(isa<CXXForRangeStmt>(CurStmt) && + "Expected canonical for or range-based for loops."); + auto *CXXFor = cast<CXXForRangeStmt>(CurStmt); + OriginalInits.push_back(CXXFor->getBeginStmt()); + Body = CXXFor->getBody(); + } + return false; + }); +} + StmtResult Sema::ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc) { @@ -12458,30 +12497,8 @@ SmallVector<OMPLoopBasedDirective::HelperExprs, 4> LoopHelpers(NumLoops); Stmt *Body = nullptr; SmallVector<Stmt *, 4> OriginalInits; - if (!OMPLoopBasedDirective::doForAllLoops( - AStmt->IgnoreContainers(), /*TryImperfectlyNestedLoops=*/false, - NumLoops, - [this, &LoopHelpers, &Body, &OriginalInits](unsigned Cnt, - Stmt *CurStmt) { - VarsWithInheritedDSAType TmpDSA; - unsigned SingleNumLoops = - checkOpenMPLoop(OMPD_tile, nullptr, nullptr, CurStmt, *this, - *DSAStack, TmpDSA, LoopHelpers[Cnt]); - if (SingleNumLoops == 0) - return true; - assert(SingleNumLoops == 1 && "Expect single loop iteration space"); - if (auto *For = dyn_cast<ForStmt>(CurStmt)) { - OriginalInits.push_back(For->getInit()); - Body = For->getBody(); - } else { - assert(isa<CXXForRangeStmt>(CurStmt) && - "Expected canonical for or range-based for loops."); - auto *CXXFor = cast<CXXForRangeStmt>(CurStmt); - OriginalInits.push_back(CXXFor->getBeginStmt()); - Body = CXXFor->getBody(); - } - return false; - })) + if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body, + OriginalInits)) return StmtError(); // Delay tiling to when template is completely instantiated. @@ -12666,6 +12683,243 @@ buildPreInits(Context, PreInits)); } +StmtResult Sema::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, + SourceLocation StartLoc, + SourceLocation EndLoc) { + auto FullClauses = + OMPExecutableDirective::getClausesOfKind<OMPFullClause>(Clauses); + const OMPFullClause *FullClause = nullptr; + if (!FullClauses.empty()) { + assert(hasSingleElement(FullClauses)); + FullClause = *FullClauses.begin(); + } + + auto PartialClauses = + OMPExecutableDirective::getClausesOfKind<OMPPartialClause>(Clauses); + const OMPPartialClause *PartialClause = nullptr; + if (!PartialClauses.empty()) { + assert(hasSingleElement(PartialClauses)); + PartialClause = *PartialClauses.begin(); + } + + assert(!(FullClause && PartialClause)); + + // Empty statement should only be possible if there already was an error. + if (!AStmt) + return StmtError(); + + if (!PartialClause) + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + nullptr, nullptr); + + Stmt *TransformedStmt = nullptr; + // Stmt* PreInits = nullptr; + + constexpr unsigned NumLoops = 1; + SmallVector<OMPLoopBasedDirective::HelperExprs, NumLoops> LoopHelpers( + NumLoops); + Stmt *Body = nullptr; + SmallVector<Stmt *, 4> OriginalInits; + if (!checkTransformableLoopNest(OMPD_tile, AStmt, NumLoops, LoopHelpers, Body, + OriginalInits)) + return StmtError(); + auto &LoopHelper = LoopHelpers.front(); + auto &OriginalInit = OriginalInits.front(); + + // Delay unrolling to when template is completely instantiated. + if (CurContext->isDependentContext()) + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + nullptr, nullptr); + + uint64_t Factor = 0; + Expr *FactorExpr = PartialClause->getFactor(); + if (FactorExpr) { + llvm::APSInt FactorInt; + VerifyIntegerConstantExpression(FactorExpr, &FactorInt); + Factor = FactorInt.getZExtValue(); + } else { + CanQualType FactorTy = Context.IntTy; + FactorExpr = new (Context) IntegerLiteral( + Context, llvm::APInt(Context.getIntWidth(FactorTy), 0), FactorTy, {}); + } + + // Collection of generated variable declaration. + SmallVector<Decl *, 4> PreInits; + + // Create iteration variables for the generated loops. + SmallVector<VarDecl *, 4> FloorIndVars; + SmallVector<VarDecl *, 4> TileIndVars; + FloorIndVars.resize(NumLoops); + TileIndVars.resize(NumLoops); + for (unsigned I = 0; I < NumLoops; ++I) { + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; + if (auto *PI = cast_or_null<DeclStmt>(LoopHelper.PreInits)) + PreInits.append(PI->decl_begin(), PI->decl_end()); + assert(LoopHelper.Counters.size() == 1 && + "Expect single-dimensional loop iteration space"); + auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters.front()); + std::string OrigVarName = OrigCntVar->getNameInfo().getAsString(); + DeclRefExpr *IterVarRef = cast<DeclRefExpr>(LoopHelper.IterationVarRef); + QualType CntTy = IterVarRef->getType(); + + // Iteration variable for the floor (i.e. outer) loop. + { + std::string FloorCntName = + (Twine(".floor_") + llvm::utostr(I) + ".iv." + OrigVarName).str(); + VarDecl *FloorCntDecl = + buildVarDecl(*this, {}, CntTy, FloorCntName, nullptr, OrigCntVar); + FloorIndVars[I] = FloorCntDecl; + } + + // Iteration variable for the tile (i.e. inner) loop. + { + std::string TileCntName = + (Twine(".tile_") + llvm::utostr(I) + ".iv." + OrigVarName).str(); + + // Reuse the iteration variable created by checkOpenMPLoop. It is also + // used by the expressions to derive the original iteration variable's + // value from the logical iteration number. + auto *TileCntDecl = cast<VarDecl>(IterVarRef->getDecl()); + TileCntDecl->setDeclName(&PP.getIdentifierTable().get(TileCntName)); + TileIndVars[I] = TileCntDecl; + } + if (auto *PI = dyn_cast_or_null<DeclStmt>(OriginalInits[I])) + PreInits.append(PI->decl_begin(), PI->decl_end()); + // Gather declarations for the data members used as counters. + for (Expr *CounterRef : LoopHelper.Counters) { + auto *CounterDecl = cast<DeclRefExpr>(CounterRef)->getDecl(); + if (isa<OMPCapturedExprDecl>(CounterDecl)) + PreInits.push_back(CounterDecl); + } + } + + // Once the original iteration values are set, append the innermost body. + Stmt *Inner = Body; + + // Create tile loops from the inside to the outside. + for (int I = NumLoops - 1; I >= 0; --I) { + OMPLoopBasedDirective::HelperExprs &LoopHelper = LoopHelpers[I]; + Expr *NumIterations = LoopHelper.NumIterations; + auto *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]); + QualType CntTy = OrigCntVar->getType(); + Expr *DimTileSize = FactorExpr; + Scope *CurScope = getCurScope(); + + // Commonly used variables. + DeclRefExpr *TileIV = buildDeclRefExpr(*this, TileIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + DeclRefExpr *FloorIV = buildDeclRefExpr(*this, FloorIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + + // For init-statement: auto .tile.iv = .floor.iv + AddInitializerToDecl(TileIndVars[I], DefaultLvalueConversion(FloorIV).get(), + /*DirectInit=*/false); + Decl *CounterDecl = TileIndVars[I]; + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef::Create(Context, &CounterDecl, 1), + OrigCntVar->getBeginLoc(), OrigCntVar->getEndLoc()); + if (!InitStmt.isUsable()) + return StmtError(); + + // For cond-expression: .tile.iv < min(.floor.iv + DimTileSize, + // NumIterations) + ExprResult EndOfTile = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_Add, FloorIV, DimTileSize); + if (!EndOfTile.isUsable()) + return StmtError(); + ExprResult IsPartialTile = + BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), BO_LT, + NumIterations, EndOfTile.get()); + if (!IsPartialTile.isUsable()) + return StmtError(); + ExprResult MinTileAndIterSpace = ActOnConditionalOp( + LoopHelper.Cond->getBeginLoc(), LoopHelper.Cond->getEndLoc(), + IsPartialTile.get(), NumIterations, EndOfTile.get()); + if (!MinTileAndIterSpace.isUsable()) + return StmtError(); + ExprResult CondExpr = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_LT, TileIV, MinTileAndIterSpace.get()); + if (!CondExpr.isUsable()) + return StmtError(); + + // For incr-statement: ++.tile.iv + ExprResult IncrStmt = + BuildUnaryOp(CurScope, LoopHelper.Inc->getExprLoc(), UO_PreInc, TileIV); + if (!IncrStmt.isUsable()) + return StmtError(); + + // Statements to set the original iteration variable's value from the + // logical iteration number. + // Generated for loop is: + // Original_for_init; + // for (auto .tile.iv = .floor.iv; .tile.iv < min(.floor.iv + DimTileSize, + // NumIterations); ++.tile.iv) { + // Original_Body; + // Original_counter_update; + // } + // FIXME: If the innermost body is an loop itself, inserting these + // statements stops it being recognized as a perfectly nested loop (e.g. + // for applying tiling again). If this is the case, sink the expressions + // further into the inner loop. + SmallVector<Stmt *, 4> BodyParts; + BodyParts.append(LoopHelper.Updates.begin(), LoopHelper.Updates.end()); + BodyParts.push_back(Inner); + Inner = CompoundStmt::Create(Context, BodyParts, Inner->getBeginLoc(), + Inner->getEndLoc()); + Inner = new (Context) + ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, + IncrStmt.get(), Inner, LoopHelper.Init->getBeginLoc(), + LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); + } + + // Create floor loops from the inside to the outside. + for (int I = NumLoops - 1; I >= 0; --I) { + auto &LoopHelper = LoopHelpers[I]; + Expr *NumIterations = LoopHelper.NumIterations; + DeclRefExpr *OrigCntVar = cast<DeclRefExpr>(LoopHelper.Counters[0]); + QualType CntTy = OrigCntVar->getType(); + Expr *DimTileSize = FactorExpr; + Scope *CurScope = getCurScope(); + + // Commonly used variables. + DeclRefExpr *FloorIV = buildDeclRefExpr(*this, FloorIndVars[I], CntTy, + OrigCntVar->getExprLoc()); + + // For init-statement: auto .floor.iv = 0 + AddInitializerToDecl( + FloorIndVars[I], + ActOnIntegerConstant(LoopHelper.Init->getExprLoc(), 0).get(), + /*DirectInit=*/false); + Decl *CounterDecl = FloorIndVars[I]; + StmtResult InitStmt = new (Context) + DeclStmt(DeclGroupRef::Create(Context, &CounterDecl, 1), + OrigCntVar->getBeginLoc(), OrigCntVar->getEndLoc()); + if (!InitStmt.isUsable()) + return StmtError(); + + // For cond-expression: .floor.iv < NumIterations + ExprResult CondExpr = BuildBinOp(CurScope, LoopHelper.Cond->getExprLoc(), + BO_LT, FloorIV, NumIterations); + if (!CondExpr.isUsable()) + return StmtError(); + + // For incr-statement: .floor.iv += DimTileSize + ExprResult IncrStmt = BuildBinOp(CurScope, LoopHelper.Inc->getExprLoc(), + BO_AddAssign, FloorIV, DimTileSize); + if (!IncrStmt.isUsable()) + return StmtError(); + + Inner = new (Context) + ForStmt(Context, InitStmt.get(), CondExpr.get(), nullptr, + IncrStmt.get(), Inner, LoopHelper.Init->getBeginLoc(), + LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); + } + + return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, + Inner, buildPreInits(Context, PreInits)); +} + OMPClause *Sema::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, Expr *Expr, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -12717,6 +12971,9 @@ case OMPC_detach: Res = ActOnOpenMPDetachClause(Expr, StartLoc, LParenLoc, EndLoc); break; + case OMPC_partial: + Res = ActOnOpenMPPartialClause(Expr, StartLoc, LParenLoc, EndLoc); + break; case OMPC_device: case OMPC_if: case OMPC_default: @@ -12919,6 +13176,7 @@ case OMPD_end_declare_target: case OMPD_teams: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_sections: case OMPD_section: @@ -12996,6 +13254,7 @@ case OMPD_teams: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13076,6 +13335,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13154,6 +13414,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13233,6 +13494,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -13311,6 +13573,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13388,6 +13651,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -13468,6 +13732,7 @@ case OMPD_end_declare_target: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -14139,6 +14404,25 @@ SizeExprs); } +OMPClause *Sema::ActOnOpenMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc) { + return OMPFullClause::Create(Context, StartLoc, EndLoc); +} + +OMPClause *Sema::ActOnOpenMPPartialClause(Expr *FactorExpr, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc) { + if (FactorExpr) { + ExprResult FactorResult = VerifyPositiveIntegerConstantInClause( + FactorExpr, OMPC_partial, /*StrictlyPositive=*/true); + FactorExpr = AssertSuccess(FactorResult); + } + + return OMPPartialClause::Create(Context, StartLoc, LParenLoc, EndLoc, + FactorExpr); +} + OMPClause *Sema::ActOnOpenMPSingleExprWithArgClause( OpenMPClauseKind Kind, ArrayRef<unsigned> Argument, Expr *Expr, SourceLocation StartLoc, SourceLocation LParenLoc, @@ -14437,6 +14721,12 @@ case OMPC_destroy: Res = ActOnOpenMPDestroyClause(StartLoc, EndLoc); break; + case OMPC_full: + Res = ActOnOpenMPFullClause(StartLoc, EndLoc); + break; + case OMPC_partial: + Res = ActOnOpenMPPartialClause(nullptr, StartLoc, {}, EndLoc); + break; case OMPC_if: case OMPC_final: case OMPC_num_threads: Index: clang/lib/Sema/SemaExceptionSpec.cpp =================================================================== --- clang/lib/Sema/SemaExceptionSpec.cpp +++ clang/lib/Sema/SemaExceptionSpec.cpp @@ -1460,6 +1460,7 @@ case Stmt::OMPSectionsDirectiveClass: case Stmt::OMPSimdDirectiveClass: case Stmt::OMPTileDirectiveClass: + case Stmt::OMPUnrollDirectiveClass: case Stmt::OMPSingleDirectiveClass: case Stmt::OMPTargetDataDirectiveClass: case Stmt::OMPTargetDirectiveClass: Index: clang/lib/Parse/ParseOpenMP.cpp =================================================================== --- clang/lib/Parse/ParseOpenMP.cpp +++ clang/lib/Parse/ParseOpenMP.cpp @@ -2154,6 +2154,7 @@ case OMPD_parallel: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_task: case OMPD_taskyield: case OMPD_barrier: @@ -2389,6 +2390,7 @@ case OMPD_parallel: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_for: case OMPD_for_simd: case OMPD_sections: @@ -2773,6 +2775,7 @@ case OMPC_allocator: case OMPC_depobj: case OMPC_detach: + case OMPC_partial: // OpenMP [2.5, Restrictions] // At most one num_threads clause can appear on the directive. // OpenMP [2.8.1, simd construct, Restrictions] @@ -2801,7 +2804,8 @@ ErrorFound = true; } - if (CKind == OMPC_ordered && PP.LookAhead(/*N=*/0).isNot(tok::l_paren)) + if ((CKind == OMPC_ordered || CKind == OMPC_partial) && + PP.LookAhead(/*N=*/0).isNot(tok::l_paren)) Clause = ParseOpenMPClause(CKind, WrongDirective); else Clause = ParseOpenMPSingleExprClause(CKind, WrongDirective); @@ -2865,6 +2869,7 @@ case OMPC_reverse_offload: case OMPC_dynamic_allocators: case OMPC_destroy: + case OMPC_full: // OpenMP [2.7.1, Restrictions, p. 9] // Only one ordered clause can appear on a loop directive. // OpenMP [2.7.1, Restrictions, C/C++, p. 4] @@ -2941,7 +2946,7 @@ SkipUntil(tok::comma, tok::annot_pragma_openmp_end, StopBeforeMatch); break; default: - break; + llvm_unreachable("Unhandled clause"); } return ErrorFound ? nullptr : Clause; } Index: clang/lib/CodeGen/CodeGenFunction.h =================================================================== --- clang/lib/CodeGen/CodeGenFunction.h +++ clang/lib/CodeGen/CodeGenFunction.h @@ -3417,6 +3417,7 @@ void EmitOMPParallelDirective(const OMPParallelDirective &S); void EmitOMPSimdDirective(const OMPSimdDirective &S); void EmitOMPTileDirective(const OMPTileDirective &S); + void EmitOMPUnrollDirective(const OMPUnrollDirective &S); void EmitOMPForDirective(const OMPForDirective &S); void EmitOMPForSimdDirective(const OMPForSimdDirective &S); void EmitOMPSectionsDirective(const OMPSectionsDirective &S); Index: clang/lib/CodeGen/CGStmtOpenMP.cpp =================================================================== --- clang/lib/CodeGen/CGStmtOpenMP.cpp +++ clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -176,6 +176,8 @@ PreInits = cast_or_null<DeclStmt>(LD->getPreInits()); } else if (const auto *Tile = dyn_cast<OMPTileDirective>(&S)) { PreInits = cast_or_null<DeclStmt>(Tile->getPreInits()); + } else if (const auto *Unroll = dyn_cast<OMPUnrollDirective>(&S)) { + PreInits = cast_or_null<DeclStmt>(Unroll->getPreInits()); } else { llvm_unreachable("Unknown loop-based directive kind."); } @@ -1803,6 +1805,8 @@ SimplifiedS); if (auto *Dir = dyn_cast<OMPTileDirective>(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); + if (auto *Dir = dyn_cast<OMPUnrollDirective>(SimplifiedS)) + SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); if (const auto *For = dyn_cast<ForStmt>(SimplifiedS)) { @@ -2561,6 +2565,45 @@ EmitStmt(S.getTransformedStmt()); } +void CodeGenFunction::EmitOMPUnrollDirective(const OMPUnrollDirective &S) { + // This function is only called if the unrolled loop is not consumed by any + // other loop-associated construct. Such a loop-associated construct will have + // used the transformed AST. + + auto FullClauses = S.getClausesOfKind<OMPFullClause>(); + const OMPFullClause *FullClause = nullptr; + if (!FullClauses.empty()) { + assert(hasSingleElement(FullClauses)); + FullClause = *FullClauses.begin(); + } + + auto PartialClauses = S.getClausesOfKind<OMPPartialClause>(); + const OMPPartialClause *PartialClause = nullptr; + if (!PartialClauses.empty()) { + assert(hasSingleElement(PartialClauses)); + PartialClause = *PartialClauses.begin(); + } + + uint64_t Factor = 0; + if (PartialClause) { + if (Expr *FactorExpr = PartialClause->getFactor()) { + RValue FactorRVal = EmitAnyExpr(FactorExpr, AggValueSlot::ignored(), + /*ignoreResult=*/true); + Factor = + cast<llvm::ConstantInt>(FactorRVal.getScalarVal())->getZExtValue(); + assert(Factor >= 1 && "One positive factors are valid"); + } + } + + // OMPTransformDirectiveScopeRAII UnrollScope(*this, &S); + LoopStack.setUnrollState(LoopAttributes::Enable); + if (Factor >= 1) + LoopStack.setUnrollCount(Factor); + else if (FullClause) + LoopStack.setUnrollState(LoopAttributes::Full); + EmitStmt(S.getAssociatedStmt()); +} + void CodeGenFunction::EmitOMPOuterLoop( bool DynamicOrOrdered, bool IsMonotonic, const OMPLoopDirective &S, CodeGenFunction::OMPPrivateScope &LoopScope, Index: clang/lib/CodeGen/CGStmt.cpp =================================================================== --- clang/lib/CodeGen/CGStmt.cpp +++ clang/lib/CodeGen/CGStmt.cpp @@ -206,6 +206,9 @@ case Stmt::OMPTileDirectiveClass: EmitOMPTileDirective(cast<OMPTileDirective>(*S)); break; + case Stmt::OMPUnrollDirectiveClass: + EmitOMPUnrollDirective(cast<OMPUnrollDirective>(*S)); + break; case Stmt::OMPForDirectiveClass: EmitOMPForDirective(cast<OMPForDirective>(*S)); break; Index: clang/lib/CodeGen/CGOpenMPRuntime.cpp =================================================================== --- clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6636,6 +6636,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -6954,6 +6955,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -9525,6 +9527,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -10349,6 +10352,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: @@ -11032,6 +11036,7 @@ case OMPD_task: case OMPD_simd: case OMPD_tile: + case OMPD_unroll: case OMPD_sections: case OMPD_section: case OMPD_single: Index: clang/lib/Basic/OpenMPKinds.cpp =================================================================== --- clang/lib/Basic/OpenMPKinds.cpp +++ clang/lib/Basic/OpenMPKinds.cpp @@ -448,7 +448,8 @@ DKind == OMPD_target_teams_distribute || DKind == OMPD_target_teams_distribute_parallel_for || DKind == OMPD_target_teams_distribute_parallel_for_simd || - DKind == OMPD_target_teams_distribute_simd || DKind == OMPD_tile; + DKind == OMPD_target_teams_distribute_simd || DKind == OMPD_tile || + DKind == OMPD_unroll; } bool clang::isOpenMPWorksharingDirective(OpenMPDirectiveKind DKind) { @@ -576,7 +577,7 @@ } bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { - return DKind == OMPD_tile; + return DKind == OMPD_tile || DKind == OMPD_unroll; } void clang::getOpenMPCaptureRegions( @@ -663,6 +664,7 @@ CaptureRegions.push_back(OMPD_unknown); break; case OMPD_tile: + case OMPD_unroll: // loop transformations do not introduce captures. break; case OMPD_threadprivate: Index: clang/lib/AST/StmtProfile.cpp =================================================================== --- clang/lib/AST/StmtProfile.cpp +++ clang/lib/AST/StmtProfile.cpp @@ -463,11 +463,18 @@ } void OMPClauseProfiler::VisitOMPSizesClause(const OMPSizesClause *C) { - for (auto E : C->getSizesRefs()) + for (Expr *E : C->getSizesRefs()) if (E) Profiler->VisitExpr(E); } +void OMPClauseProfiler::VisitOMPFullClause(const OMPFullClause *C) {} + +void OMPClauseProfiler::VisitOMPPartialClause(const OMPPartialClause *C) { + if (Expr *Factor = C->getFactor()) + Profiler->VisitExpr(Factor); +} + void OMPClauseProfiler::VisitOMPAllocatorClause(const OMPAllocatorClause *C) { if (C->getAllocator()) Profiler->VisitStmt(C->getAllocator()); @@ -878,6 +885,10 @@ VisitOMPLoopBasedDirective(S); } +void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { + VisitOMPLoopBasedDirective(S); +} + void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { VisitOMPLoopDirective(S); } Index: clang/lib/AST/StmtPrinter.cpp =================================================================== --- clang/lib/AST/StmtPrinter.cpp +++ clang/lib/AST/StmtPrinter.cpp @@ -669,6 +669,11 @@ PrintOMPExecutableDirective(Node); } +void StmtPrinter::VisitOMPUnrollDirective(OMPUnrollDirective *Node) { + Indent() << "#pragma omp unroll"; + PrintOMPExecutableDirective(Node); +} + void StmtPrinter::VisitOMPForDirective(OMPForDirective *Node) { Indent() << "#pragma omp for"; PrintOMPExecutableDirective(Node); Index: clang/lib/AST/StmtOpenMP.cpp =================================================================== --- clang/lib/AST/StmtOpenMP.cpp +++ clang/lib/AST/StmtOpenMP.cpp @@ -127,10 +127,16 @@ llvm::function_ref<bool(unsigned, Stmt *)> Callback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { - if (auto *Dir = dyn_cast<OMPTileDirective>(CurStmt)) - CurStmt = Dir->getTransformedStmt(); - if (auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(CurStmt)) - CurStmt = CanonLoop->getLoopStmt(); + while (true) { + if (auto *Dir = dyn_cast<OMPTileDirective>(CurStmt)) + CurStmt = Dir->getTransformedStmt(); + else if (auto *Dir = dyn_cast<OMPUnrollDirective>(CurStmt)) + CurStmt = Dir->getTransformedStmt(); + else if (auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(CurStmt)) + CurStmt = CanonLoop->getLoopStmt(); + else + break; + } if (Callback(Cnt, CurStmt)) return false; // Move on to the next nested for loop, or to the loop body. @@ -355,6 +361,25 @@ SourceLocation(), SourceLocation(), NumLoops); } +OMPUnrollDirective * +OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, + Stmt *AssociatedStmt, Stmt *TransformedStmt, + Stmt *PreInits) { + OMPUnrollDirective *Dir = createDirective<OMPUnrollDirective>( + C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); + Dir->setTransformedStmt(TransformedStmt); + Dir->setPreInits(PreInits); + return Dir; +} + +OMPUnrollDirective *OMPUnrollDirective::CreateEmpty(const ASTContext &C, + unsigned NumClauses) { + return createEmptyDirective<OMPUnrollDirective>( + C, NumClauses, /*HasAssociatedStmt=*/true, TransformedStmtOffset + 1, + SourceLocation(), SourceLocation()); +} + OMPForSimdDirective * OMPForSimdDirective::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, unsigned CollapsedNum, Index: clang/lib/AST/OpenMPClause.cpp =================================================================== --- clang/lib/AST/OpenMPClause.cpp +++ clang/lib/AST/OpenMPClause.cpp @@ -922,6 +922,36 @@ return new (Mem) OMPSizesClause(NumSizes); } +OMPFullClause *OMPFullClause::Create(const ASTContext &C, + SourceLocation StartLoc, + SourceLocation EndLoc) { + OMPFullClause *Clause = CreateEmpty(C); + Clause->setLocStart(StartLoc); + Clause->setLocEnd(EndLoc); + return Clause; +} + +OMPFullClause *OMPFullClause::CreateEmpty(const ASTContext &C) { + return new (C) OMPFullClause(); +} + +OMPPartialClause *OMPPartialClause::Create(const ASTContext &C, + SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, + Expr *Factor) { + OMPPartialClause *Clause = CreateEmpty(C); + Clause->setLocStart(StartLoc); + Clause->setLParenLoc(LParenLoc); + Clause->setLocEnd(EndLoc); + Clause->setFactor(Factor); + return Clause; +} + +OMPPartialClause *OMPPartialClause::CreateEmpty(const ASTContext &C) { + return new (C) OMPPartialClause(); +} + OMPAllocateClause * OMPAllocateClause::Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation LParenLoc, Expr *Allocator, @@ -1561,6 +1591,18 @@ OS << ")"; } +void OMPClausePrinter::VisitOMPFullClause(OMPFullClause *Node) { OS << "full"; } + +void OMPClausePrinter::VisitOMPPartialClause(OMPPartialClause *Node) { + OS << "partial"; + + if (Expr *Factor = Node->getFactor()) { + OS << '('; + Factor->printPretty(OS, nullptr, Policy, 0); + OS << ')'; + } +} + void OMPClausePrinter::VisitOMPAllocatorClause(OMPAllocatorClause *Node) { OS << "allocator("; Node->getAllocator()->printPretty(OS, nullptr, Policy, 0); Index: clang/include/clang/Serialization/ASTBitCodes.h =================================================================== --- clang/include/clang/Serialization/ASTBitCodes.h +++ clang/include/clang/Serialization/ASTBitCodes.h @@ -1888,6 +1888,7 @@ STMT_OMP_PARALLEL_DIRECTIVE, STMT_OMP_SIMD_DIRECTIVE, STMT_OMP_TILE_DIRECTIVE, + STMT_OMP_UNROLL_DIRECTIVE, STMT_OMP_FOR_DIRECTIVE, STMT_OMP_FOR_SIMD_DIRECTIVE, STMT_OMP_SECTIONS_DIRECTIVE, Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -32,6 +32,7 @@ #include "clang/AST/NSAPI.h" #include "clang/AST/PrettyPrinter.h" #include "clang/AST/StmtCXX.h" +#include "clang/AST/StmtOpenMP.h" #include "clang/AST/TypeLoc.h" #include "clang/AST/TypeOrdering.h" #include "clang/Basic/BitmaskEnum.h" @@ -10234,6 +10235,11 @@ MapT &Map, unsigned Selector = 0, SourceRange SrcRange = SourceRange()); + bool checkTransformableLoopNest( + OpenMPDirectiveKind Kind, Stmt *AStmt, int NumLoops, + SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers, + Stmt *&Body, SmallVectorImpl<Stmt *> &OriginalInits); + /// Helper to keep information about the current `omp begin/end declare /// variant` nesting. struct OMPDeclareVariantScope { @@ -10530,6 +10536,11 @@ StmtResult ActOnOpenMPTileDirective(ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc); + /// Called on well-formed '#pragma omp unroll' after parsing of its clauses + /// and the associated statement. + StmtResult ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, + Stmt *AStmt, SourceLocation StartLoc, + SourceLocation EndLoc); /// Called on well-formed '\#pragma omp for' after parsing /// of the associated statement. StmtResult @@ -10871,6 +10882,14 @@ SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); + + OMPClause *ActOnOpenMPFullClause(SourceLocation StartLoc, + SourceLocation EndLoc); + + OMPClause *ActOnOpenMPPartialClause(Expr *FactorExpr, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc); + /// Called on well-formed 'collapse' clause. OMPClause *ActOnOpenMPCollapseClause(Expr *NumForLoops, SourceLocation StartLoc, Index: clang/include/clang/Basic/StmtNodes.td =================================================================== --- clang/include/clang/Basic/StmtNodes.td +++ clang/include/clang/Basic/StmtNodes.td @@ -223,6 +223,7 @@ def OMPParallelDirective : StmtNode<OMPExecutableDirective>; def OMPSimdDirective : StmtNode<OMPLoopDirective>; def OMPTileDirective : StmtNode<OMPLoopBasedDirective>; +def OMPUnrollDirective : StmtNode<OMPLoopBasedDirective>; def OMPForDirective : StmtNode<OMPLoopDirective>; def OMPForSimdDirective : StmtNode<OMPLoopDirective>; def OMPSectionsDirective : StmtNode<OMPExecutableDirective>; Index: clang/include/clang/AST/StmtOpenMP.h =================================================================== --- clang/include/clang/AST/StmtOpenMP.h +++ clang/include/clang/AST/StmtOpenMP.h @@ -5034,6 +5034,78 @@ } }; +/// This represents the '#pragma omp tile' loop transformation directive. +class OMPUnrollDirective final : public OMPLoopBasedDirective { + friend class ASTStmtReader; + friend class OMPExecutableDirective; + + /// Default list of offsets. + enum { + PreInitsOffset = 0, + TransformedStmtOffset, + }; + + explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) + : OMPLoopBasedDirective(OMPUnrollDirectiveClass, llvm::omp::OMPD_unroll, + StartLoc, EndLoc, 1) {} + + void setPreInits(Stmt *PreInits) { + Data->getChildren()[PreInitsOffset] = PreInits; + } + + void setTransformedStmt(Stmt *S) { + Data->getChildren()[TransformedStmtOffset] = S; + } + +public: + /// Create a new AST node representation for '#pragma omp tile'. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the introducer (e.g. the 'omp' token). + /// \param EndLoc Location of the directive's end (e.g. the tok::eod). + /// \param Clauses The directive's clauses. + /// \param NumLoops Number of associated loops (number of items in the + /// 'sizes' clause). + /// \param AssociatedStmt The outermost associated loop. + /// \param TransformedStmt The loop nest after tiling, or nullptr in + /// dependent contexts. + /// \param PreInits Helper preinits statements for the loop nest. + static OMPUnrollDirective * + Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, + Stmt *TransformedStmt, Stmt *PreInits); + + /// Build an empty '#pragma omp tile' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumClauses Number of clauses to allocate. + /// \param NumLoops Number of associated loops to allocate. + static OMPUnrollDirective *CreateEmpty(const ASTContext &C, + unsigned NumClauses); + + /// Gets/sets the associated loops after tiling. + /// + /// This is in de-sugared format stored as a CompoundStmt. + /// + /// \code + /// for (...) + /// ... + /// \endcode + /// + /// Note that if the generated loops a become associated loops of another + /// directive, they may need to be hoisted before them. + Stmt *getTransformedStmt() const { + return Data->getChildren()[TransformedStmtOffset]; + } + + /// Return preinits statement. + Stmt *getPreInits() const { return Data->getChildren()[PreInitsOffset]; } + + static bool classof(const Stmt *T) { + return T->getStmtClass() == OMPUnrollDirectiveClass; + } +}; + /// This represents '#pragma omp scan' directive. /// /// \code Index: clang/include/clang/AST/RecursiveASTVisitor.h =================================================================== --- clang/include/clang/AST/RecursiveASTVisitor.h +++ clang/include/clang/AST/RecursiveASTVisitor.h @@ -2810,6 +2810,9 @@ DEF_TRAVERSE_STMT(OMPTileDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) +DEF_TRAVERSE_STMT(OMPUnrollDirective, + { TRY_TO(TraverseOMPExecutableDirective(S)); }) + DEF_TRAVERSE_STMT(OMPForDirective, { TRY_TO(TraverseOMPExecutableDirective(S)); }) @@ -3057,6 +3060,17 @@ return true; } +template <typename Derived> +bool RecursiveASTVisitor<Derived>::VisitOMPFullClause(OMPFullClause *C) { + return true; +} + +template <typename Derived> +bool RecursiveASTVisitor<Derived>::VisitOMPPartialClause(OMPPartialClause *C) { + TRY_TO(TraverseStmt(C->getFactor())); + return true; +} + template <typename Derived> bool RecursiveASTVisitor<Derived>::VisitOMPCollapseClause(OMPCollapseClause *C) { Index: clang/include/clang/AST/OpenMPClause.h =================================================================== --- clang/include/clang/AST/OpenMPClause.h +++ clang/include/clang/AST/OpenMPClause.h @@ -888,6 +888,106 @@ } }; +class OMPFullClause final : public OMPClause { + friend class OMPClauseReader; + + /// Build an empty clause. + explicit OMPFullClause() + : OMPClause(llvm::omp::OMPC_full, SourceLocation(), SourceLocation()) {} + +public: + /// Build a 'sizes' AST node. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the 'sizes' identifier. + /// \param LParenLoc Location of '('. + /// \param EndLoc Location of ')'. + /// \param Sizes Content of the clause. + static OMPFullClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation EndLoc); + + /// Build an empty 'sizes' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumSizes Number of items in the clause. + static OMPFullClause *CreateEmpty(const ASTContext &C); + + child_range children() { return {child_iterator(), child_iterator()}; } + const_child_range children() const { + return {const_child_iterator(), const_child_iterator()}; + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_full; + } +}; + +class OMPPartialClause final : public OMPClause { + friend class OMPClauseReader; + + /// Location of '('. + SourceLocation LParenLoc; + + Stmt *Factor; + + /// Build an empty clause. + explicit OMPPartialClause() + : OMPClause(llvm::omp::OMPC_partial, SourceLocation(), SourceLocation()) { + } + +public: + /// Build a 'sizes' AST node. + /// + /// \param C Context of the AST. + /// \param StartLoc Location of the 'sizes' identifier. + /// \param LParenLoc Location of '('. + /// \param EndLoc Location of ')'. + /// \param Sizes Content of the clause. + static OMPPartialClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, Expr *Factor); + + /// Build an empty 'sizes' AST node for deserialization. + /// + /// \param C Context of the AST. + /// \param NumSizes Number of items in the clause. + static OMPPartialClause *CreateEmpty(const ASTContext &C); + + /// Sets the location of '('. + void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } + + /// Returns the location of '('. + SourceLocation getLParenLoc() const { return LParenLoc; } + + Expr *getFactor() const { return cast_or_null<Expr>(Factor); } + + void setFactor(Expr *E) { Factor = E; } + + child_range children() { return child_range(&Factor, &Factor + 1); } + + const_child_range children() const { + return const_child_range(&Factor, &Factor + 1); + } + + child_range used_children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range used_children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + + static bool classof(const OMPClause *T) { + return T->getClauseKind() == llvm::omp::OMPC_partial; + } +}; + /// This represents 'collapse' clause in the '#pragma omp ...' /// directive. /// Index: clang/include/clang-c/Index.h =================================================================== --- clang/include/clang-c/Index.h +++ clang/include/clang-c/Index.h @@ -2576,7 +2576,11 @@ */ CXCursor_OMPCanonicalLoop = 289, - CXCursor_LastStmt = CXCursor_OMPCanonicalLoop, + /** OpenMP unroll directive. + */ + CXCursor_OMPUnrollDirective = 290, + + CXCursor_LastStmt = CXCursor_OMPUnrollDirective, /** * Cursor that represents the translation unit itself.
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits