llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Younan Zhang (zyn0217) <details> <summary>Changes</summary> (This is one step towards tweaking `getTemplateInstantiationArgs()` as discussed in https://github.com/llvm/llvm-project/pull/102922) We don't always substitute into default arguments while transforming a function parameter. In that case, we would preserve the uninstantiated expression until after, e.g. building up a CXXDefaultArgExpr and instantiate the expression there. For member function instantiation, this algorithm used to cause a problem in that the default argument of an out-of-line member function specialization couldn't get properly instantiated. This is because, in `getTemplateInstantiationArgs()`, we would give up visiting a function's declaration context if the function is a specialization of a member template. For example, ```cpp template <class T> struct S { template <class U> void f(T = sizeof(T)); }; template <> template <class U> void S<int>::f(int) {} ``` The default argument `sizeof(U)` that lexically appears inside the declaration would be copied to the function declaration in the class template specialization `S<int>`, as well as to the function's out-of-line definition. We use template arguments collected from the out-of-line function definition when substituting into the default arguments. We would therefore give up the traversal after the function, resulting in a single-level template argument of the f itself. However the default argument here could still reference the template parameters of the primary template, hence the error. In fact, this is similar to constraint checking in some respects: we actually want the "whole" template arguments relative to the primary template, not those relative to the function definition. So this patch adds another flag to indicate `getTemplateInstantiationArgs()` for that. This patch also consolidates the tests for default arguments and removes some unnecessary tests. --- Full diff: https://github.com/llvm/llvm-project/pull/104911.diff 5 Files Affected: - (modified) clang/include/clang/Sema/Sema.h (+8-1) - (modified) clang/lib/Sema/SemaTemplateInstantiate.cpp (+8-15) - (modified) clang/lib/Sema/SemaTemplateInstantiateDecl.cpp (+6-4) - (modified) clang/test/SemaTemplate/default-arguments.cpp (+55) - (modified) clang/test/SemaTemplate/default-parm-init.cpp (-186) ``````````diff diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 2ec6367eccea01..84df847726e6d2 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -13053,12 +13053,19 @@ class Sema final : public SemaBase { /// ForConstraintInstantiation indicates we should continue looking when /// encountering a lambda generic call operator, and continue looking for /// arguments on an enclosing class template. + /// + /// \param SkipForSpecialization when specified, any template specializations + /// in a traversal would be ignored. + /// \param ForDefaultArgumentSubstitution indicates we should continue looking + /// when encountering a specialized member function template, rather than + /// returning immediately. MultiLevelTemplateArgumentList getTemplateInstantiationArgs( const NamedDecl *D, const DeclContext *DC = nullptr, bool Final = false, std::optional<ArrayRef<TemplateArgument>> Innermost = std::nullopt, bool RelativeToPrimary = false, const FunctionDecl *Pattern = nullptr, bool ForConstraintInstantiation = false, - bool SkipForSpecialization = false); + bool SkipForSpecialization = false, + bool ForDefaultArgumentSubstitution = false); /// RAII object to handle the state changes required to synthesize /// a function body. diff --git a/clang/lib/Sema/SemaTemplateInstantiate.cpp b/clang/lib/Sema/SemaTemplateInstantiate.cpp index de470739ab78e7..feed797de838dd 100644 --- a/clang/lib/Sema/SemaTemplateInstantiate.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiate.cpp @@ -255,7 +255,8 @@ HandleClassTemplateSpec(const ClassTemplateSpecializationDecl *ClassTemplSpec, Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function, MultiLevelTemplateArgumentList &Result, const FunctionDecl *Pattern, bool RelativeToPrimary, - bool ForConstraintInstantiation) { + bool ForConstraintInstantiation, + bool ForDefaultArgumentSubstitution) { // Add template arguments from a function template specialization. if (!RelativeToPrimary && Function->getTemplateSpecializationKindForInstantiation() == @@ -285,7 +286,8 @@ Response HandleFunction(Sema &SemaRef, const FunctionDecl *Function, // If this function was instantiated from a specialized member that is // a function template, we're done. assert(Function->getPrimaryTemplate() && "No function template?"); - if (Function->getPrimaryTemplate()->isMemberSpecialization()) + if (!ForDefaultArgumentSubstitution && + Function->getPrimaryTemplate()->isMemberSpecialization()) return Response::Done(); // If this function is a generic lambda specialization, we are done. @@ -467,7 +469,7 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs( const NamedDecl *ND, const DeclContext *DC, bool Final, std::optional<ArrayRef<TemplateArgument>> Innermost, bool RelativeToPrimary, const FunctionDecl *Pattern, bool ForConstraintInstantiation, - bool SkipForSpecialization) { + bool SkipForSpecialization, bool ForDefaultArgumentSubstitution) { assert((ND || DC) && "Can't find arguments for a decl if one isn't provided"); // Accumulate the set of template argument lists in this structure. MultiLevelTemplateArgumentList Result; @@ -509,7 +511,8 @@ MultiLevelTemplateArgumentList Sema::getTemplateInstantiationArgs( SkipForSpecialization); } else if (const auto *Function = dyn_cast<FunctionDecl>(CurDecl)) { R = HandleFunction(*this, Function, Result, Pattern, RelativeToPrimary, - ForConstraintInstantiation); + ForConstraintInstantiation, + ForDefaultArgumentSubstitution); } else if (const auto *Rec = dyn_cast<CXXRecordDecl>(CurDecl)) { R = HandleRecordDecl(*this, Rec, Result, Context, ForConstraintInstantiation); @@ -3229,7 +3232,6 @@ bool Sema::SubstDefaultArgument( // default argument expression appears. ContextRAII SavedContext(*this, FD); std::unique_ptr<LocalInstantiationScope> LIS; - MultiLevelTemplateArgumentList NewTemplateArgs = TemplateArgs; if (ForCallExpr) { // When instantiating a default argument due to use in a call expression, @@ -3242,19 +3244,10 @@ bool Sema::SubstDefaultArgument( /*ForDefinition*/ false); if (addInstantiatedParametersToScope(FD, PatternFD, *LIS, TemplateArgs)) return true; - const FunctionTemplateDecl *PrimaryTemplate = FD->getPrimaryTemplate(); - if (PrimaryTemplate && PrimaryTemplate->isOutOfLine()) { - TemplateArgumentList *CurrentTemplateArgumentList = - TemplateArgumentList::CreateCopy(getASTContext(), - TemplateArgs.getInnermost()); - NewTemplateArgs = getTemplateInstantiationArgs( - FD, FD->getDeclContext(), /*Final=*/false, - CurrentTemplateArgumentList->asArray(), /*RelativeToPrimary=*/true); - } } runWithSufficientStackSpace(Loc, [&] { - Result = SubstInitializer(PatternExpr, NewTemplateArgs, + Result = SubstInitializer(PatternExpr, TemplateArgs, /*DirectInit*/ false); }); } diff --git a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp index f93cd113988ae4..ad2ad3b1d1a790 100644 --- a/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp +++ b/clang/lib/Sema/SemaTemplateInstantiateDecl.cpp @@ -4659,10 +4659,12 @@ bool Sema::InstantiateDefaultArgument(SourceLocation CallLoc, FunctionDecl *FD, // // template<typename T> // A<T> Foo(int a = A<T>::FooImpl()); - MultiLevelTemplateArgumentList TemplateArgs = - getTemplateInstantiationArgs(FD, FD->getLexicalDeclContext(), - /*Final=*/false, /*Innermost=*/std::nullopt, - /*RelativeToPrimary=*/true); + MultiLevelTemplateArgumentList TemplateArgs = getTemplateInstantiationArgs( + FD, FD->getLexicalDeclContext(), + /*Final=*/false, /*Innermost=*/std::nullopt, + /*RelativeToPrimary=*/true, /*Pattern=*/nullptr, + /*ForConstraintInstantiation=*/false, /*SkipForSpecialization=*/false, + /*ForDefaultArgumentSubstitution=*/true); if (SubstDefaultArgument(CallLoc, Param, TemplateArgs, /*ForCallExpr*/ true)) return true; diff --git a/clang/test/SemaTemplate/default-arguments.cpp b/clang/test/SemaTemplate/default-arguments.cpp index d5d9687cc90f49..c90787c4255a4a 100644 --- a/clang/test/SemaTemplate/default-arguments.cpp +++ b/clang/test/SemaTemplate/default-arguments.cpp @@ -229,3 +229,58 @@ namespace unevaluated { template<int = 0> int f(int = a); // expected-warning 0-1{{extension}} int k = sizeof(f()); } + +#if __cplusplus >= 201103L +namespace GH68490 { + +template <typename T> struct Problem { + template <typename U> + constexpr int UseAlignOf(int param = alignof(U)) const; + + template <typename U> + constexpr int UseSizeOf(int param = sizeof(T)) const; +}; + +template <typename T> struct Problem<T *> { + template <typename U> + constexpr int UseAlignOf(int param = alignof(U)) const; + + template <typename U> + constexpr int UseSizeOf(int param = sizeof(T)) const; +}; + +template <typename T> +template <typename U> +constexpr int Problem<T *>::UseAlignOf(int param) const { + return 2 * param; +} + +template <typename T> +template <typename U> +constexpr int Problem<T *>::UseSizeOf(int param) const { + return 2 * param; +} + +template <> +template <typename T> +constexpr int Problem<int>::UseAlignOf(int param) const { + return param; +} + +template <> +template <typename T> +constexpr int Problem<int>::UseSizeOf(int param) const { + return param; +} + +void foo() { + static_assert(Problem<int>().UseAlignOf<char>() == alignof(char), ""); + static_assert(Problem<int>().UseSizeOf<char>() == sizeof(char), ""); + // expected-error@-1 {{failed}} expected-note@-1 {{evaluates to '4 == 1'}} + static_assert(Problem<short *>().UseAlignOf<char>() == 2U * alignof(char), ""); + static_assert(Problem<short *>().UseSizeOf<char>() == 2U * sizeof(char), ""); + // expected-error@-1 {{failed}} expected-note@-1 {{evaluates to '4 == 2'}} +} + +} // namespace GH68490 +#endif diff --git a/clang/test/SemaTemplate/default-parm-init.cpp b/clang/test/SemaTemplate/default-parm-init.cpp index 73ba8998df6a98..d1f407ad15c677 100644 --- a/clang/test/SemaTemplate/default-parm-init.cpp +++ b/clang/test/SemaTemplate/default-parm-init.cpp @@ -2,189 +2,3 @@ // RUN: %clang_cc1 -fsyntax-only -std=c++20 -verify %s // expected-no-diagnostics -namespace std { - -template<typename Signature> class function; - -template<typename R, typename... Args> class invoker_base { -public: - virtual ~invoker_base() { } - virtual R invoke(Args...) = 0; - virtual invoker_base* clone() = 0; -}; - -template<typename F, typename R, typename... Args> -class functor_invoker : public invoker_base<R, Args...> { -public: - explicit functor_invoker(const F& f) : f(f) { } - R invoke(Args... args) { return f(args...); } - functor_invoker* clone() { return new functor_invoker(f); } - -private: - F f; -}; - -template<typename R, typename... Args> -class function<R (Args...)> { -public: - typedef R result_type; - function() : invoker (0) { } - function(const function& other) : invoker(0) { - if (other.invoker) - invoker = other.invoker->clone(); - } - - template<typename F> function(const F& f) : invoker(0) { - invoker = new functor_invoker<F, R, Args...>(f); - } - - ~function() { - if (invoker) - delete invoker; - } - - function& operator=(const function& other) { - function(other).swap(*this); - return *this; - } - - template<typename F> - function& operator=(const F& f) { - function(f).swap(*this); - return *this; - } - - void swap(function& other) { - invoker_base<R, Args...>* tmp = invoker; - invoker = other.invoker; - other.invoker = tmp; - } - - result_type operator()(Args... args) const { - return invoker->invoke(args...); - } - -private: - invoker_base<R, Args...>* invoker; -}; - -} - -template<typename TemplateParam> -struct Problem { - template<typename FunctionTemplateParam> - constexpr int FuncAlign(int param = alignof(FunctionTemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncAlign2(int param = alignof(TemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncSizeof2(int param = sizeof(TemplateParam)); -}; - -template<typename TemplateParam> -struct Problem<TemplateParam*> { - template<typename FunctionTemplateParam> - constexpr int FuncAlign(int param = alignof(FunctionTemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncSizeof(int param = sizeof(FunctionTemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncAlign2(int param = alignof(TemplateParam)); - - template<typename FunctionTemplateParam> - constexpr int FuncSizeof2(int param = sizeof(TemplateParam)); -}; - -template<typename TemplateParam> -template<typename FunctionTemplateParam> -constexpr int Problem<TemplateParam*>::FuncAlign(int param) { - return 2U*param; -} - -template<typename TemplateParam> -template<typename FunctionTemplateParam> -constexpr int Problem<TemplateParam*>::FuncSizeof(int param) { - return 2U*param; -} - -template<typename TemplateParam> -template<typename FunctionTemplateParam> -constexpr int Problem<TemplateParam*>::FuncAlign2(int param) { - return 2U*param; -} - -template<typename TemplateParam> -template<typename FunctionTemplateParam> -constexpr int Problem<TemplateParam*>::FuncSizeof2(int param) { - return 2U*param; -} - -template <> -template<typename FunctionTemplateParam> -constexpr int Problem<int>::FuncAlign(int param) { - return param; -} - -template <> -template<typename FunctionTemplateParam> -constexpr int Problem<int>::FuncSizeof(int param) { - return param; -} - -template <> -template<typename FunctionTemplateParam> -constexpr int Problem<int>::FuncAlign2(int param) { - return param; -} - -template <> -template<typename FunctionTemplateParam> -constexpr int Problem<int>::FuncSizeof2(int param) { - return param; -} - -void foo() { - Problem<int> p = {}; - static_assert(p.FuncAlign<char>() == alignof(char)); - static_assert(p.FuncSizeof<char>() == sizeof(char)); - static_assert(p.FuncAlign2<char>() == alignof(int)); - static_assert(p.FuncSizeof2<char>() == sizeof(int)); - Problem<short*> q = {}; - static_assert(q.FuncAlign<char>() == 2U * alignof(char)); - static_assert(q.FuncSizeof<char>() == 2U * sizeof(char)); - static_assert(q.FuncAlign2<char>() == 2U *alignof(short)); - static_assert(q.FuncSizeof2<char>() == 2U * sizeof(short)); -} - -template <typename T> -class A { - public: - void run( - std::function<void(T&)> f1 = [](auto&&) {}, - std::function<void(T&)> f2 = [](auto&&) {}); - private: - class Helper { - public: - explicit Helper(std::function<void(T&)> f2) : f2_(f2) {} - std::function<void(T&)> f2_; - }; -}; - -template <typename T> -void A<T>::run(std::function<void(T&)> f1, - std::function<void(T&)> f2) { - Helper h(f2); -} - -struct B {}; - -int main() { - A<B> a; - a.run([&](auto& l) {}); - return 0; -} `````````` </details> https://github.com/llvm/llvm-project/pull/104911 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits