Skip to content

[clang] CTAD: use index and depth to retrieve template parameter for TemplateParamsReferencedInTemplateArgumentList #98013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions clang/lib/Sema/SemaTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "clang/Sema/SemaInternal.h"
#include "clang/Sema/Template.h"
#include "clang/Sema/TemplateDeduction.h"
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
Expand Down Expand Up @@ -2751,23 +2752,48 @@ struct ConvertConstructorToDeductionGuideTransform {
}
};

unsigned getTemplateParameterDepth(NamedDecl *TemplateParam) {
if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
return TTP->getDepth();
if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
return TTP->getDepth();
if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
return NTTP->getDepth();
llvm_unreachable("Unhandled template parameter types");
}

unsigned getTemplateParameterIndex(NamedDecl *TemplateParam) {
if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
return TTP->getIndex();
if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
return TTP->getIndex();
if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
return NTTP->getIndex();
llvm_unreachable("Unhandled template parameter types");
}

// Find all template parameters that appear in the given DeducedArgs.
// Return the indices of the template parameters in the TemplateParams.
SmallVector<unsigned> TemplateParamsReferencedInTemplateArgumentList(
ArrayRef<NamedDecl *> TemplateParams,
const TemplateParameterList *TemplateParamsList,
ArrayRef<TemplateArgument> DeducedArgs) {
struct TemplateParamsReferencedFinder
: public RecursiveASTVisitor<TemplateParamsReferencedFinder> {
llvm::DenseSet<NamedDecl *> TemplateParams;
llvm::DenseSet<const NamedDecl *> ReferencedTemplateParams;
const TemplateParameterList *TemplateParamList;
llvm::BitVector ReferencedTemplateParams;

TemplateParamsReferencedFinder(ArrayRef<NamedDecl *> TemplateParams)
: TemplateParams(TemplateParams.begin(), TemplateParams.end()) {}
TemplateParamsReferencedFinder(
const TemplateParameterList *TemplateParamList)
: TemplateParamList(TemplateParamList),
ReferencedTemplateParams(TemplateParamList->size()) {}

bool VisitTemplateTypeParmType(TemplateTypeParmType *TTP) {
MarkAppeared(TTP->getDecl());
// We use the index and depth to retrieve the corresponding template
// parameter from the parameter list, which is more robost.
Mark(TTP->getDepth(), TTP->getIndex());
return true;
}

bool VisitDeclRefExpr(DeclRefExpr *DRE) {
MarkAppeared(DRE->getFoundDecl());
return true;
Expand All @@ -2780,16 +2806,22 @@ SmallVector<unsigned> TemplateParamsReferencedInTemplateArgumentList(
}

void MarkAppeared(NamedDecl *ND) {
if (TemplateParams.contains(ND))
ReferencedTemplateParams.insert(ND);
if (llvm::isa<NonTypeTemplateParmDecl, TemplateTypeParmDecl,
TemplateTemplateParmDecl>(ND))
Mark(getTemplateParameterDepth(ND), getTemplateParameterIndex(ND));
}
void Mark(unsigned Depth, unsigned Index) {
if (Index < TemplateParamList->size() &&
TemplateParamList->getParam(Index)->getTemplateDepth() == Depth)
ReferencedTemplateParams.set(Index);
}
};
TemplateParamsReferencedFinder Finder(TemplateParams);
TemplateParamsReferencedFinder Finder(TemplateParamsList);
Finder.TraverseTemplateArguments(DeducedArgs);

SmallVector<unsigned> Results;
for (unsigned Index = 0; Index < TemplateParams.size(); ++Index) {
if (Finder.ReferencedTemplateParams.contains(TemplateParams[Index]))
for (unsigned Index = 0; Index < TemplateParamsList->size(); ++Index) {
if (Finder.ReferencedTemplateParams[Index])
Results.push_back(Index);
}
return Results;
Expand All @@ -2808,16 +2840,6 @@ bool hasDeclaredDeductionGuides(DeclarationName Name, DeclContext *DC) {
return false;
}

unsigned getTemplateParameterDepth(NamedDecl *TemplateParam) {
if (auto *TTP = dyn_cast<TemplateTypeParmDecl>(TemplateParam))
return TTP->getDepth();
if (auto *TTP = dyn_cast<TemplateTemplateParmDecl>(TemplateParam))
return TTP->getDepth();
if (auto *NTTP = dyn_cast<NonTypeTemplateParmDecl>(TemplateParam))
return NTTP->getDepth();
llvm_unreachable("Unhandled template parameter types");
}

NamedDecl *transformTemplateParameter(Sema &SemaRef, DeclContext *DC,
NamedDecl *TemplateParam,
MultiLevelTemplateArgumentList &Args,
Expand Down Expand Up @@ -3149,7 +3171,7 @@ BuildDeductionGuideForTypeAlias(Sema &SemaRef,
}
auto DeducedAliasTemplateParams =
TemplateParamsReferencedInTemplateArgumentList(
AliasTemplate->getTemplateParameters()->asArray(), DeducedArgs);
AliasTemplate->getTemplateParameters(), DeducedArgs);
// All template arguments null by default.
SmallVector<TemplateArgument> TemplateArgsForBuildingFPrime(
F->getTemplateParameters()->size());
Expand Down
55 changes: 55 additions & 0 deletions clang/test/AST/ast-dump-ctad-alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,58 @@ BFoo b2(1.0, 2.0);
// CHECK-NEXT: | | |-ParmVarDecl {{.*}} 'type-parameter-0-0'
// CHECK-NEXT: | | `-ParmVarDecl {{.*}} 'type-parameter-0-0'
// CHECK-NEXT: | `-CXXDeductionGuideDecl {{.*}} implicit used <deduction guide for BFoo> 'auto (double, double) -> Foo<double, double>' implicit_instantiation

namespace GH90209 {
// Case 1: type template parameter
template <class Ts>
struct List1 {
List1(int);
};

template <class T1>
struct TemplatedClass1 {
TemplatedClass1(T1);
};

template <class T1>
TemplatedClass1(T1) -> TemplatedClass1<List1<T1>>;

template <class T2>
using ATemplatedClass1 = TemplatedClass1<List1<T2>>;

ATemplatedClass1 test1(1);
// Verify that we have a correct template parameter list for the deduction guide.
//
// CHECK: FunctionTemplateDecl {{.*}} <deduction guide for ATemplatedClass1>
// CHECK-NEXT: |-TemplateTypeParmDecl {{.*}} class depth 0 index 0 T2
// CHECK-NEXT: |-TypeTraitExpr {{.*}} 'bool' __is_deducible

// Case 2: template template parameter
template<typename K> struct Foo{};

template <template<typename> typename Ts>
struct List2 {
List2(int);
};

template <typename T1>
struct TemplatedClass2 {
TemplatedClass2(T1);
};

template <template<typename> typename T1>
TemplatedClass2(T1<int>) -> TemplatedClass2<List2<T1>>;

template <template<typename> typename T2>
using ATemplatedClass2 = TemplatedClass2<List2<T2>>;

List2<Foo> list(1);
ATemplatedClass2 test2(list);
// Verify that we have a correct template parameter list for the deduction guide.
//
// CHECK: FunctionTemplateDecl {{.*}} <deduction guide for ATemplatedClass2>
// CHECK-NEXT: |-TemplateTemplateParmDecl {{.*}} depth 0 index 0 T2
// CHECK-NEXT: | `-TemplateTypeParmDecl {{.*}} typename depth 0 index 0
// CHECK-NEXT: |-TypeTraitExpr {{.*}} 'bool' __is_deducible

} // namespace GH90209
Loading