Skip to content

Try to derive a type witness in a known conformance before attempting associated type inference #32578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 14 additions & 41 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,34 +729,6 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
return structDecl;
}

/// Add a typealias declaration with the given name and underlying target
/// struct type to the given source nominal declaration context.
static void addAssociatedTypeAliasDecl(Identifier name, DeclContext *sourceDC,
StructDecl *target,
ASTContext &Context) {
auto *nominal = sourceDC->getSelfNominalTypeDecl();
assert(nominal && "Expected `DeclContext` to be a nominal type");
auto lookup = nominal->lookupDirect(name);
assert(lookup.size() < 2 &&
"Expected at most one associated type named member");
// If implicit type declaration with the given name already exists in source
// struct, return it.
if (lookup.size() == 1) {
auto existingTypeDecl = dyn_cast<TypeDecl>(lookup.front());
assert(existingTypeDecl && existingTypeDecl->isImplicit() &&
"Expected lookup result to be an implicit type declaration");
return;
}
// Otherwise, create a new typealias.
auto *aliasDecl = new (Context)
TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, sourceDC);
aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType());
aliasDecl->setImplicit();
aliasDecl->setGenericSignature(sourceDC->getGenericSignatureOfContext());
cast<IterableDeclContext>(sourceDC->getAsDecl())->addMember(aliasDecl);
aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
};

/// Diagnose stored properties in the nominal that do not have an explicit
/// `@noDerivative` attribute, but either:
/// - Do not conform to `Differentiable`.
Expand Down Expand Up @@ -842,7 +814,7 @@ static void checkAndDiagnoseImplicitNoDerivative(ASTContext &Context,
}

/// Get or synthesize `TangentVector` struct type.
static Type
static std::pair<Type, TypeDecl *>
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
auto *parentDC = derived.getConformanceContext();
auto *nominal = derived.Nominal;
Expand All @@ -852,28 +824,28 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
auto *tangentStruct =
getOrSynthesizeTangentVectorStruct(derived, C.Id_TangentVector);
if (!tangentStruct)
return nullptr;
return std::make_pair(nullptr, nullptr);

// Check and emit warnings for implicit `@noDerivative` members.
checkAndDiagnoseImplicitNoDerivative(C, nominal, parentDC);
// Add `TangentVector` typealias for `TangentVector` struct.
addAssociatedTypeAliasDecl(C.Id_TangentVector, tangentStruct, tangentStruct,
C);

// Return the `TangentVector` struct type.
return parentDC->mapTypeIntoContext(
tangentStruct->getDeclaredInterfaceType());
return std::make_pair(
parentDC->mapTypeIntoContext(
tangentStruct->getDeclaredInterfaceType()),
tangentStruct);
}

/// Synthesize the `TangentVector` struct type.
static Type
static std::pair<Type, TypeDecl *>
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
auto *parentDC = derived.getConformanceContext();
auto *nominal = derived.Nominal;

// If nominal type can derive `TangentVector` as the contextual `Self` type,
// return it.
if (canDeriveTangentVectorAsSelf(nominal, parentDC))
return parentDC->getSelfTypeInContext();
return std::make_pair(parentDC->getSelfTypeInContext(), nullptr);

// Otherwise, get or synthesize `TangentVector` struct type.
return getOrSynthesizeTangentVectorStructType(derived);
Expand Down Expand Up @@ -914,16 +886,17 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) {
return nullptr;
}

Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
std::pair<Type, TypeDecl *>
DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
// Diagnose unknown requirements.
if (requirement->getBaseName() != Context.Id_TangentVector) {
Context.Diags.diagnose(requirement->getLoc(),
diag::broken_differentiable_requirement);
return nullptr;
return std::make_pair(nullptr, nullptr);
}
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
return std::make_pair(nullptr, nullptr);

// Start an error diagnostic before attempting derivation.
// If derivation succeeds, cancel the diagnostic.
Expand All @@ -939,5 +912,5 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
}

