https://github.com/shiltian updated https://github.com/llvm/llvm-project/pull/99732
>From fefe6d36301f93097e73725876e89235a279225d Mon Sep 17 00:00:00 2001 From: Shilei Tian <i...@tianshilei.me> Date: Fri, 19 Jul 2024 22:07:06 -0400 Subject: [PATCH] [Clang][OpenMP] Allow `num_teams` to accept multiple expressions --- clang/include/clang/AST/OpenMPClause.h | 79 +++++++++++------- clang/include/clang/AST/RecursiveASTVisitor.h | 2 +- .../clang/Basic/DiagnosticSemaKinds.td | 1 + clang/include/clang/Sema/SemaOpenMP.h | 3 +- clang/lib/AST/OpenMPClause.cpp | 26 +++++- clang/lib/AST/StmtProfile.cpp | 3 +- clang/lib/CodeGen/CGOpenMPRuntime.cpp | 7 +- clang/lib/CodeGen/CGStmtOpenMP.cpp | 2 +- clang/lib/Parse/ParseOpenMP.cpp | 8 +- clang/lib/Sema/SemaOpenMP.cpp | 80 ++++++++++++++----- clang/lib/Sema/TreeTransform.h | 7 +- clang/lib/Serialization/ASTReader.cpp | 10 ++- clang/lib/Serialization/ASTWriter.cpp | 4 +- clang/test/OpenMP/target_teams_ast_print.cpp | 4 + ...et_teams_distribute_num_teams_messages.cpp | 6 ++ ...ribute_parallel_for_num_teams_messages.cpp | 5 ++ .../test/OpenMP/teams_num_teams_messages.cpp | 7 ++ clang/tools/libclang/CIndex.cpp | 2 +- 18 files changed, 188 insertions(+), 68 deletions(-) diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index b029c72fa7d8f..50ac1e0ea8db7 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>, /// \endcode /// In this example directive '#pragma omp teams' has clause 'num_teams' /// with single expression 'n'. -class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit { - friend class OMPClauseReader; +/// +/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause +/// can accept up to three expressions. +/// +/// \code +/// #pragma omp target teams ompx_bare num_teams(x, y, z) +/// \endcode +class OMPNumTeamsClause final + : public OMPVarListClause<OMPNumTeamsClause>, + public OMPClauseWithPreInit, + private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> { + friend OMPVarListClause; + friend TrailingObjects; /// Location of '('. SourceLocation LParenLoc; - /// NumTeams number. - Stmt *NumTeams = nullptr; + OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N) + : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc, + N), + OMPClauseWithPreInit(this) {} - /// Set the NumTeams number. - /// - /// \param E NumTeams number. - void setNumTeams(Expr *E) { NumTeams = E; } + /// Build an empty clause. + OMPNumTeamsClause(unsigned N) + : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(), + SourceLocation(), SourceLocation(), N), + OMPClauseWithPreInit(this) {} public: - /// Build 'num_teams' clause. + /// Creates clause with a list of variables \a VL. /// - /// \param E Expression associated with this clause. - /// \param HelperE Helper Expression associated with this clause. - /// \param CaptureRegion Innermost OpenMP region where expressions in this - /// clause must be captured. + /// \param C AST context. /// \param StartLoc Starting location of the clause. /// \param LParenLoc Location of '('. /// \param EndLoc Ending location of the clause. - OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion, - SourceLocation StartLoc, SourceLocation LParenLoc, - SourceLocation EndLoc) - : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc), - OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) { - setPreInitStmt(HelperE, CaptureRegion); - } + /// \param VL List of references to the variables. + /// \param PreInit + static OMPNumTeamsClause * + Create(const ASTContext &C, OpenMPDirectiveKind CaptureRegion, + SourceLocation StartLoc, SourceLocation LParenLoc, + SourceLocation EndLoc, ArrayRef<Expr *> VL, Stmt *PreInit); - /// Build an empty clause. - OMPNumTeamsClause() - : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(), - SourceLocation()), - OMPClauseWithPreInit(this) {} + /// Creates an empty clause with \a N variables. + /// + /// \param C AST context. + /// \param N The number of variables. + static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N); /// Sets the location of '('. void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } @@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit { /// Returns the location of '('. SourceLocation getLParenLoc() const { return LParenLoc; } - /// Return NumTeams number. - Expr *getNumTeams() { return cast<Expr>(NumTeams); } + /// Return NumTeams expressions. + ArrayRef<Expr *> getNumTeams() { return getVarRefs(); } - /// Return NumTeams number. - Expr *getNumTeams() const { return cast<Expr>(NumTeams); } + /// Return NumTeams expressions. + ArrayRef<Expr *> getNumTeams() const { + return const_cast<OMPNumTeamsClause *>(this)->getNumTeams(); + } - child_range children() { return child_range(&NumTeams, &NumTeams + 1); } + child_range children() { + return child_range(reinterpret_cast<Stmt **>(varlist_begin()), + reinterpret_cast<Stmt **>(varlist_end())); + } const_child_range children() const { - return const_child_range(&NumTeams, &NumTeams + 1); + auto Children = const_cast<OMPNumTeamsClause *>(this)->children(); + return const_child_range(Children.begin(), Children.end()); } child_range used_children() { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index dcf5dbf449f8b..9a6e8a9ea1c7b 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) { template <typename Derived> bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause( OMPNumTeamsClause *C) { + TRY_TO(VisitOMPClauseList(C)); TRY_TO(VisitOMPClauseWithPreInit(C)); - TRY_TO(TraverseStmt(C->getNumTeams())); return true; } diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 581434d33c5c9..8e98aa028db08 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -11639,6 +11639,7 @@ def warn_omp_unterminated_declare_target : Warning< InGroup<SourceUsesOpenMP>; def err_ompx_bare_no_grid : Error< "'ompx_bare' clauses requires explicit grid size via 'num_teams' and 'thread_limit' clauses">; +def err_omp_multi_expr_not_allowed: Error<"only one expression allowed to '%0' clause">; } // end of OpenMP category let CategoryName = "Related Result Type Issue" in { diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index aa61dae9415e2..703c1511fc3ae 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -1226,7 +1226,8 @@ class SemaOpenMP : public SemaBase { const OMPVarListLocTy &Locs, bool NoDiagnose = false, ArrayRef<Expr *> UnresolvedMappers = std::nullopt); /// Called on well-formed 'num_teams' clause. - OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc, + OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'thread_limit' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index 042a5df5906ca..9ec2f593b4477 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const { return *It; } +OMPNumTeamsClause *OMPNumTeamsClause::Create( + const ASTContext &C, OpenMPDirectiveKind CaptureRegion, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc, + ArrayRef<Expr *> VL, Stmt *PreInit) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size())); + OMPNumTeamsClause *Clause = + new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size()); + Clause->setVarRefs(VL); + Clause->setPreInitStmt(PreInit, CaptureRegion); + return Clause; +} + +OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C, + unsigned N) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N)); + return new (Mem) OMPNumTeamsClause(N); +} + //===----------------------------------------------------------------------===// // OpenMP clauses printing methods //===----------------------------------------------------------------------===// @@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) { } void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) { - OS << "num_teams("; - Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0); - OS << ")"; + if (!Node->varlist_empty()) { + OS << "num_teams"; + VisitOMPClauseList(Node, '('); + OS << ")"; + } } void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) { diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index f1e723b4242ee..00ba8e490ac4e 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) { VisitOMPClauseList(C); } void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { + VisitOMPClauseList(C); VistOMPClauseWithPreInit(C); - if (C->getNumTeams()) - Profiler->VisitStmt(C->getNumTeams()); } void OMPClauseProfiler::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp index d869aa3322cce..f229202ae5535 100644 --- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp +++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp @@ -6036,8 +6036,9 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective( dyn_cast_or_null<OMPExecutableDirective>(ChildStmt)) { if (isOpenMPTeamsDirective(NestedDir->getDirectiveKind())) { if (NestedDir->hasClausesOfKind<OMPNumTeamsClause>()) { - const Expr *NumTeams = - NestedDir->getSingleClause<OMPNumTeamsClause>()->getNumTeams(); + const Expr *NumTeams = NestedDir->getSingleClause<OMPNumTeamsClause>() + ->getNumTeams() + .front(); if (NumTeams->isIntegerConstantExpr(CGF.getContext())) if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext())) @@ -6062,7 +6063,7 @@ const Expr *CGOpenMPRuntime::getNumTeamsExprForTargetDirective( case OMPD_target_teams_distribute_parallel_for_simd: { if (D.hasClausesOfKind<OMPNumTeamsClause>()) { const Expr *NumTeams = - D.getSingleClause<OMPNumTeamsClause>()->getNumTeams(); + D.getSingleClause<OMPNumTeamsClause>()->getNumTeams().front(); if (NumTeams->isIntegerConstantExpr(CGF.getContext())) if (auto Constant = NumTeams->getIntegerConstantExpr(CGF.getContext())) MinTeamsVal = MaxTeamsVal = Constant->getExtValue(); diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index b1ac9361957ff..0cb8b7804f644 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -6859,7 +6859,7 @@ static void emitCommonOMPTeamsDirective(CodeGenFunction &CGF, const auto *NT = S.getSingleClause<OMPNumTeamsClause>(); const auto *TL = S.getSingleClause<OMPThreadLimitClause>(); if (NT || TL) { - const Expr *NumTeams = NT ? NT->getNumTeams() : nullptr; + const Expr *NumTeams = NT ? NT->getNumTeams().front() : nullptr; const Expr *ThreadLimit = TL ? TL->getThreadLimit() : nullptr; CGF.CGM.getOpenMPRuntime().emitNumTeamsClause(CGF, NumTeams, ThreadLimit, diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index e975e96c5c7e4..50930aa0e9a4a 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_simdlen: case OMPC_collapse: case OMPC_ordered: - case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: @@ -3252,6 +3251,13 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, ? ParseOpenMPSimpleClause(CKind, WrongDirective) : ParseOpenMPClause(CKind, WrongDirective); break; + case OMPC_num_teams: + if (!FirstClause) { + Diag(Tok, diag::err_omp_more_one_clause) + << getOpenMPDirectiveName(DKind) << getOpenMPClauseName(CKind) << 0; + ErrorFound = true; + } + [[clang::fallthrough]]; case OMPC_private: case OMPC_firstprivate: case OMPC_lastprivate: diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 4f50efda155fb..8d25358ef5fa3 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -13004,6 +13004,24 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetUpdateDirective( Clauses, AStmt); } +// This checks whether num_teams clause only has one expression. +static bool checkNumTeamsClauseSingleExpr(SemaBase &SemaRef, + ArrayRef<OMPClause *> Clauses) { + auto NumTeamsClauseItr = + llvm::find_if(Clauses, llvm::IsaPred<OMPNumTeamsClause>); + if (NumTeamsClauseItr != Clauses.end()) { + ArrayRef<const Expr *> NumTeams = + cast<OMPNumTeamsClause>(*NumTeamsClauseItr)->getNumTeams(); + if (NumTeams.size() > 1) { + SemaRef.Diag(NumTeams[1]->getBeginLoc(), + diag::err_omp_multi_expr_not_allowed) + << getOpenMPClauseName(OMPC_num_teams); + return false; + } + } + return true; +} + StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, Stmt *AStmt, SourceLocation StartLoc, @@ -13011,6 +13029,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTeamsDirective(ArrayRef<OMPClause *> Clauses, if (!AStmt) return StmtError(); + if (!checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + // Report affected OpenMP target offloading behavior when in HIP lang-mode. if (getLangOpts().HIP && (DSAStack->getParentDirective() == OMPD_target)) Diag(StartLoc, diag::warn_hip_omp_target_directives); @@ -13785,6 +13806,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDirective( return StmtError(); } + if (!HasBareClause && !checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + return OMPTargetTeamsDirective::Create(getASTContext(), StartLoc, EndLoc, Clauses, AStmt); } @@ -13795,6 +13819,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeDirective( if (!AStmt) return StmtError(); + if (!checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope(SemaRef, OMPD_target_teams_distribute, AStmt); @@ -13821,6 +13848,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForDirective( if (!AStmt) return StmtError(); + if (!checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_parallel_for, AStmt); @@ -13848,6 +13878,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeParallelForSimdDirective( if (!AStmt) return StmtError(); + if (!checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_parallel_for_simd, AStmt); @@ -13878,6 +13911,9 @@ StmtResult SemaOpenMP::ActOnOpenMPTargetTeamsDistributeSimdDirective( if (!AStmt) return StmtError(); + if (!checkNumTeamsClauseSingleExpr(*this, Clauses)) + return StmtError(); + CapturedStmt *CS = setBranchProtectedScope( SemaRef, OMPD_target_teams_distribute_simd, AStmt); @@ -14925,9 +14961,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_ordered: Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr); break; - case OMPC_num_teams: - Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc); - break; case OMPC_thread_limit: Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc); break; @@ -15031,6 +15064,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_affinity: case OMPC_when: case OMPC_bind: + case OMPC_num_teams: default: llvm_unreachable("Clause is not allowed."); } @@ -16894,6 +16928,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier), ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc); break; + case OMPC_num_teams: + Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc); + break; case OMPC_if: case OMPC_depobj: case OMPC_final: @@ -16924,7 +16961,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_device: case OMPC_threads: case OMPC_simd: - case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: @@ -21587,32 +21623,40 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const { return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl(); } -OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams, +OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - Expr *ValExpr = NumTeams; - Stmt *HelperValStmt = nullptr; - - // OpenMP [teams Constrcut, Restrictions] - // The num_teams expression must evaluate to a positive integer value. - if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, - /*StrictlyPositive=*/true)) + if (VarList.empty()) return nullptr; + for (Expr *ValExpr : VarList) { + // OpenMP [teams Constrcut, Restrictions] + // The num_teams expression must evaluate to a positive integer value. + if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, + /*StrictlyPositive=*/true)) + return nullptr; + } + OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause( DKind, OMPC_num_teams, getLangOpts().OpenMP); - if (CaptureRegion != OMPD_unknown && - !SemaRef.CurContext->isDependentContext()) { + if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext()) + return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc, + LParenLoc, EndLoc, VarList, + /*PreInit=*/nullptr); + + llvm::MapVector<const Expr *, DeclRefExpr *> Captures; + SmallVector<Expr *, 3> Vars; + for (Expr *ValExpr : VarList) { ValExpr = SemaRef.MakeFullExpr(ValExpr).get(); - llvm::MapVector<const Expr *, DeclRefExpr *> Captures; ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get(); - HelperValStmt = buildPreInits(getASTContext(), Captures); + Vars.push_back(ValExpr); } - return new (getASTContext()) OMPNumTeamsClause( - ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); + Stmt *PreInit = buildPreInits(getASTContext(), Captures); + return OMPNumTeamsClause::Create(getASTContext(), CaptureRegion, StartLoc, + LParenLoc, EndLoc, Vars, PreInit); } OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit, diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 8d3e1edf7a45d..3fbfb2ec989ce 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -2065,10 +2065,11 @@ class TreeTransform { /// /// By default, performs semantic analysis to build the new statement. /// Subclasses may override this routine to provide different behavior. - OMPClause *RebuildOMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc, + OMPClause *RebuildOMPNumTeamsClause(ArrayRef<Expr *> VarList, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - return getSema().OpenMP().ActOnOpenMPNumTeamsClause(NumTeams, StartLoc, + return getSema().OpenMP().ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc); } @@ -10872,7 +10873,7 @@ TreeTransform<Derived>::TransformOMPAllocateClause(OMPAllocateClause *C) { template <typename Derived> OMPClause * TreeTransform<Derived>::TransformOMPNumTeamsClause(OMPNumTeamsClause *C) { - ExprResult E = getDerived().TransformExpr(C->getNumTeams()); + ExprResult E = getDerived().TransformExpr(C->getNumTeams().front()); if (E.isInvalid()) return nullptr; return getDerived().RebuildOMPNumTeamsClause( diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index 86fa96a91932f..0950c165bd662 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -104,6 +104,7 @@ #include "llvm/ADT/IntrusiveRefCntPtr.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" @@ -10567,7 +10568,7 @@ OMPClause *OMPClauseReader::readClause() { break; } case llvm::omp::OMPC_num_teams: - C = new (Context) OMPNumTeamsClause(); + C = OMPNumTeamsClause::CreateEmpty(Context, Record.readInt()); break; case llvm::omp::OMPC_thread_limit: C = new (Context) OMPThreadLimitClause(); @@ -11355,8 +11356,13 @@ void OMPClauseReader::VisitOMPAllocateClause(OMPAllocateClause *C) { void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { VisitOMPClauseWithPreInit(C); - C->setNumTeams(Record.readSubExpr()); C->setLParenLoc(Record.readSourceLocation()); + unsigned NumVars = C->varlist_size(); + SmallVector<Expr *, 16> Vars; + Vars.reserve(NumVars); + for ([[maybe_unused]] unsigned I : llvm::seq<unsigned>(NumVars)) + Vars.push_back(Record.readSubExpr()); + C->setVarRefs(Vars); } void OMPClauseReader::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index f0f9d397f1717..657eb6d3d1cc4 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -7528,9 +7528,11 @@ void OMPClauseWriter::VisitOMPAllocateClause(OMPAllocateClause *C) { } void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) { + Record.push_back(C->varlist_size()); VisitOMPClauseWithPreInit(C); - Record.AddStmt(C->getNumTeams()); Record.AddSourceLocation(C->getLParenLoc()); + for (auto *VE : C->varlist()) + Record.AddStmt(VE); } void OMPClauseWriter::VisitOMPThreadLimitClause(OMPThreadLimitClause *C) { diff --git a/clang/test/OpenMP/target_teams_ast_print.cpp b/clang/test/OpenMP/target_teams_ast_print.cpp index 2ff34e4498bfe..1590a996289f8 100644 --- a/clang/test/OpenMP/target_teams_ast_print.cpp +++ b/clang/test/OpenMP/target_teams_ast_print.cpp @@ -115,6 +115,10 @@ int main (int argc, char **argv) { // CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1) thread_limit(32) a=3; // CHECK-NEXT: a = 3; +#pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(32) +// CHECK-NEXT: #pragma omp target teams ompx_bare num_teams(1,2,3) thread_limit(32) + a=4; +// CHECK-NEXT: a = 4; #pragma omp target teams default(none), private(argc,b) num_teams(f) firstprivate(argv) reduction(| : c, d) reduction(* : e) thread_limit(f+g) // CHECK-NEXT: #pragma omp target teams default(none) private(argc,b) num_teams(f) firstprivate(argv) reduction(|: c,d) reduction(*: e) thread_limit(f + g) foo(); diff --git a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp index c0a31fa19b282..e8f898f1f25ee 100644 --- a/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_num_teams_messages.cpp @@ -44,6 +44,9 @@ T tmain(T argc) { #pragma omp target teams distribute num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + for (int i=0; i<100; i++) foo(); + return 0; } @@ -85,5 +88,8 @@ int main(int argc, char **argv) { #pragma omp target teams distribute num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + for (int i=0; i<100; i++) foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp index d80b6ea380b93..2a2f5ae27ac55 100644 --- a/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp +++ b/clang/test/OpenMP/target_teams_distribute_parallel_for_num_teams_messages.cpp @@ -43,6 +43,8 @@ T tmain(T argc) { for (int i=0; i<100; i++) foo(); #pragma omp target teams distribute parallel for num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + for (int i=0; i<100; i++) foo(); return 0; } @@ -85,5 +87,8 @@ int main(int argc, char **argv) { #pragma omp target teams distribute parallel for num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} for (int i=0; i<100; i++) foo(); +#pragma omp target teams distribute parallel for num_teams(1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + for (int i=0; i<100; i++) foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/test/OpenMP/teams_num_teams_messages.cpp b/clang/test/OpenMP/teams_num_teams_messages.cpp index 40da396b01069..09429167ee39e 100644 --- a/clang/test/OpenMP/teams_num_teams_messages.cpp +++ b/clang/test/OpenMP/teams_num_teams_messages.cpp @@ -57,6 +57,9 @@ T tmain(T argc) { #pragma omp target #pragma omp teams num_teams(3.14) // expected-error 2 {{expression must have integral or unscoped enumeration type, not 'double'}} foo(); +#pragma omp target +#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + foo(); return 0; } @@ -111,5 +114,9 @@ int main(int argc, char **argv) { #pragma omp teams num_teams (3.14) // expected-error {{expression must have integral or unscoped enumeration type, not 'double'}} foo(); +#pragma omp target +#pragma omp teams num_teams (1, 2, 3) // expected-error {{only one expression allowed to 'num_teams' clause}} + foo(); + return tmain<int, 10>(argc); // expected-note {{in instantiation of function template specialization 'tmain<int, 10>' requested here}} } diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 937d7ff09e4ee..d34da8b4eb158 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2499,8 +2499,8 @@ void OMPClauseEnqueue::VisitOMPDeviceClause(const OMPDeviceClause *C) { } void OMPClauseEnqueue::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { + VisitOMPClauseList(C); VisitOMPClauseWithPreInit(C); - Visitor->AddStmt(C->getNumTeams()); } void OMPClauseEnqueue::VisitOMPThreadLimitClause( _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits