Skip to content

Commit ff50a95

Browse files
committed
Sema: More conservative 'tautological binding' check
Instead of computing the reduced type of the witness upfront and then checking for canonical equality in type matching, check for reduced equality in type matching. This restores the old behavior and prevents us from considering too many protocol extension witnesses, while fixing the request cycle that motivated the change to instead match against the reduced type of the witness. Fixes rdar://problem/122589094, rdar://problem/122596633.
1 parent 5dde2fe commit ff50a95

File tree

4 files changed

+233
-58
lines changed

4 files changed

+233
-58
lines changed

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 117 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,46 +1382,65 @@ enum class InferenceCandidateKind {
13821382
static InferenceCandidateKind checkInferenceCandidate(
13831383
std::pair<AssociatedTypeDecl *, Type> *result,
13841384
NormalProtocolConformance *conformance,
1385-
DeclContext *witnessDC,
1385+
ValueDecl *witness,
13861386
Type selfTy) {
13871387
auto &ctx = selfTy->getASTContext();
13881388

1389+
// The unbound form of `Self.A`.
1390+
auto selfAssocTy = DependentMemberType::get(selfTy, result->first->getName());
1391+
auto genericSig = witness->getInnermostDeclContext()
1392+
->getGenericSignatureOfContext();
1393+
1394+
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1395+
// If the witness is in a protocol extension for a completely unrelated
1396+
// protocol that doesn't declare an associated type with the same name as
1397+
// the one we are trying to infer, then it will never be tautological.
1398+
if (!genericSig->isValidTypeParameter(selfAssocTy))
1399+
return InferenceCandidateKind::Good;
1400+
}
1401+
1402+
// A tautological binding is one where the left-hand side has the same
1403+
// reduced type as the right-hand side in the generic signature of the
1404+
// witness.
13891405
auto isTautological = [&](Type t) -> bool {
1406+
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1407+
13901408
auto dmt = t->getAs<DependentMemberType>();
13911409
if (!dmt)
13921410
return false;
1393-
if (!associatedTypesAreSameEquivalenceClass(dmt->getAssocType(),
1394-
result->first))
1395-
return false;
13961411

1397-
Type typeInContext;
1398-
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1399-
1400-
typeInContext = selfTy;
1412+
return genericSig->areReducedTypeParametersEqual(dmt, selfAssocTy);
14011413

14021414
} else {
14031415

1404-
typeInContext =
1405-
conformance->getDeclContext()->mapTypeIntoContext(conformance->getType());
1406-
1407-
}
1416+
auto dmt = t->getAs<DependentMemberType>();
1417+
if (!dmt)
1418+
return false;
1419+
if (!associatedTypesAreSameEquivalenceClass(dmt->getAssocType(),
1420+
result->first))
1421+
return false;
14081422

1423+
Type typeInContext =
1424+
conformance->getDeclContext()->mapTypeIntoContext(conformance->getType());
14091425
if (!dmt->getBase()->isEqual(typeInContext))
14101426
return false;
14111427

14121428
return true;
1429+
1430+
}
14131431
};
14141432

14151433
// Self.X == Self.X doesn't give us any new information, nor does it
14161434
// immediately fail.
14171435
if (isTautological(result->second)) {
1418-
auto *dmt = result->second->castTo<DependentMemberType>();
1419-
1420-
auto selfAssocTy = DependentMemberType::get(selfTy, dmt->getAssocType());
1436+
// FIXME: This should be getInnermostDeclContext()->getGenericSignature(),
1437+
// but that might introduce new ambiguities in existing code so we need
1438+
// to be careful.
1439+
auto genericSig = witness->getDeclContext()->getGenericSignatureOfContext();
14211440

14221441
// If we have a same-type requirement `Self.X == Self.Y`,
14231442
// introduce a binding `Self.X := Self.Y`.
1424-
for (auto &reqt : witnessDC->getGenericSignatureOfContext().getRequirements()) {
1443+
for (auto &reqt : genericSig.getRequirements()) {
14251444
switch (reqt.getKind()) {
14261445
case RequirementKind::SameShape:
14271446
llvm_unreachable("Same-shape requirement not supported here");
@@ -1432,6 +1451,45 @@ static InferenceCandidateKind checkInferenceCandidate(
14321451
break;
14331452

14341453
case RequirementKind::SameType:
1454+
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1455+
1456+
auto matches = [&](Type t) {
1457+
if (auto *dmt = t->getAs<DependentMemberType>()) {
1458+
return (dmt->getName() == result->first->getName() &&
1459+
dmt->getBase()->isEqual(selfTy));
1460+
}
1461+
1462+
return false;
1463+
};
1464+
1465+
// If we have a tautological binding, check if the witness generic
1466+
// signature has a same-type requirement `Self.A == Self.X` or
1467+
// `Self.X == Self.A`, where `A` is an associated type with the same
1468+
// name as the one we're trying to infer, and `X` is some other type
1469+
// parameter.
1470+
Type other;
1471+
if (matches(reqt.getFirstType())) {
1472+
other = reqt.getSecondType();
1473+
} else if (matches(reqt.getSecondType())) {
1474+
other = reqt.getFirstType();
1475+
} else {
1476+
break;
1477+
}
1478+
1479+
if (other->isTypeParameter() &&
1480+
other->getRootGenericParam()->isEqual(selfTy)) {
1481+
result->second = other;
1482+
LLVM_DEBUG(llvm::dbgs() << "++ we can same-type to:\n";
1483+
result->second->dump(llvm::dbgs()));
1484+
return InferenceCandidateKind::Good;
1485+
1486+
}
1487+
1488+
} else {
1489+
1490+
auto *dmt = result->second->castTo<DependentMemberType>();
1491+
auto selfAssocTy = DependentMemberType::get(selfTy, dmt->getAssocType());
1492+
14351493
Type other;
14361494
if (reqt.getFirstType()->isEqual(selfAssocTy)) {
14371495
other = reqt.getSecondType();
@@ -1443,18 +1501,9 @@ static InferenceCandidateKind checkInferenceCandidate(
14431501

14441502
if (auto otherAssoc = other->getAs<DependentMemberType>()) {
14451503
if (otherAssoc->getBase()->isEqual(selfTy)) {
1446-
DependentMemberType *otherDMT;
1447-
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1448-
1449-
otherDMT = otherAssoc;
1450-
1451-
} else {
1452-
1453-
otherDMT = DependentMemberType::get(dmt->getBase(),
1504+
auto *otherDMT = DependentMemberType::get(dmt->getBase(),
14541505
otherAssoc->getAssocType());
14551506

1456-
}
1457-
14581507
result->second = result->second.transform([&](Type t) -> Type{
14591508
if (t->isEqual(dmt))
14601509
return otherDMT;
@@ -1465,6 +1514,8 @@ static InferenceCandidateKind checkInferenceCandidate(
14651514
return InferenceCandidateKind::Good;
14661515
}
14671516
}
1517+
1518+
}
14681519
break;
14691520
}
14701521
}
@@ -1609,8 +1660,7 @@ AssociatedTypeInference::getPotentialTypeWitnessesFromRequirement(
16091660
// itself involve unresolved type witnesses.
16101661
if (selfTy) {
16111662
// Handle Self.X := Self.X and Self.X := G<Self.X>.
1612-
switch (checkInferenceCandidate(&result, conformance,
1613-
witness->getDeclContext(), selfTy)) {
1663+
switch (checkInferenceCandidate(&result, conformance, witness, selfTy)) {
16141664
case InferenceCandidateKind::Good:
16151665
// The "good" case is something like `Self.X := Self.Y`.
16161666
break;
@@ -1864,37 +1914,27 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
18641914
auto proto = conformance->getProtocol();
18651915
auto selfTy = proto->getSelfInterfaceType();
18661916

1867-
// Get the reduced type of the witness. This rules our certain tautological
1868-
// inferences below.
1869-
if (ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1870-
if (auto genericSig = witness->getInnermostDeclContext()
1871-
->getGenericSignatureOfContext()) {
1872-
type = genericSig.getReducedType(type);
1873-
type = genericSig->getSugaredType(type);
1874-
}
1875-
}
1876-
1877-
// Remap associated types that reference other protocols into this
1878-
// protocol.
1879-
type = type.transformRec([proto](TypeBase *type) -> llvm::Optional<Type> {
1880-
if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1881-
if (depMemTy->getAssocType() &&
1882-
depMemTy->getAssocType()->getProtocol() != proto) {
1883-
if (auto *assocType = proto->getAssociatedType(depMemTy->getName())) {
1884-
auto origProto = depMemTy->getAssocType()->getProtocol();
1885-
if (proto->inheritsFrom(origProto))
1886-
return Type(DependentMemberType::get(depMemTy->getBase(),
1887-
assocType));
1917+
if (!ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1918+
// Remap associated types that reference other protocols into this
1919+
// protocol.
1920+
auto resultType = Type(type).transformRec([proto](TypeBase *type)
1921+
-> llvm::Optional<Type> {
1922+
if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1923+
if (depMemTy->getAssocType() &&
1924+
depMemTy->getAssocType()->getProtocol() != proto) {
1925+
if (auto *assocType = proto->getAssociatedType(depMemTy->getName())) {
1926+
auto origProto = depMemTy->getAssocType()->getProtocol();
1927+
if (proto->inheritsFrom(origProto))
1928+
return Type(DependentMemberType::get(depMemTy->getBase(),
1929+
assocType));
1930+
}
18881931
}
18891932
}
1890-
}
18911933

1892-
return llvm::None;
1893-
});
1894-
1895-
if (!ctx.LangOpts.EnableExperimentalAssociatedTypeInference) {
1896-
auto resultType = type.subst(QueryTypeSubstitutionMap{substitutions},
1897-
LookUpConformanceInModule(module));
1934+
return llvm::None;
1935+
});
1936+
resultType = resultType.subst(QueryTypeSubstitutionMap{substitutions},
1937+
LookUpConformanceInModule(module));
18981938
if (!resultType->hasError()) return resultType;
18991939

19001940
// Map error types with original types *back* to the original, dependent type.
@@ -1916,9 +1956,28 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
19161956
if (!rootParam->isEqual(selfTy))
19171957
return type;
19181958

1959+
// Remap associated types that reference other protocols into this
1960+
// protocol.
1961+
auto substType = Type(type).transformRec([proto](TypeBase *type)
1962+
-> llvm::Optional<Type> {
1963+
if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1964+
if (depMemTy->getAssocType() &&
1965+
depMemTy->getAssocType()->getProtocol() != proto) {
1966+
if (auto *assocType = proto->getAssociatedType(depMemTy->getName())) {
1967+
auto origProto = depMemTy->getAssocType()->getProtocol();
1968+
if (proto->inheritsFrom(origProto))
1969+
return Type(DependentMemberType::get(depMemTy->getBase(),
1970+
assocType));
1971+
}
1972+
}
1973+
}
1974+
1975+
return llvm::None;
1976+
});
1977+
19191978
// Replace Self with the concrete conforming type.
1920-
auto substType = Type(type).subst(QueryTypeSubstitutionMap{substitutions},
1921-
LookUpConformanceInModule(module));
1979+
substType = substType.subst(QueryTypeSubstitutionMap{substitutions},
1980+
LookUpConformanceInModule(module));
19221981

19231982
// If we don't have enough type witnesses, leave it abstract.
19241983
if (substType->hasError())
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
2+
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
3+
4+
protocol P1 {
5+
associatedtype A
6+
7+
func f1(_: C) -> A
8+
func f2(_: A, _: C)
9+
10+
typealias C = S1<Self>
11+
}
12+
13+
struct S1<T> {}
14+
15+
protocol P2: P1 where A == B {
16+
associatedtype B
17+
18+
func g1(_: C) -> B
19+
func g2(_: B, _: C)
20+
}
21+
22+
extension P2 {
23+
func f1(_: C) -> B { fatalError() }
24+
func f2(_: B, _: C) { fatalError() }
25+
}
26+
27+
extension P2 {
28+
func g2(_: B, _: C) {}
29+
}
30+
31+
struct S2: P2 {
32+
func g1(_: C) -> Int {
33+
fatalError()
34+
}
35+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
2+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
3+
4+
public protocol P1 {
5+
associatedtype A = Void
6+
7+
func makeA() -> A
8+
func consumeA(a: inout A)
9+
}
10+
11+
extension P1 where A == Void {
12+
// Don't consider this witness in the 'S2: P1' conformance below.
13+
public func makeA() -> A { fatalError() }
14+
}
15+
16+
public struct S1: P1 {}
17+
18+
public protocol P2: P1 where A == B.A {
19+
associatedtype B: P1
20+
var base: B { get }
21+
}
22+
23+
extension P2 {
24+
public func makeA() -> B.A { fatalError() }
25+
public func consumeA(a: inout B.A) {}
26+
}
27+
28+
extension S1: P2 {
29+
public var base: S2 { fatalError() }
30+
}
31+
32+
public struct S2 {}
33+
34+
public struct S3 {}
35+
36+
extension S2: P1 {
37+
public typealias A = S3
38+
}
39+
40+
public protocol P3: P1 where A == S3 {}
41+
42+
extension P3 {
43+
public func makeA() -> A { fatalError() }
44+
public func consumeA(a: inout A) {}
45+
}
46+
47+
extension S2: P3 {}
48+
49+
let x: S3.Type = S2.A.self
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// RUN: %target-typecheck-verify-swift -disable-experimental-associated-type-inference
2+
// RUN: %target-typecheck-verify-swift -enable-experimental-associated-type-inference
3+
4+
public protocol P<A> {
5+
associatedtype A
6+
associatedtype B: P
7+
8+
func makeA() -> A
9+
var b: B { get }
10+
}
11+
12+
extension P where A == B.A {
13+
public func makeA() -> B.A {
14+
fatalError()
15+
}
16+
}
17+
18+
public struct S: P {
19+
public var b: some P<Int> {
20+
return G<Int>()
21+
}
22+
}
23+
24+
public struct G<A>: P {
25+
public func makeA() -> A { fatalError() }
26+
public var b: Never { fatalError() }
27+
}
28+
29+
extension Never: P {
30+
public func makeA() -> Never { fatalError() }
31+
public var b: Never { fatalError() }
32+
}

0 commit comments

Comments
 (0)