// Otherwise, return nullptr.
return nullptr;
return std::make_pair(nullptr, nullptr);
}
3 changes: 2 additions & 1 deletion lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ class DerivedConformance {
/// Derive a Differentiable type witness for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
Type deriveDifferentiable(AssociatedTypeDecl *assocType);
std::pair<Type, TypeDecl *>
deriveDifferentiable(AssociatedTypeDecl *assocType);

/// Derive a CaseIterable requirement for an enum if it has no associated
/// values for any of its cases.
Expand Down
15 changes: 8 additions & 7 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5579,29 +5579,30 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
llvm_unreachable("unknown derivable protocol kind");
}

Type TypeChecker::deriveTypeWitness(DeclContext *DC,
NominalTypeDecl *TypeDecl,
AssociatedTypeDecl *AssocType) {
std::pair<Type, TypeDecl *>
TypeChecker::deriveTypeWitness(DeclContext *DC,
NominalTypeDecl *TypeDecl,
AssociatedTypeDecl *AssocType) {
auto *protocol = cast<ProtocolDecl>(AssocType->getDeclContext());

auto knownKind = protocol->getKnownProtocolKind();

if (!knownKind)
return nullptr;
return std::make_pair(nullptr, nullptr);

auto Decl = DC->getInnermostDeclarationDeclContext();

DerivedConformance derived(TypeDecl->getASTContext(), Decl, TypeDecl,
protocol);
switch (*knownKind) {
case KnownProtocolKind::RawRepresentable:
return derived.deriveRawRepresentable(AssocType);
return std::make_pair(derived.deriveRawRepresentable(AssocType), nullptr);
case KnownProtocolKind::CaseIterable:
return derived.deriveCaseIterable(AssocType);
return std::make_pair(derived.deriveCaseIterable(AssocType), nullptr);
case KnownProtocolKind::Differentiable:
return derived.deriveDifferentiable(AssocType);
default:
return nullptr;
return std::make_pair(nullptr, nullptr);
}
}

Expand Down
8 changes: 3 additions & 5 deletions lib/Sema/TypeCheckProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -822,15 +822,13 @@ class AssociatedTypeInference {

/// Compute the "derived" type witness for an associated type that is
/// known to the compiler.
Type computeDerivedTypeWitness(AssociatedTypeDecl *assocType);
std::pair<Type, TypeDecl *>
computeDerivedTypeWitness(AssociatedTypeDecl *assocType);

/// Compute a type witness without using a specific potential witness,
/// e.g., using a fixed type (from a refined protocol), default type
/// on an associated type, or deriving the type.
///
/// \param allowDerived Whether to allow "derived" type witnesses.
Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType,
bool allowDerived);
Type computeAbstractTypeWitness(AssociatedTypeDecl *assocType);

/// Substitute the current type witnesses into the given interface type.
Type substCurrentTypeWitnesses(Type type);
Expand Down
53 changes: 29 additions & 24 deletions lib/Sema/TypeCheckProtocolInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,38 +868,37 @@ Type AssociatedTypeInference::computeDefaultTypeWitness(
return defaultType;
}

Type AssociatedTypeInference::computeDerivedTypeWitness(
std::pair<Type, TypeDecl *>
AssociatedTypeInference::computeDerivedTypeWitness(
AssociatedTypeDecl *assocType) {
if (adoptee->hasError())
return Type();
return std::make_pair(Type(), nullptr);

// Can we derive conformances for this protocol and adoptee?
NominalTypeDecl *derivingTypeDecl = adoptee->getAnyNominal();
if (!DerivedConformance::derivesProtocolConformance(dc, derivingTypeDecl,
proto))
return Type();
return std::make_pair(Type(), nullptr);

// Try to derive the type witness.
Type derivedType =
TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
if (!derivedType)
return Type();
auto result = TypeChecker::deriveTypeWitness(dc, derivingTypeDecl, assocType);
if (!result.first)
return std::make_pair(Type(), nullptr);

// Make sure that the derived type is sane.
if (checkTypeWitness(derivedType, assocType, conformance)) {
// Make sure that the derived type satisfies requirements.
if (checkTypeWitness(result.first, assocType, conformance)) {
/// FIXME: Diagnose based on this.
failedDerivedAssocType = assocType;
failedDerivedWitness = derivedType;
return Type();
failedDerivedWitness = result.first;
return std::make_pair(Type(), nullptr);
}

return derivedType;
return result;
}

Type
AssociatedTypeInference::computeAbstractTypeWitness(
AssociatedTypeDecl *assocType,
bool allowDerived) {
AssociatedTypeDecl *assocType) {
// We don't have a type witness for this associated type, so go
// looking for more options.
if (Type concreteType = computeFixedTypeWitness(assocType))
Expand All @@ -909,12 +908,6 @@ AssociatedTypeInference::computeAbstractTypeWitness(
if (Type defaultType = computeDefaultTypeWitness(assocType))
return defaultType;

// If we can derive a type witness, do so.
if (allowDerived) {
if (Type derivedType = computeDerivedTypeWitness(assocType))
return derivedType;
}

// If there is a generic parameter of the named type, use that.
if (auto genericSig = dc->getGenericSignatureOfContext()) {
for (auto gp : genericSig->getInnermostGenericParams()) {
Expand Down Expand Up @@ -1197,8 +1190,7 @@ void AssociatedTypeInference::findSolutionsRec(

// Try to compute the type without the aid of a specific potential
// witness.
if (Type type = computeAbstractTypeWitness(assocType,
/*allowDerived=*/true)) {
if (Type type = computeAbstractTypeWitness(assocType)) {
if (type->hasError()) {
recordMissing();
return;
Expand Down Expand Up @@ -1880,10 +1872,23 @@ auto AssociatedTypeInference::solve(ConformanceChecker &checker)
continue;

case ResolveWitnessResult::Missing:
// Note that we haven't resolved this associated type yet.
unresolvedAssocTypes.insert(assocType);
// We did not find the witness via name lookup. Try to derive
// it below.
break;
}

// Finally, try to derive the witness if we know how.
auto derivedType = computeDerivedTypeWitness(assocType);
if (derivedType.first) {
checker.recordTypeWitness(assocType,
derivedType.first->mapTypeOutOfContext(),
derivedType.second);
continue;
}

// We failed to derive the witness. We're going to go on to try
// to infer it from potential value witnesses next.
unresolvedAssocTypes.insert(assocType);
}

// Result variable to use for returns so that we get NRVO.
Expand Down
5 changes: 3 additions & 2 deletions lib/Sema/TypeChecker.h
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,9 @@ ValueDecl *deriveProtocolRequirement(DeclContext *DC,
/// Derive an implicit type witness for the given associated type in
/// the conformance of the given nominal type to some known
/// protocol.
Type deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
AssociatedTypeDecl *assocType);
std::pair<Type, TypeDecl *>
deriveTypeWitness(DeclContext *DC, NominalTypeDecl *nominal,
AssociatedTypeDecl *assocType);

/// \name Name lookup
///
Expand Down
5 changes: 4 additions & 1 deletion test/Sema/enum_raw_representable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ var doubles: [Double] = serialize([Bar.a, .b, .c])
var foos: [Foo] = deserialize([1, 2, 3])
var bars: [Bar] = deserialize([1.2, 3.4, 5.6])

// Infer RawValue from witnesses.
// We reject enums where the raw type stated in the inheritance clause does not
// match the types of the witnesses.
enum Color : Int {
case red
case blue
Expand All @@ -56,11 +57,13 @@ enum Color : Int {
}

var rawValue: Double {
// expected-error@-1 {{invalid redeclaration of synthesized implementation for protocol requirement 'rawValue'}}
return 1.0
}
}

var colorRaw: Color.RawValue = 7.5
// expected-error@-1 {{cannot convert value of type 'Double' to specified type 'Color.RawValue' (aka 'Int')}}

// Mismatched case types

Expand Down
13 changes: 13 additions & 0 deletions test/Sema/enum_raw_representable_circularity.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: %target-typecheck-verify-swift

// This used to fail with "reference to invalid associated type 'RawValue' of type 'E'"
_ = E(rawValue: 123)

enum E : Int {
case a = 123

init?(rawValue: RawValue) {
self = .a
}
}