https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/102715
>From fd59f45bbe6912e3683f4a684fccdcc459bdd58a Mon Sep 17 00:00:00 2001 From: Shilei Tian <i...@tianshilei.me> Date: Fri, 9 Aug 2024 22:14:36 -0400 Subject: [PATCH] [Clang][Sema][OpenMP] Allow `thread_limit` to accept multiple expressions --- clang/docs/OpenMPSupport.rst | 3 +- clang/docs/ReleaseNotes.rst | 5 +- clang/include/clang/AST/OpenMPClause.h | 81 +++++++++++-------- clang/include/clang/AST/RecursiveASTVisitor.h | 2 +- clang/include/clang/Sema/SemaOpenMP.h | 2 +- clang/lib/AST/OpenMPClause.cpp | 26 +++++- clang/lib/AST/StmtProfile.cpp | 3 +- clang/lib/CodeGen/CGOpenMPRuntime.cpp | 15 ++-- clang/lib/CodeGen/CGStmtOpenMP.cpp | 4 +- clang/lib/Parse/ParseOpenMP.cpp | 2 +- clang/lib/Sema/SemaOpenMP.cpp | 58 ++++++++----- clang/lib/Sema/TreeTransform.h | 19 +++-- clang/lib/Serialization/ASTReader.cpp | 9 ++- clang/lib/Serialization/ASTWriter.cpp | 4 +- clang/test/OpenMP/target_teams_ast_print.cpp | 4 +- ...et_teams_distribute_num_teams_messages.cpp | 12 +++ ...ribute_parallel_for_num_teams_messages.cpp | 5 ++ .../test/OpenMP/teams_num_teams_messages.cpp | 7 ++ clang/tools/libclang/CIndex.cpp | 2 +- 19 files changed, 181 insertions(+), 82 deletions(-) diff --git a/clang/docs/OpenMPSupport.rst b/clang/docs/OpenMPSupport.rst index 3fc74cdd07f71c..cdbd69520e5bb5 100644 --- a/clang/docs/OpenMPSupport.rst +++ b/clang/docs/OpenMPSupport.rst @@ -363,7 +363,8 @@ considered for standardization. Please post on the | device extension | `'ompx_bare' clause on 'target teams' construct | :good:`prototyped` | #66844, #70612 | | | <https://www.osti.gov/servlets/purl/2205717>`_ | | | +------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+ -| device extension | Multi-dim 'num_teams' clause on 'target teams ompx_bare' construct | :good:`partial` | #99732, #101407 | +| device extension | Multi-dim 'num_teams' and 'thread_limit' clause on 'target teams ompx_bare' | :good:`partial` | #99732, #101407, #102715 | +| | construct | | | +------------------------------+-----------------------------------------------------------------------------------+--------------------------+--------------------------------------------------------+ .. _Discourse forums (Runtimes - OpenMP category): https://discourse.llvm.org/c/runtimes/openmp/35 diff --git a/clang/docs/ReleaseNotes.rst b/clang/docs/ReleaseNotes.rst index 351b41b1c0c588..602f3edaf121cb 100644 --- a/clang/docs/ReleaseNotes.rst +++ b/clang/docs/ReleaseNotes.rst @@ -360,8 +360,9 @@ Improvements ^^^^^^^^^^^^ - Improve the handling of mapping array-section for struct containing nested structs with user defined mappers -- `num_teams` now accepts multiple expressions when it is used along in ``target teams ompx_bare`` construct. - This allows the target region to be launched with multi-dim grid on GPUs. +- `num_teams` and `thead_limit` now accept multiple expressions when it is used + along in ``target teams ompx_bare`` construct. This allows the target region + to be launched with multi-dim grid on GPUs. Additional Information ====================== diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 1e830b14727c19..c1b9e0dbafb6c3 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6462,44 +6462,55 @@ class OMPNumTeamsClause final /// \endcode /// In this example directive '#pragma omp teams' has clause 'thread_limit' /// with single expression 'n'. -class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit { - friend class OMPClauseReader; +/// +/// When 'ompx_bare' clause exists on a 'target' directive, 'thread_limit' +/// clause can accept up to three expressions. +/// +/// \code +/// #pragma omp target teams ompx_bare thread_limit(x, y, z) +/// \endcode +class OMPThreadLimitClause final + : public OMPVarListClause<OMPThreadLimitClause>, + public OMPClauseWithPreInit, + private llvm::TrailingObjects<OMPThreadLimitClause, Expr *> { + friend OMPVarListClause; + friend TrailingObjects; /// Location of '('. SourceLocation LParenLoc; - /// ThreadLimit number. - Stmt *ThreadLimit = nullptr; + OMPThreadLimitClause(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, + unsigned N) + : OMPVarListClause(llvm::omp::OMPC_thread_limit, StartLoc, LParenLoc, + EndLoc, N), + OMPClauseWithPreInit(this) {} - /// Set the ThreadLimit number. - /// - /// \param E ThreadLimit number. - void setThreadLimit(Expr *E) { ThreadLimit = E; } + /// Build an empty clause. + OMPThreadLimitClause(unsigned N) + : OMPVarListClause(llvm::omp::OMPC_thread_limit, SourceLocation(), + SourceLocation(), SourceLocation(), N), + OMPClauseWithPreInit(this) {} public: - /// Build 'thread_limit' clause. + /// Creates clause with a list of variables \a VL. /// - /// \param E Expression associated with this clause. - /// \param HelperE Helper Expression associated with this clause. - /// \param CaptureRegion Innermost OpenMP region where expressions in this - /// clause must be captured. + /// \param C AST context. /// \param StartLoc Starting location of the clause. /// \param LParenLoc Location of '('. /// \param EndLoc Ending location of the clause. - OMPThreadLimitClause(Expr *E, Stmt *HelperE, - OpenMPDirectiveKind CaptureRegion, - SourceLocation StartLoc, SourceLocation LParenLoc, - SourceLocation EndLoc) - : OMPClause(llvm::omp::OMPC_thread_limit, StartLoc, EndLoc), - OMPClauseWithPreInit(this), LParenLoc(LParenLoc), ThreadLimit(E) { - setPreInitStmt(HelperE, CaptureRegion); - } + /// \param VL List of references to the variables. + /// \param PreInit + static OMPThreadLimitClause * + Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit); - /// Build an empty clause. - OMPThreadLimitClause() - : OMPClause(llvm::omp::OMPC_thread_limit, SourceLocation(), - SourceLocation()), - OMPClauseWithPreInit(this) {} + /// Creates an empty clause with \a N variables. + /// + /// \param C AST context. + /// \param N The number of variables. + static OMPThreadLimitClause *CreateEmpty(const ASTContext &C, unsigned N); /// Sets the location of '('. void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } @@ -6507,16 +6518,22 @@ class OMPThreadLimitClause : public OMPClause, public OMPClauseWithPreInit { /// Returns the location of '('. SourceLocation getLParenLoc() const { return LParenLoc; } - /// Return ThreadLimit number. - Expr *getThreadLimit() { return cast<Expr>(ThreadLimit); } + /// Return ThreadLimit expressions. + ArrayRef<Expr *> getThreadLimit() { return getVarRefs(); } - /// Return ThreadLimit number. - Expr *getThreadLimit() const { return cast<Expr>(ThreadLimit); } + /// Return ThreadLimit expressions. + ArrayRef<Expr *> getThreadLimit() const { + return const_cast<OMPThreadLimitClause *>(this)->getThreadLimit(); + } - child_range children() { return child_range(&ThreadLimit, &ThreadLimit + 1); } + child_range children() { + return child_range(reinterpret_cast<Stmt **>(varlist_begin()), + reinterpret_cast<Stmt **>(varlist_end())); + } const_child_range children() const { - return const_child_range(&ThreadLimit, &ThreadLimit + 1); + auto Children = const_cast<OMPThreadLimitClause *>(this)->children(); + return const_child_range(Children.begin(), Children.end()); } child_range used_children() { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index b505c746cc7dc2..2b35997bd539ac 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3836,8 +3836,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause( template <typename Derived> bool RecursiveASTVisitor<Derived>::VisitOMPThreadLimitClause( OMPThreadLimitClause *C) { + TRY_TO(VisitOMPClauseList(C)); TRY_TO(VisitOMPClauseWithPreInit(C)); - TRY_TO(TraverseStmt(C->getThreadLimit())); return true; } diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index 0ceb5fc07765c4..e55731212c4a41 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -1264,7 +1264,7 @@ class SemaOpenMP : public SemaBase { SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'thread_limit' clause. - OMPClause *ActOnOpenMPThreadLimitClause(Expr *ThreadLimit, + OMPClause *ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index 6bdc86f6167920..7e73c076239410 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1773,6 +1773,24 @@ OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C, return new (Mem) OMPNumTeamsClause(N); } +OMPThreadLimitClause *OMPThreadLimitClause::Create( + const ASTContext &C, OpenMPDirectiveKind CaptureRegion, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, + ArrayRef<Expr *> VL, Stmt *PreInit) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size())); + OMPThreadLimitClause *Clause = + new (Mem) OMPThreadLimitClause(C, StartLoc, LParenLoc, EndLoc, VL.size()); + Clause->setVarRefs(VL); + Clause->setPreInitStmt(PreInit, CaptureRegion); + return Clause; +} + +OMPThreadLimitClause *OMPThreadLimitClause::CreateEmpty(const ASTContext &C, + unsigned N) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N)); + return new (Mem) OMPThreadLimitClause(N); +} + //===----------------------------------------------------------------------===// // OpenMP clauses printing methods //===----------------------------------------------------------------------===// @@ -2081,9 +2099,11 @@ void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) { } void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) { - OS << "thread_limit("; - Node->getThreadLimit()->printPretty(OS, nullptr, Policy, 0); - OS << ")"; + if (!Node->varlist_empty()) { + OS << "thread_limit"; + VisitOMPClauseList(Node, '('); + OS << ")"; + } } void OMPClausePrinter::VisitOMPPriorityClause(OMPPriorityClause *Node) { diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index bf46984e94a85d..35d8b0706fe3ce 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -862,9 +862,8 @@ void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { } void OMPClauseProfiler::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { + VisitOMPClauseList(C); VistOMPClauseWithPreInit(C); - if (C->getThreadLimit()) - Profiler->VisitStmt(C->getThreadLimit()); } void OMPClauseProfiler::VisitOMPPriorityClause(const OMPPriorityClause *C) { VistOMPClauseWithPreInit(C); diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index be8ab2d121277e..8c5e4aa9c037e2 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6332,7 +6332,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective( CGOpenMPInnerExprInfo CGInfo(CGF, *CS); CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo); CodeGenFunction::LexicalScope Scope( - CGF, ThreadLimitClause->getThreadLimit()->getSourceRange()); + CGF, + ThreadLimitClause->getThreadLimit().front()->getSourceRange()); if (const auto *PreInit = cast_or_null<DeclStmt>(ThreadLimitClause->getPreInitStmt())) { for (const auto *I : PreInit->decls()) { @@ -6349,7 +6350,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective( } } if (ThreadLimitClause) - CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr); + CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(), + ThreadLimitExpr); if (const auto *Dir = dyn_cast_or_null<OMPExecutableDirective>(Child)) { if (isOpenMPTeamsDirective(Dir->getDirectiveKind()) && !isOpenMPDistributeDirective(Dir->getDirectiveKind())) { @@ -6370,7 +6372,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective( if (D.hasClausesOfKind<OMPThreadLimitClause>()) { CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF); const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>(); - CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr); + CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(), + ThreadLimitExpr); } const CapturedStmt *CS = D.getInnermostCapturedStmt(); getNumThreads(CGF, CS, NTPtr, UpperBound, UpperBoundOnly, CondVal); @@ -6388,7 +6391,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective( if (D.hasClausesOfKind<OMPThreadLimitClause>()) { CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF); const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>(); - CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr); + CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(), + ThreadLimitExpr); } getNumThreads(CGF, D.getInnermostCapturedStmt(), NTPtr, UpperBound, UpperBoundOnly, CondVal); @@ -6424,7 +6428,8 @@ const Expr *CGOpenMPRuntime::getNumThreadsExprForTargetDirective( if (D.hasClausesOfKind<OMPThreadLimitClause>()) { CodeGenFunction::RunCleanupsScope ThreadLimitScope(CGF); const auto *ThreadLimitClause = D.getSingleClause<OMPThreadLimitClause>(); - CheckForConstExpr(ThreadLimitClause->getThreadLimit(), ThreadLimitExpr); + CheckForConstExpr(ThreadLimitClause->getThreadLimit().front(), + ThreadLimitExpr); } if (D.hasClausesOfKind<OMPNumThreadsClause>()) { CodeGenFunction::RunCleanupsScope NumThreadsScope(CGF); diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index 6841ceb3b41548..8afe2abf2cc494 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -5259,7 +5259,7 @@ void CodeGenFunction::EmitOMPTargetTaskBasedDirective( // enclosing this target region. This will indirectly set the thread_limit // for every applicable construct within target region. CGF.CGM.getOpenMPRuntime().emitThreadLimitClause( - CGF, TL->getThreadLimit(), S.getBeginLoc()); + CGF, TL->getThreadLimit().front(), S.getBeginLoc()); } BodyGen(CGF); }; @@ -6860,7 +6860,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF, const auto *TL = S.getSingleClause<OMPThreadLimitClause>(); if (NT || TL) { const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr; - const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr; + const Expr *ThreadLimit = TL ? TL->getThreadLimit().front() : nullptr; CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit, S.getBeginLoc()); diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index 5732ee7add7c03..61aa72c30a4654 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3175,7 +3175,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_simdlen: case OMPC_collapse: case OMPC_ordered: - case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: case OMPC_num_tasks: @@ -3332,6 +3331,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, : ParseOpenMPClause(CKind, WrongDirective); break; case OMPC_num_teams: + case OMPC_thread_limit: if (!FirstClause) { Diag(Tok, diag::err_omp_more_one_clause) << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index b5978ddde24651..87d81dfaad601b 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -13061,6 +13061,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, return StmtError(); if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || + !checkNumExprsInClause<OMPThreadLimitClause>( *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) return StmtError(); @@ -13843,7 +13845,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective( ? diag::err_ompx_more_than_three_expr_not_allowed : diag::err_omp_multi_expr_not_allowed; if (!checkNumExprsInClause<OMPNumTeamsClause>(*this, Clauses, - ClauseMaxNumExprs, DiagNo)) + ClauseMaxNumExprs, DiagNo) || + !checkNumExprsInClause<OMPThreadLimitClause>(*this, Clauses, + ClauseMaxNumExprs, DiagNo)) return StmtError(); return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc, @@ -13857,6 +13861,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective( return StmtError(); if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || + !checkNumExprsInClause<OMPThreadLimitClause>( *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) return StmtError(); @@ -13887,6 +13893,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective( return StmtError(); if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || + !checkNumExprsInClause<OMPThreadLimitClause>( *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) return StmtError(); @@ -13918,6 +13926,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( return StmtError(); if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || + !checkNumExprsInClause<OMPThreadLimitClause>( *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) return StmtError(); @@ -13952,6 +13962,8 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective( return StmtError(); if (!checkNumExprsInClause<OMPNumTeamsClause>( + *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed) || + !checkNumExprsInClause<OMPThreadLimitClause>( *this, Clauses, /*MaxNum=*/1, diag::err_omp_multi_expr_not_allowed)) return StmtError(); @@ -15002,9 +15014,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_ordered: Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr); break; - case OMPC_thread_limit: - Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc); - break; case OMPC_priority: Res = ActOnOpenMPPriorityClause(Expr, StartLoc, LParenLoc, EndLoc); break; @@ -15109,6 +15118,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_when: case OMPC_bind: case OMPC_num_teams: + case OMPC_thread_limit: default: llvm_unreachable("Clause is not allowed."); } @@ -16975,6 +16985,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_num_teams: Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc); break; + case OMPC_thread_limit: + Res = ActOnOpenMPThreadLimitClause(VarList, StartLoc, LParenLoc, EndLoc); + break; case OMPC_if: case OMPC_depobj: case OMPC_final: @@ -17005,7 +17018,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_device: case OMPC_threads: case OMPC_simd: - case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: case OMPC_nogroup: @@ -21919,32 +21931,40 @@ OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, LParenLoc, EndLoc, Vars, PreInit); } -OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit, +OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - Expr *ValExpr = ThreadLimit; - Stmt *HelperValStmt = nullptr; - - // OpenMP [teams Constrcut, Restrictions] - // The thread_limit expression must evaluate to a positive integer value. - if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_thread_limit, - /*StrictlyPositive=*/true)) + if (VarList.empty()) return nullptr; + for (Expr *ValExpr : VarList) { + // OpenMP [teams Constrcut, Restrictions] + // The thread_limit expression must evaluate to a positive integer value. + if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_thread_limit, + /*StrictlyPositive=*/true)) + return nullptr; + } + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause( DKind, OMPC_thread_limit, getLangOpts().OpenMP); - if (CaptureRegion != OMPD_unknown && - !SemaRef.CurContext->isDependentContext()) { + if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext()) + return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion, + StartLoc, LParenLoc, EndLoc, VarList, + /*PreInit=*/nullptr); + + llvm::MapVector<const Expr *, DeclRefExpr *> Captures; + SmallVector<Expr *, 3> Vars; + for (Expr *ValExpr : VarList) { ValExpr = SemaRef.MakeFullExpr(ValExpr).get(); - llvm::MapVector<const Expr *, DeclRefExpr *> Captures; ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get(); - HelperValStmt = buildPreInits(getASTContext(), Captures); + Vars.push_back(ValExpr); } - return new (getASTContext()) OMPThreadLimitClause( - ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); + Stmt *PreInit = buildPreInits(getASTContext(), Captures); + return OMPThreadLimitClause::Create(getASTContext(), CaptureRegion, StartLoc, + LParenLoc, EndLoc, Vars, PreInit); } OMPClause *SemaOpenMP::ActOnOpenMPPriorityClause(Expr *Priority, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 8f6f30434af65e..78ec964037dfe9 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2091,12 +2091,12 @@ class TreeTransform { /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - OMPClause *RebuildOMPThreadLimitClause(Expr *ThreadLimit, + OMPClause *RebuildOMPThreadLimitClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - return getSema().OpenMP().ActOnOpenMPThreadLimitClause( - ThreadLimit, StartLoc, LParenLoc, EndLoc); + return getSema().OpenMP().ActOnOpenMPThreadLimitClause(VarList, StartLoc, + LParenLoc, EndLoc); } /// Build a new OpenMP 'priority' clause. @@ -11028,11 +11028,16 @@ TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) { template <typename Derived> OMPClause * TreeTransform<Derived>::TransformOMPThreadLimitClause(OMPThreadLimitClause *C) { - ExprResult E = getDerived().TransformExpr(C->getThreadLimit()); - if (E.isInvalid()) - return nullptr; + llvm::SmallVector<Expr *, 3> Vars; + Vars.reserve(C->varlist_size()); + for (auto *VE : C->varlist()) { + ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE)); + if (EVar.isInvalid()) + return nullptr; + Vars.push_back(EVar.get()); + } return getDerived().RebuildOMPThreadLimitClause( - E.get(), C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); + Vars, C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc()); } template <typename Derived> diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index ad8d6c336f2780..e1d554ee7db224 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -10645,7 +10645,7 @@ OMPClause *OMPClauseReader::readClause() { C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt()); break; case llvm::omp::OMPC_thread_limit: - C = new (Context) OMPThreadLimitClause(); + C = OMPThreadLimitClause::CreateEmpty(Context, Record.readInt()); break; case llvm::omp::OMPC_priority: C = new (Context) OMPPriorityClause(); @@ -11477,8 +11477,13 @@ void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { VisitOMPClauseWithPreInit(C); - C->setThreadLimit(Record.readSubExpr()); C->setLParenLoc(Record.readSourceLocation()); + unsigned NumVars = C->varlist_size(); + SmallVector<Expr *, 16> Vars; + Vars.reserve(NumVars); + for (auto _ : llvm::seq<unsigned>(NumVars)) + Vars.push_back(Record.readSubExpr()); + C->setVarRefs(Vars); } void OMPClauseReader::VisitOMPPriorityClause(OMPPriorityClause *C) { diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index 25e50e4bdc5f80..b5d487465541b8 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -7589,9 +7589,11 @@ void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { } void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { + Record.push_back(C->varlist_size()); VisitOMPClauseWithPreInit(C); - Record.AddStmt(C->getThreadLimit()); Record.AddSourceLocation(C->getLParenLoc()); + for (auto *VE : C->varlist()) + Record.AddStmt(VE); } void OMPClauseWriter::VisitOMPPriorityClause(OMPPriorityClause *C) { diff --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp index 1590a996289f8f..ca5d26822ec96d 100644 --- a/clang/test/OpenMP/target_teams_ast_print.cpp +++ b/clang/test/OpenMP/target_teams_ast_print.cpp @@ -115,8 +115,8 @@ int main (int argc, char **argv) { // CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1) thread_limit(32) a=3; // CHECK-NEXT: a = 3; -#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(32) -// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(32) +#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(2, 4, 6) +// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(2,4,6) a=4; // CHECK-NEXT: a = 4; #pragma omp target teams default(none), private(argc,b) num_teams(f) firstprivate(argv) reduction(| : c, d) reduction(* : e) thread_limit(f+g) diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp index b489e6a860d672..8bf388f0b5da98 100644 --- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp @@ -47,9 +47,15 @@ T tmain(T argc) { #pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + for (int i=0; i<100; i++) foo(); + #pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} + for (int i=0; i<100; i++) foo(); + return 0; } @@ -94,8 +100,14 @@ int main(int argc, char **argv) { #pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + for (int i=0; i<100; i++) foo(); + #pragma omp target teams ompx_bare num_teams(1, 2, 3, 4) thread_limit(1) // expected-error {{at most three expressions are allowed in 'num_teams' clause in 'target teams ompx_bare' construct}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams ompx_bare num_teams(1) thread_limit(1, 2, 3, 4) // expected-error {{at most three expressions are allowed in 'thread_limit' clause in 'target teams ompx_bare' construct}} + for (int i=0; i<100; i++) foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp index fa6e8f5887f834..092e0137d250d8 100644 --- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp @@ -45,6 +45,8 @@ T tmain(T argc) { for (int i=0; i<100; i++) foo(); #pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + for (int i=0; i<100; i++) foo(); return 0; } @@ -90,5 +92,8 @@ int main(int argc, char **argv) { #pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute parallel for thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + for (int i=0; i<100; i++) foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp index 0cfecc5e117438..615bf0be0d8147 100644 --- a/clang/test/OpenMP/teams_num_teams_messages.cpp +++ b/clang/test/OpenMP/teams_num_teams_messages.cpp @@ -60,6 +60,9 @@ T tmain(T argc) { #pragma omp target #pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} foo(); +#pragma omp target +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + foo(); return 0; } @@ -118,5 +121,9 @@ int main(int argc, char **argv) { #pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed in 'num_teams' clause}} foo(); +#pragma omp target +#pragma omp teams thread_limit(1, 2, 3) // expected-error {{only one expression allowed in 'thread_limit' clause}} + foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 48b34e025729c8..66636f2c665feb 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2522,8 +2522,8 @@ void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { void OMPClauseEnqueue::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { + VisitOMPClauseList(C); VisitOMPClauseWithPreInit(C); - Visitor->AddStmt(C->getThreadLimit()); } void OMPClauseEnqueue::VisitOMPPriorityClause(const OMPPriorityClause *C) { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits