@@ -1382,46 +1382,65 @@ enum class InferenceCandidateKind {
1382
1382
static InferenceCandidateKind checkInferenceCandidate (
1383
1383
std::pair<AssociatedTypeDecl *, Type> *result,
1384
1384
NormalProtocolConformance *conformance,
1385
- DeclContext *witnessDC ,
1385
+ ValueDecl *witness ,
1386
1386
Type selfTy) {
1387
1387
auto &ctx = selfTy->getASTContext ();
1388
1388
1389
+ // The unbound form of `Self.A`.
1390
+ auto selfAssocTy = DependentMemberType::get (selfTy, result->first ->getName ());
1391
+ auto genericSig = witness->getInnermostDeclContext ()
1392
+ ->getGenericSignatureOfContext ();
1393
+
1394
+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1395
+ // If the witness is in a protocol extension for a completely unrelated
1396
+ // protocol that doesn't declare an associated type with the same name as
1397
+ // the one we are trying to infer, then it will never be tautological.
1398
+ if (!genericSig->isValidTypeParameter (selfAssocTy))
1399
+ return InferenceCandidateKind::Good;
1400
+ }
1401
+
1402
+ // A tautological binding is one where the left-hand side has the same
1403
+ // reduced type as the right-hand side in the generic signature of the
1404
+ // witness.
1389
1405
auto isTautological = [&](Type t) -> bool {
1406
+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1407
+
1390
1408
auto dmt = t->getAs <DependentMemberType>();
1391
1409
if (!dmt)
1392
1410
return false ;
1393
- if (!associatedTypesAreSameEquivalenceClass (dmt->getAssocType (),
1394
- result->first ))
1395
- return false ;
1396
1411
1397
- Type typeInContext;
1398
- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1399
-
1400
- typeInContext = selfTy;
1412
+ return genericSig->areReducedTypeParametersEqual (dmt, selfAssocTy);
1401
1413
1402
1414
} else {
1403
1415
1404
- typeInContext =
1405
- conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
1406
-
1407
- }
1416
+ auto dmt = t->getAs <DependentMemberType>();
1417
+ if (!dmt)
1418
+ return false ;
1419
+ if (!associatedTypesAreSameEquivalenceClass (dmt->getAssocType (),
1420
+ result->first ))
1421
+ return false ;
1408
1422
1423
+ Type typeInContext =
1424
+ conformance->getDeclContext ()->mapTypeIntoContext (conformance->getType ());
1409
1425
if (!dmt->getBase ()->isEqual (typeInContext))
1410
1426
return false ;
1411
1427
1412
1428
return true ;
1429
+
1430
+ }
1413
1431
};
1414
1432
1415
1433
// Self.X == Self.X doesn't give us any new information, nor does it
1416
1434
// immediately fail.
1417
1435
if (isTautological (result->second )) {
1418
- auto *dmt = result->second ->castTo <DependentMemberType>();
1419
-
1420
- auto selfAssocTy = DependentMemberType::get (selfTy, dmt->getAssocType ());
1436
+ // FIXME: This should be getInnermostDeclContext()->getGenericSignature(),
1437
+ // but that might introduce new ambiguities in existing code so we need
1438
+ // to be careful.
1439
+ auto genericSig = witness->getDeclContext ()->getGenericSignatureOfContext ();
1421
1440
1422
1441
// If we have a same-type requirement `Self.X == Self.Y`,
1423
1442
// introduce a binding `Self.X := Self.Y`.
1424
- for (auto &reqt : witnessDC-> getGenericSignatureOfContext () .getRequirements ()) {
1443
+ for (auto &reqt : genericSig .getRequirements ()) {
1425
1444
switch (reqt.getKind ()) {
1426
1445
case RequirementKind::SameShape:
1427
1446
llvm_unreachable (" Same-shape requirement not supported here" );
@@ -1432,6 +1451,45 @@ static InferenceCandidateKind checkInferenceCandidate(
1432
1451
break ;
1433
1452
1434
1453
case RequirementKind::SameType:
1454
+ if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1455
+
1456
+ auto matches = [&](Type t) {
1457
+ if (auto *dmt = t->getAs <DependentMemberType>()) {
1458
+ return (dmt->getName () == result->first ->getName () &&
1459
+ dmt->getBase ()->isEqual (selfTy));
1460
+ }
1461
+
1462
+ return false ;
1463
+ };
1464
+
1465
+ // If we have a tautological binding, check if the witness generic
1466
+ // signature has a same-type requirement `Self.A == Self.X` or
1467
+ // `Self.X == Self.A`, where `A` is an associated type with the same
1468
+ // name as the one we're trying to infer, and `X` is some other type
1469
+ // parameter.
1470
+ Type other;
1471
+ if (matches (reqt.getFirstType ())) {
1472
+ other = reqt.getSecondType ();
1473
+ } else if (matches (reqt.getSecondType ())) {
1474
+ other = reqt.getFirstType ();
1475
+ } else {
1476
+ break ;
1477
+ }
1478
+
1479
+ if (other->isTypeParameter () &&
1480
+ other->getRootGenericParam ()->isEqual (selfTy)) {
1481
+ result->second = other;
1482
+ LLVM_DEBUG (llvm::dbgs () << " ++ we can same-type to:\n " ;
1483
+ result->second ->dump (llvm::dbgs ()));
1484
+ return InferenceCandidateKind::Good;
1485
+
1486
+ }
1487
+
1488
+ } else {
1489
+
1490
+ auto *dmt = result->second ->castTo <DependentMemberType>();
1491
+ auto selfAssocTy = DependentMemberType::get (selfTy, dmt->getAssocType ());
1492
+
1435
1493
Type other;
1436
1494
if (reqt.getFirstType ()->isEqual (selfAssocTy)) {
1437
1495
other = reqt.getSecondType ();
@@ -1443,18 +1501,9 @@ static InferenceCandidateKind checkInferenceCandidate(
1443
1501
1444
1502
if (auto otherAssoc = other->getAs <DependentMemberType>()) {
1445
1503
if (otherAssoc->getBase ()->isEqual (selfTy)) {
1446
- DependentMemberType *otherDMT;
1447
- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1448
-
1449
- otherDMT = otherAssoc;
1450
-
1451
- } else {
1452
-
1453
- otherDMT = DependentMemberType::get (dmt->getBase (),
1504
+ auto *otherDMT = DependentMemberType::get (dmt->getBase (),
1454
1505
otherAssoc->getAssocType ());
1455
1506
1456
- }
1457
-
1458
1507
result->second = result->second .transform ([&](Type t) -> Type{
1459
1508
if (t->isEqual (dmt))
1460
1509
return otherDMT;
@@ -1465,6 +1514,8 @@ static InferenceCandidateKind checkInferenceCandidate(
1465
1514
return InferenceCandidateKind::Good;
1466
1515
}
1467
1516
}
1517
+
1518
+ }
1468
1519
break ;
1469
1520
}
1470
1521
}
@@ -1609,8 +1660,7 @@ AssociatedTypeInference::getPotentialTypeWitnessesFromRequirement(
1609
1660
// itself involve unresolved type witnesses.
1610
1661
if (selfTy) {
1611
1662
// Handle Self.X := Self.X and Self.X := G<Self.X>.
1612
- switch (checkInferenceCandidate (&result, conformance,
1613
- witness->getDeclContext (), selfTy)) {
1663
+ switch (checkInferenceCandidate (&result, conformance, witness, selfTy)) {
1614
1664
case InferenceCandidateKind::Good:
1615
1665
// The "good" case is something like `Self.X := Self.Y`.
1616
1666
break ;
@@ -1864,37 +1914,27 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
1864
1914
auto proto = conformance->getProtocol ();
1865
1915
auto selfTy = proto->getSelfInterfaceType ();
1866
1916
1867
- // Get the reduced type of the witness. This rules our certain tautological
1868
- // inferences below.
1869
- if (ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1870
- if (auto genericSig = witness->getInnermostDeclContext ()
1871
- ->getGenericSignatureOfContext ()) {
1872
- type = genericSig.getReducedType (type);
1873
- type = genericSig->getSugaredType (type);
1874
- }
1875
- }
1876
-
1877
- // Remap associated types that reference other protocols into this
1878
- // protocol.
1879
- type = type.transformRec ([proto](TypeBase *type) -> llvm::Optional<Type> {
1880
- if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1881
- if (depMemTy->getAssocType () &&
1882
- depMemTy->getAssocType ()->getProtocol () != proto) {
1883
- if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1884
- auto origProto = depMemTy->getAssocType ()->getProtocol ();
1885
- if (proto->inheritsFrom (origProto))
1886
- return Type (DependentMemberType::get (depMemTy->getBase (),
1887
- assocType));
1917
+ if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1918
+ // Remap associated types that reference other protocols into this
1919
+ // protocol.
1920
+ auto resultType = Type (type).transformRec ([proto](TypeBase *type)
1921
+ -> llvm::Optional<Type> {
1922
+ if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1923
+ if (depMemTy->getAssocType () &&
1924
+ depMemTy->getAssocType ()->getProtocol () != proto) {
1925
+ if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1926
+ auto origProto = depMemTy->getAssocType ()->getProtocol ();
1927
+ if (proto->inheritsFrom (origProto))
1928
+ return Type (DependentMemberType::get (depMemTy->getBase (),
1929
+ assocType));
1930
+ }
1888
1931
}
1889
1932
}
1890
- }
1891
1933
1892
- return llvm::None;
1893
- });
1894
-
1895
- if (!ctx.LangOpts .EnableExperimentalAssociatedTypeInference ) {
1896
- auto resultType = type.subst (QueryTypeSubstitutionMap{substitutions},
1897
- LookUpConformanceInModule (module));
1934
+ return llvm::None;
1935
+ });
1936
+ resultType = resultType.subst (QueryTypeSubstitutionMap{substitutions},
1937
+ LookUpConformanceInModule (module));
1898
1938
if (!resultType->hasError ()) return resultType;
1899
1939
1900
1940
// Map error types with original types *back* to the original, dependent type.
@@ -1916,9 +1956,28 @@ static Type getWitnessTypeForMatching(NormalProtocolConformance *conformance,
1916
1956
if (!rootParam->isEqual (selfTy))
1917
1957
return type;
1918
1958
1959
+ // Remap associated types that reference other protocols into this
1960
+ // protocol.
1961
+ auto substType = Type (type).transformRec ([proto](TypeBase *type)
1962
+ -> llvm::Optional<Type> {
1963
+ if (auto depMemTy = dyn_cast<DependentMemberType>(type)) {
1964
+ if (depMemTy->getAssocType () &&
1965
+ depMemTy->getAssocType ()->getProtocol () != proto) {
1966
+ if (auto *assocType = proto->getAssociatedType (depMemTy->getName ())) {
1967
+ auto origProto = depMemTy->getAssocType ()->getProtocol ();
1968
+ if (proto->inheritsFrom (origProto))
1969
+ return Type (DependentMemberType::get (depMemTy->getBase (),
1970
+ assocType));
1971
+ }
1972
+ }
1973
+ }
1974
+
1975
+ return llvm::None;
1976
+ });
1977
+
1919
1978
// Replace Self with the concrete conforming type.
1920
- auto substType = Type (type) .subst (QueryTypeSubstitutionMap{substitutions},
1921
- LookUpConformanceInModule (module));
1979
+ substType = substType .subst (QueryTypeSubstitutionMap{substitutions},
1980
+ LookUpConformanceInModule (module));
1922
1981
1923
1982
// If we don't have enough type witnesses, leave it abstract.
1924
1983
if (substType->hasError ())
0 commit comments