llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Krystian Stasiowski (sdkrystian) <details> <summary>Changes</summary> This patch fixes an infinite recursion bug in `ASTImporter` that occurs when importing the primary template of a class template specialization when the latest redeclaration of that template is a friend declaration in the primary template. --- Patch is 21.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114569.diff 12 Files Affected: - (modified) clang/include/clang/AST/DeclTemplate.h (+5-47) - (modified) clang/lib/AST/ASTImporter.cpp (+2-1) - (modified) clang/lib/AST/Decl.cpp (+5-5) - (modified) clang/lib/AST/DeclCXX.cpp (+2-2) - (modified) clang/lib/AST/DeclTemplate.cpp (+54-2) - (modified) clang/lib/Sema/SemaDecl.cpp (+3-1) - (modified) clang/lib/Sema/SemaInit.cpp (+1-1) - (modified) clang/lib/Sema/SemaTemplateInstantiate.cpp (+7-7) - (modified) clang/test/AST/ast-dump-decl.cpp (+1-1) - (added) clang/test/ASTMerge/class-template-spec/Inputs/class-template-spec.cpp (+47) - (added) clang/test/ASTMerge/class-template-spec/test.cpp (+8) - (modified) clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp (+87) ``````````diff diff --git a/clang/include/clang/AST/DeclTemplate.h b/clang/include/clang/AST/DeclTemplate.h index a572e3380f1655..0ca3fd48e81cf4 100644 --- a/clang/include/clang/AST/DeclTemplate.h +++ b/clang/include/clang/AST/DeclTemplate.h @@ -857,16 +857,6 @@ class RedeclarableTemplateDecl : public TemplateDecl, /// \endcode bool isMemberSpecialization() const { return Common.getInt(); } - /// Determines whether any redeclaration of this template was - /// a specialization of a member template. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (D->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { assert(!isMemberSpecialization() && "already a member specialization"); @@ -1965,13 +1955,7 @@ class ClassTemplateSpecializationDecl : public CXXRecordDecl, /// specialization which was specialized by this. llvm::PointerUnion<ClassTemplateDecl *, ClassTemplatePartialSpecializationDecl *> - getSpecializedTemplateOrPartial() const { - if (const auto *PartialSpec = - SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>()) - return PartialSpec->PartialSpecialization; - - return SpecializedTemplate.get<ClassTemplateDecl*>(); - } + getSpecializedTemplateOrPartial() const; /// Retrieve the set of template arguments that should be used /// to instantiate members of the class template or class template partial @@ -2208,17 +2192,6 @@ class ClassTemplatePartialSpecializationDecl return InstantiatedFromMember.getInt(); } - /// Determines whether any redeclaration of this this class template partial - /// specialization was a specialization of a member partial specialization. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (cast<ClassTemplatePartialSpecializationDecl>(D) - ->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); } @@ -2740,13 +2713,7 @@ class VarTemplateSpecializationDecl : public VarDecl, /// Retrieve the variable template or variable template partial /// specialization which was specialized by this. llvm::PointerUnion<VarTemplateDecl *, VarTemplatePartialSpecializationDecl *> - getSpecializedTemplateOrPartial() const { - if (const auto *PartialSpec = - SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>()) - return PartialSpec->PartialSpecialization; - - return SpecializedTemplate.get<VarTemplateDecl *>(); - } + getSpecializedTemplateOrPartial() const; /// Retrieve the set of template arguments that should be used /// to instantiate the initializer of the variable template or variable @@ -2980,18 +2947,6 @@ class VarTemplatePartialSpecializationDecl return InstantiatedFromMember.getInt(); } - /// Determines whether any redeclaration of this this variable template - /// partial specialization was a specialization of a member partial - /// specialization. - bool hasMemberSpecialization() const { - for (const auto *D : redecls()) { - if (cast<VarTemplatePartialSpecializationDecl>(D) - ->isMemberSpecialization()) - return true; - } - return false; - } - /// Note that this member template is a specialization. void setMemberSpecialization() { return InstantiatedFromMember.setInt(true); } @@ -3164,6 +3119,9 @@ class VarTemplateDecl : public RedeclarableTemplateDecl { return makeSpecIterator(getSpecializations(), true); } + /// Merge \p Prev with our RedeclarableTemplateDecl::Common. + void mergePrevDecl(VarTemplateDecl *Prev); + // Implement isa/cast/dyncast support static bool classof(const Decl *D) { return classofKind(D->getKind()); } static bool classofKind(Kind K) { return K == VarTemplate; } diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index 6e31df691fa104..9d0b77566f6747 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -6190,7 +6190,8 @@ ExpectedDecl ASTNodeImporter::VisitClassTemplateDecl(ClassTemplateDecl *D) { ExpectedDecl ASTNodeImporter::VisitClassTemplateSpecializationDecl( ClassTemplateSpecializationDecl *D) { ClassTemplateDecl *ClassTemplate; - if (Error Err = importInto(ClassTemplate, D->getSpecializedTemplate())) + if (Error Err = importInto(ClassTemplate, + D->getSpecializedTemplate()->getCanonicalDecl())) return std::move(Err); // Import the context of this declaration. diff --git a/clang/lib/AST/Decl.cpp b/clang/lib/AST/Decl.cpp index 86913763ef9ff5..cd173d17263792 100644 --- a/clang/lib/AST/Decl.cpp +++ b/clang/lib/AST/Decl.cpp @@ -2708,7 +2708,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { if (isTemplateInstantiation(VDTemplSpec->getTemplateSpecializationKind())) { auto From = VDTemplSpec->getInstantiatedFrom(); if (auto *VTD = From.dyn_cast<VarTemplateDecl *>()) { - while (!VTD->hasMemberSpecialization()) { + while (!VTD->isMemberSpecialization()) { if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate()) VTD = NewVTD; else @@ -2718,7 +2718,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { } if (auto *VTPSD = From.dyn_cast<VarTemplatePartialSpecializationDecl *>()) { - while (!VTPSD->hasMemberSpecialization()) { + while (!VTPSD->isMemberSpecialization()) { if (auto *NewVTPSD = VTPSD->getInstantiatedFromMember()) VTPSD = NewVTPSD; else @@ -2732,7 +2732,7 @@ VarDecl *VarDecl::getTemplateInstantiationPattern() const { // If this is the pattern of a variable template, find where it was // instantiated from. FIXME: Is this necessary? if (VarTemplateDecl *VTD = VD->getDescribedVarTemplate()) { - while (!VTD->hasMemberSpecialization()) { + while (!VTD->isMemberSpecialization()) { if (auto *NewVTD = VTD->getInstantiatedFromMemberTemplate()) VTD = NewVTD; else @@ -4153,7 +4153,7 @@ FunctionDecl::getTemplateInstantiationPattern(bool ForDefinition) const { if (FunctionTemplateDecl *Primary = getPrimaryTemplate()) { // If we hit a point where the user provided a specialization of this // template, we're done looking. - while (!ForDefinition || !Primary->hasMemberSpecialization()) { + while (!ForDefinition || !Primary->isMemberSpecialization()) { if (auto *NewPrimary = Primary->getInstantiatedFromMemberTemplate()) Primary = NewPrimary; else @@ -4170,7 +4170,7 @@ FunctionTemplateDecl *FunctionDecl::getPrimaryTemplate() const { if (FunctionTemplateSpecializationInfo *Info = TemplateOrSpecialization .dyn_cast<FunctionTemplateSpecializationInfo*>()) { - return Info->getTemplate(); + return Info->getTemplate()->getMostRecentDecl(); } return nullptr; } diff --git a/clang/lib/AST/DeclCXX.cpp b/clang/lib/AST/DeclCXX.cpp index db0ea62a2323eb..1c92fd9e3ff067 100644 --- a/clang/lib/AST/DeclCXX.cpp +++ b/clang/lib/AST/DeclCXX.cpp @@ -2030,7 +2030,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const { if (auto *TD = dyn_cast<ClassTemplateSpecializationDecl>(this)) { auto From = TD->getInstantiatedFrom(); if (auto *CTD = From.dyn_cast<ClassTemplateDecl *>()) { - while (!CTD->hasMemberSpecialization()) { + while (!CTD->isMemberSpecialization()) { if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) CTD = NewCTD; else @@ -2040,7 +2040,7 @@ const CXXRecordDecl *CXXRecordDecl::getTemplateInstantiationPattern() const { } if (auto *CTPSD = From.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) { - while (!CTPSD->hasMemberSpecialization()) { + while (!CTPSD->isMemberSpecialization()) { if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate()) CTPSD = NewCTPSD; else diff --git a/clang/lib/AST/DeclTemplate.cpp b/clang/lib/AST/DeclTemplate.cpp index 755ec72f00bf77..1db02d0d04448c 100644 --- a/clang/lib/AST/DeclTemplate.cpp +++ b/clang/lib/AST/DeclTemplate.cpp @@ -993,7 +993,17 @@ ClassTemplateSpecializationDecl::getSpecializedTemplate() const { if (const auto *PartialSpec = SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization*>()) return PartialSpec->PartialSpecialization->getSpecializedTemplate(); - return SpecializedTemplate.get<ClassTemplateDecl*>(); + return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl(); +} + +llvm::PointerUnion<ClassTemplateDecl *, + ClassTemplatePartialSpecializationDecl *> +ClassTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const { + if (const auto *PartialSpec = + SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>()) + return PartialSpec->PartialSpecialization->getMostRecentDecl(); + + return SpecializedTemplate.get<ClassTemplateDecl *>()->getMostRecentDecl(); } SourceRange @@ -1283,6 +1293,39 @@ VarTemplateDecl::newCommon(ASTContext &C) const { return CommonPtr; } +void VarTemplateDecl::mergePrevDecl(VarTemplateDecl *Prev) { + // If we haven't created a common pointer yet, then it can just be created + // with the usual method. + if (!getCommonPtrInternal()) + return; + + Common *ThisCommon = static_cast<Common *>(getCommonPtrInternal()); + Common *PrevCommon = nullptr; + SmallVector<VarTemplateDecl *, 8> PreviousDecls; + for (; Prev; Prev = Prev->getPreviousDecl()) { + if (CommonBase *C = Prev->getCommonPtrInternal()) { + PrevCommon = static_cast<Common *>(C); + break; + } + PreviousDecls.push_back(Prev); + } + + // If the previous redecl chain hasn't created a common pointer yet, then just + // use this common pointer. + if (!PrevCommon) { + for (auto *D : PreviousDecls) + D->setCommonPtr(ThisCommon); + return; + } + + // Ensure we don't leak any important state. + assert(ThisCommon->Specializations.empty() && + ThisCommon->PartialSpecializations.empty() && + "Can't merge incompatible declarations!"); + + setCommonPtr(PrevCommon); +} + VarTemplateSpecializationDecl * VarTemplateDecl::findSpecialization(ArrayRef<TemplateArgument> Args, void *&InsertPos) { @@ -1405,7 +1448,16 @@ VarTemplateDecl *VarTemplateSpecializationDecl::getSpecializedTemplate() const { if (const auto *PartialSpec = SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>()) return PartialSpec->PartialSpecialization->getSpecializedTemplate(); - return SpecializedTemplate.get<VarTemplateDecl *>(); + return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl(); +} + +llvm::PointerUnion<VarTemplateDecl *, VarTemplatePartialSpecializationDecl *> +VarTemplateSpecializationDecl::getSpecializedTemplateOrPartial() const { + if (const auto *PartialSpec = + SpecializedTemplate.dyn_cast<SpecializedPartialSpecialization *>()) + return PartialSpec->PartialSpecialization->getMostRecentDecl(); + + return SpecializedTemplate.get<VarTemplateDecl *>()->getMostRecentDecl(); } SourceRange VarTemplateSpecializationDecl::getSourceRange() const { diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index c56883a80c1c55..65976a9f30a54b 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -4694,8 +4694,10 @@ void Sema::MergeVarDecl(VarDecl *New, LookupResult &Previous) { // Keep a chain of previous declarations. New->setPreviousDecl(Old); - if (NewTemplate) + if (NewTemplate) { + NewTemplate->mergePrevDecl(OldTemplate); NewTemplate->setPreviousDecl(OldTemplate); + } // Inherit access appropriately. New->setAccess(Old->getAccess()); diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp index 573e90aced3eea..e2a59f63ccf589 100644 --- a/clang/lib/Sema/SemaInit.cpp +++ b/clang/lib/Sema/SemaInit.cpp @@ -9954,7 +9954,7 @@ QualType Sema::DeduceTemplateSpecializationFromInitializer( auto SynthesizeAggrGuide = [&](InitListExpr *ListInit) { auto *Pattern = Template; while (Pattern->getInstantiatedFromMemberTemplate()) { - if (Pattern->hasMemberSpecialization()) + if (Pattern->isMemberSpecialization()) break; Pattern = Pattern->getInstantiatedFromMemberTemplate(); } diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp index b63063813f1b56..de0ec0128905ff 100644 --- a/clang/lib/Sema/SemaTemplateInstantiate.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp @@ -343,7 +343,7 @@ struct TemplateInstantiationArgumentCollecter // If this function was instantiated from a specialized member that is // a function template, we're done. assert(FD->getPrimaryTemplate() && "No function template?"); - if (FD->getPrimaryTemplate()->hasMemberSpecialization()) + if (FD->getPrimaryTemplate()->isMemberSpecialization()) return Done(); // If this function is a generic lambda specialization, we are done. @@ -442,11 +442,11 @@ struct TemplateInstantiationArgumentCollecter Specialized = CTSD->getSpecializedTemplateOrPartial(); if (auto *CTPSD = Specialized.dyn_cast<ClassTemplatePartialSpecializationDecl *>()) { - if (CTPSD->hasMemberSpecialization()) + if (CTPSD->isMemberSpecialization()) return Done(); } else { auto *CTD = Specialized.get<ClassTemplateDecl *>(); - if (CTD->hasMemberSpecialization()) + if (CTD->isMemberSpecialization()) return Done(); } return UseNextDecl(CTSD); @@ -478,11 +478,11 @@ struct TemplateInstantiationArgumentCollecter Specialized = VTSD->getSpecializedTemplateOrPartial(); if (auto *VTPSD = Specialized.dyn_cast<VarTemplatePartialSpecializationDecl *>()) { - if (VTPSD->hasMemberSpecialization()) + if (VTPSD->isMemberSpecialization()) return Done(); } else { auto *VTD = Specialized.get<VarTemplateDecl *>(); - if (VTD->hasMemberSpecialization()) + if (VTD->isMemberSpecialization()) return Done(); } return UseNextDecl(VTSD); @@ -4141,7 +4141,7 @@ getPatternForClassTemplateSpecialization( CXXRecordDecl *Pattern = nullptr; Specialized = ClassTemplateSpec->getSpecializedTemplateOrPartial(); if (auto *CTD = Specialized.dyn_cast<ClassTemplateDecl *>()) { - while (!CTD->hasMemberSpecialization()) { + while (!CTD->isMemberSpecialization()) { if (auto *NewCTD = CTD->getInstantiatedFromMemberTemplate()) CTD = NewCTD; else @@ -4151,7 +4151,7 @@ getPatternForClassTemplateSpecialization( } else if (auto *CTPSD = Specialized .dyn_cast<ClassTemplatePartialSpecializationDecl *>()) { - while (!CTPSD->hasMemberSpecialization()) { + while (!CTPSD->isMemberSpecialization()) { if (auto *NewCTPSD = CTPSD->getInstantiatedFromMemberTemplate()) CTPSD = NewCTPSD; else diff --git a/clang/test/AST/ast-dump-decl.cpp b/clang/test/AST/ast-dump-decl.cpp index e84241cee922f5..7b998f20944f49 100644 --- a/clang/test/AST/ast-dump-decl.cpp +++ b/clang/test/AST/ast-dump-decl.cpp @@ -530,7 +530,7 @@ namespace testCanonicalTemplate { // CHECK-NEXT: | `-ClassTemplateDecl 0x{{.+}} parent 0x{{.+}} <col:5, col:40> col:40 friend_undeclared TestClassTemplate{{$}} // CHECK-NEXT: | |-TemplateTypeParmDecl 0x{{.+}} <col:14, col:23> col:23 typename depth 1 index 0 T2{{$}} // CHECK-NEXT: | `-CXXRecordDecl 0x{{.+}} parent 0x{{.+}} <col:34, col:40> col:40 class TestClassTemplate{{$}} - // CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} <line:[[@LINE-19]]:3, line:[[@LINE-17]]:3> line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}} + // CHECK-NEXT: `-ClassTemplateSpecializationDecl 0x{{.+}} <col:5, col:40> line:[[@LINE-19]]:31 class TestClassTemplate definition implicit_instantiation{{$}} // CHECK-NEXT: |-DefinitionData pass_in_registers empty aggregate standard_layout trivially_copyable pod trivial literal has_constexpr_non_copy_move_ctor can_const_default_init{{$}} // CHECK-NEXT: | |-DefaultConstructor exists trivial constexpr defaulted_is_constexpr{{$}} // CHECK-NEXT: | |-CopyConstructor simple trivial has_const_param implicit_has_const_param{{$}} diff --git a/clang/test/ASTMerge/class-template-spec/Inputs/class-template-spec.cpp b/clang/test/ASTMerge/class-template-spec/Inputs/class-template-spec.cpp new file mode 100644 index 00000000000000..332bf24d25b29d --- /dev/null +++ b/clang/test/ASTMerge/class-template-spec/Inputs/class-template-spec.cpp @@ -0,0 +1,47 @@ +namespace N0 { + template<typename T> + struct A { + template<typename U> + friend struct A; + }; + + template struct A<long>; +} // namespace N0 + +namespace N1 { + template<typename T> + struct A; + + template<typename T> + struct A { + template<typename U> + friend struct A; + }; + + template struct A<long>; +} // namespace N1 + +namespace N2 { + template<typename T> + struct A { + template<typename U> + friend struct A; + }; + + template<typename T> + struct A; + + template struct A<long>; +} // namespace N2 + +namespace N3 { + struct A { + template<typename T> + friend struct B; + }; + + template<typename T> + struct B { }; + + template struct B<long>; +} // namespace N3 diff --git a/clang/test/ASTMerge/class-template-spec/test.cpp b/clang/test/ASTMerge/class-template-spec/test.cpp new file mode 100644 index 00000000000000..adbce483503278 --- /dev/null +++ b/clang/test/ASTMerge/class-template-spec/test.cpp @@ -0,0 +1,8 @@ +// RUN: %clang_cc1 -emit-pch -o %t.1.ast %S/Inputs/class-template-spec.cpp +// RUN: %clang_cc1 -ast-merge %t.1.ast -fsyntax-only -verify %s +// expected-no-diagnostics + +template struct N0::A<short>; +template struct N1::A<short>; +template struct N2::A<short>; +template struct N3::B<short>; diff --git a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp index 87127366eb58a5..e7e4738032f647 100644 --- a/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp +++ b/clang/test/CXX/temp/temp.spec/temp.expl.spec/p7.cpp @@ -177,6 +177,93 @@ namespace Defined { static_assert(A<short>::B<int*>::y == 2); } // namespace Defined +namespace Constrained { + template<typename T> + struct A { + template<typename U, bool V> requires V + static constexpr int f(); // expected-note {{declared here}} + + template<typename U, bool V> requires V + static const int x; // expected-note {{declared here}} + + template<typename U, bool V> requires V + static const int x<U*, V>; // expected-note {{declared here}} + + template<typename U, bool V> requires V + struct B; // expected-note {{template is declared here}} + + template<typename U, bool V> requires V + struct B<U*, V>; // expected-note {{template is declared here}} + }; + + template<> + template<typename U, bool V> requires V + constexpr int A<short>::f() { + return A<long>::f<U, V>(); + } + + template<> + template<typename U, bool V> requires V + constexpr int A<short>::x = A<long>::x<U, V>; + + template<> + template<typename U, bool V> requires V + constexpr int A<short>::x<U*, V> = A<long>::x<U*, V>; + + template<> + template<typename U, bool V> requires V + struct A<short>::B<U*, V> { + static constexpr int y = A<long>::B<U*, V>::y; + }; + + template<> + template<typename U, bool V> requires V + struct A<short>::B { + static constexpr int y = A<long>::B<U, V>::y; + }; + + template<> + template<typename U, bool V> requires V + constexpr int A<long>::f() { + return 1; + } + + template<> + template<typename U, bool V> requires V + constexpr int A<long>::x = 1; + + template<> + template<typename U, bool V> requires V + constexpr int A<long>::x<U*, V> = 2; + + template<> + template<typename U, bool V> requires V + struct A<long>::B { + ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/114569 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits