Skip to content

Commit b6a0418

Browse files
authored
Merge pull request #61109 from xedin/issue-60958-alt-5.7
[5.7][ConstraintSystem] Use witnesses for `makeIterator` and `next` refs in `for-in` context
2 parents 4058fbf + ef5b0a3 commit b6a0418

File tree

10 files changed

+447
-29
lines changed

10 files changed

+447
-29
lines changed

include/swift/AST/ASTContext.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ class ASTContext final {
600600
/// Get AsyncSequence.makeAsyncIterator().
601601
FuncDecl *getAsyncSequenceMakeAsyncIterator() const;
602602

603+
/// Get IteratorProtocol.next().
604+
FuncDecl *getIteratorNext() const;
605+
606+
/// Get AsyncIteratorProtocol.next().
607+
FuncDecl *getAsyncIteratorNext() const;
608+
603609
/// Check whether the standard library provides all the correct
604610
/// intrinsic support for Optional<T>.
605611
///

include/swift/Sema/Constraint.h

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ enum class ConstraintKind : char {
137137
/// name, and the type of that member, when referenced as a value, is the
138138
/// second type.
139139
UnresolvedValueMember,
140+
/// The first type conforms to the protocol in which the member requirement
141+
/// resides. Once the conformance is resolved, the value witness will be
142+
/// determined, and the type of that witness, when referenced as a value,
143+
/// will be bound to the second type.
144+
ValueWitness,
140145
/// The first type can be defaulted to the second (which currently
141146
/// cannot be dependent). This is more like a type property than a
142147
/// relational constraint.
@@ -406,11 +411,18 @@ class Constraint final : public llvm::ilist_node<Constraint>,
406411
/// The type of the member.
407412
Type Second;
408413

409-
/// If non-null, the name of a member of the first type is that
410-
/// being related to the second type.
411-
///
412-
/// Used for ValueMember an UnresolvedValueMember constraints.
413-
DeclNameRef Name;
414+
union {
415+
/// If non-null, the name of a member of the first type is that
416+
/// being related to the second type.
417+
///
418+
/// Used for ValueMember an UnresolvedValueMember constraints.
419+
DeclNameRef Name;
420+
421+
/// If non-null, the member being referenced.
422+
///
423+
/// Used for ValueWitness constraints.
424+
ValueDecl *Ref;
425+
} Member;
414426

415427
/// The DC in which the use appears.
416428
DeclContext *UseDC;
@@ -525,6 +537,12 @@ class Constraint final : public llvm::ilist_node<Constraint>,
525537
FunctionRefKind functionRefKind,
526538
ConstraintLocator *locator);
527539

540+
/// Create a new value witness constraint.
541+
static Constraint *createValueWitness(
542+
ConstraintSystem &cs, ConstraintKind kind, Type first, Type second,
543+
ValueDecl *requirement, DeclContext *useDC,
544+
FunctionRefKind functionRefKind, ConstraintLocator *locator);
545+
528546
/// Create an overload-binding constraint.
529547
static Constraint *createBindOverload(ConstraintSystem &cs, Type type,
530548
OverloadChoice choice,
@@ -672,6 +690,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
672690

673691
case ConstraintKind::ValueMember:
674692
case ConstraintKind::UnresolvedValueMember:
693+
case ConstraintKind::ValueWitness:
675694
case ConstraintKind::PropertyWrapper:
676695
return ConstraintClassification::Member;
677696

@@ -711,6 +730,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
711730

712731
case ConstraintKind::ValueMember:
713732
case ConstraintKind::UnresolvedValueMember:
733+
case ConstraintKind::ValueWitness:
714734
return Member.First;
715735

716736
case ConstraintKind::SyntacticElement:
@@ -732,6 +752,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,
732752

733753
case ConstraintKind::ValueMember:
734754
case ConstraintKind::UnresolvedValueMember:
755+
case ConstraintKind::ValueWitness:
735756
return Member.Second;
736757

737758
default:
@@ -757,13 +778,20 @@ class Constraint final : public llvm::ilist_node<Constraint>,
757778
DeclNameRef getMember() const {
758779
assert(Kind == ConstraintKind::ValueMember ||
759780
Kind == ConstraintKind::UnresolvedValueMember);
760-
return Member.Name;
781+
return Member.Member.Name;
782+
}
783+
784+
/// Retrieve the requirement being referenced by a value witness constraint.
785+
ValueDecl *getRequirement() const {
786+
assert(Kind == ConstraintKind::ValueWitness);
787+
return Member.Member.Ref;
761788
}
762789

763790
/// Determine the kind of function reference we have for a member reference.
764791
FunctionRefKind getFunctionRefKind() const {
765792
if (Kind == ConstraintKind::ValueMember ||
766-
Kind == ConstraintKind::UnresolvedValueMember)
793+
Kind == ConstraintKind::UnresolvedValueMember ||
794+
Kind == ConstraintKind::ValueWitness)
767795
return static_cast<FunctionRefKind>(TheFunctionRefKind);
768796

769797
// Conservative answer: drop all of the labels.
@@ -823,7 +851,8 @@ class Constraint final : public llvm::ilist_node<Constraint>,
823851
/// Retrieve the DC in which the member was used.
824852
DeclContext *getMemberUseDC() const {
825853
assert(Kind == ConstraintKind::ValueMember ||
826-
Kind == ConstraintKind::UnresolvedValueMember);
854+
Kind == ConstraintKind::UnresolvedValueMember ||
855+
Kind == ConstraintKind::ValueWitness);
827856
return Member.UseDC;
828857
}
829858

include/swift/Sema/ConstraintSystem.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4083,6 +4083,26 @@ class ConstraintSystem {
40834083
}
40844084
}
40854085

4086+
/// Add a value witness constraint to the constraint system.
4087+
void addValueWitnessConstraint(
4088+
Type baseTy, ValueDecl *requirement, Type memberTy, DeclContext *useDC,
4089+
FunctionRefKind functionRefKind, ConstraintLocatorBuilder locator) {
4090+
assert(baseTy);
4091+
assert(memberTy);
4092+
assert(requirement);
4093+
assert(useDC);
4094+
switch (simplifyValueWitnessConstraint(
4095+
ConstraintKind::ValueWitness, baseTy, requirement, memberTy, useDC,
4096+
functionRefKind, TMF_GenerateConstraints, locator)) {
4097+
case SolutionKind::Unsolved:
4098+
llvm_unreachable("Unsolved result when generating constraints!");
4099+
4100+
case SolutionKind::Solved:
4101+
case SolutionKind::Error:
4102+
break;
4103+
}
4104+
}
4105+
40864106
/// Add an explicit conversion constraint (e.g., \c 'x as T').
40874107
///
40884108
/// \param fromType The type of the expression being converted.

lib/AST/ASTContext.cpp

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,12 @@ struct ASTContext::Implementation {
210210
/// The declaration of 'AsyncSequence.makeAsyncIterator()'.
211211
FuncDecl *MakeAsyncIterator = nullptr;
212212

213+
/// The declaration of 'IteratorProtocol.next()'.
214+
FuncDecl *IteratorNext = nullptr;
215+
216+
/// The declaration of 'AsyncIteratorProtocol.next()'.
217+
FuncDecl *AsyncIteratorNext = nullptr;
218+
213219
/// The declaration of Swift.Optional<T>.Some.
214220
EnumElementDecl *OptionalSomeDecl = nullptr;
215221

@@ -779,31 +785,40 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
779785
return getImpl().PlusFunctionOnString;
780786
}
781787

782-
FuncDecl *ASTContext::getSequenceMakeIterator() const {
783-
if (getImpl().MakeIterator) {
784-
return getImpl().MakeIterator;
785-
}
786-
787-
auto proto = getProtocol(KnownProtocolKind::Sequence);
788-
if (!proto)
789-
return nullptr;
790-
791-
for (auto result : proto->lookupDirect(Id_makeIterator)) {
788+
static FuncDecl *lookupRequirement(ProtocolDecl *proto,
789+
Identifier requirement) {
790+
for (auto result : proto->lookupDirect(requirement)) {
792791
if (result->getDeclContext() != proto)
793792
continue;
794793

795794
if (auto func = dyn_cast<FuncDecl>(result)) {
796795
if (func->getParameters()->size() != 0)
797796
continue;
798797

799-
getImpl().MakeIterator = func;
800798
return func;
801799
}
802800
}
803801

804802
return nullptr;
805803
}
806804

805+
FuncDecl *ASTContext::getSequenceMakeIterator() const {
806+
if (getImpl().MakeIterator) {
807+
return getImpl().MakeIterator;
808+
}
809+
810+
auto proto = getProtocol(KnownProtocolKind::Sequence);
811+
if (!proto)
812+
return nullptr;
813+
814+
if (auto *func = lookupRequirement(proto, Id_makeIterator)) {
815+
getImpl().MakeIterator = func;
816+
return func;
817+
}
818+
819+
return nullptr;
820+
}
821+
807822
FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
808823
if (getImpl().MakeAsyncIterator) {
809824
return getImpl().MakeAsyncIterator;
@@ -813,17 +828,43 @@ FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
813828
if (!proto)
814829
return nullptr;
815830

816-
for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) {
817-
if (result->getDeclContext() != proto)
818-
continue;
831+
if (auto *func = lookupRequirement(proto, Id_makeAsyncIterator)) {
832+
getImpl().MakeAsyncIterator = func;
833+
return func;
834+
}
819835

820-
if (auto func = dyn_cast<FuncDecl>(result)) {
821-
if (func->getParameters()->size() != 0)
822-
continue;
836+
return nullptr;
837+
}
823838

824-
getImpl().MakeAsyncIterator = func;
825-
return func;
826-
}
839+
FuncDecl *ASTContext::getIteratorNext() const {
840+
if (getImpl().IteratorNext) {
841+
return getImpl().IteratorNext;
842+
}
843+
844+
auto proto = getProtocol(KnownProtocolKind::IteratorProtocol);
845+
if (!proto)
846+
return nullptr;
847+
848+
if (auto *func = lookupRequirement(proto, Id_next)) {
849+
getImpl().IteratorNext = func;
850+
return func;
851+
}
852+
853+
return nullptr;
854+
}
855+
856+
FuncDecl *ASTContext::getAsyncIteratorNext() const {
857+
if (getImpl().AsyncIteratorNext) {
858+
return getImpl().AsyncIteratorNext;
859+
}
860+
861+
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
862+
if (!proto)
863+
return nullptr;
864+
865+
if (auto *func = lookupRequirement(proto, Id_next)) {
866+
getImpl().AsyncIteratorNext = func;
867+
return func;
827868
}
828869

829870
return nullptr;

lib/Sema/CSBindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,7 @@ void PotentialBindings::infer(Constraint *constraint) {
14671467

14681468
case ConstraintKind::ValueMember:
14691469
case ConstraintKind::UnresolvedValueMember:
1470+
case ConstraintKind::ValueWitness:
14701471
case ConstraintKind::PropertyWrapper: {
14711472
// If current type variable represents a member type of some reference,
14721473
// it would be bound once member is resolved either to a actual member

lib/Sema/CSGen.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3992,6 +3992,21 @@ generateForEachStmtConstraints(
39923992
AwaitExpr::createImplicit(ctx, /*awaitLoc=*/SourceLoc(), nextCall);
39933993
}
39943994

3995+
// The iterator type must conform to IteratorProtocol.
3996+
{
3997+
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
3998+
cs.getASTContext(), stmt->getForLoc(),
3999+
isAsync ? KnownProtocolKind::AsyncIteratorProtocol
4000+
: KnownProtocolKind::IteratorProtocol);
4001+
if (!iteratorProto)
4002+
return None;
4003+
4004+
cs.setContextualType(
4005+
nextRef->getBase(),
4006+
TypeLoc::withoutLoc(iteratorProto->getDeclaredInterfaceType()),
4007+
CTP_ForEachSequence);
4008+
}
4009+
39954010
SolutionApplicationTarget nextTarget(nextCall, dc, CTP_Unused,
39964011
/*contextualType=*/Type(),
39974012
/*isDiscarded=*/false);

0 commit comments

Comments
 (0)