Skip to content

[ConstraintSystem] Use witnesses for makeIterator and next refs in for-in context #61091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,12 @@ class ASTContext final {
/// Get AsyncSequence.makeAsyncIterator().
FuncDecl *getAsyncSequenceMakeAsyncIterator() const;

/// Get IteratorProtocol.next().
FuncDecl *getIteratorNext() const;

/// Get AsyncIteratorProtocol.next().
FuncDecl *getAsyncIteratorNext() const;

/// Check whether the standard library provides all the correct
/// intrinsic support for Optional<T>.
///
Expand Down
45 changes: 37 additions & 8 deletions include/swift/Sema/Constraint.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ enum class ConstraintKind : char {
/// name, and the type of that member, when referenced as a value, is the
/// second type.
UnresolvedValueMember,
/// The first type conforms to the protocol in which the member requirement
/// resides. Once the conformance is resolved, the value witness will be
/// determined, and the type of that witness, when referenced as a value,
/// will be bound to the second type.
ValueWitness,
/// The first type can be defaulted to the second (which currently
/// cannot be dependent). This is more like a type property than a
/// relational constraint.
Expand Down Expand Up @@ -406,11 +411,18 @@ class Constraint final : public llvm::ilist_node<Constraint>,
/// The type of the member.
Type Second;

/// If non-null, the name of a member of the first type is that
/// being related to the second type.
///
/// Used for ValueMember an UnresolvedValueMember constraints.
DeclNameRef Name;
union {
/// If non-null, the name of a member of the first type is that
/// being related to the second type.
///
/// Used for ValueMember an UnresolvedValueMember constraints.
DeclNameRef Name;

/// If non-null, the member being referenced.
///
/// Used for ValueWitness constraints.
ValueDecl *Ref;
} Member;

/// The DC in which the use appears.
DeclContext *UseDC;
Expand Down Expand Up @@ -525,6 +537,12 @@ class Constraint final : public llvm::ilist_node<Constraint>,
FunctionRefKind functionRefKind,
ConstraintLocator *locator);

/// Create a new value witness constraint.
static Constraint *createValueWitness(
ConstraintSystem &cs, ConstraintKind kind, Type first, Type second,
ValueDecl *requirement, DeclContext *useDC,
FunctionRefKind functionRefKind, ConstraintLocator *locator);

/// Create an overload-binding constraint.
static Constraint *createBindOverload(ConstraintSystem &cs, Type type,
OverloadChoice choice,
Expand Down Expand Up @@ -672,6 +690,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,

case ConstraintKind::ValueMember:
case ConstraintKind::UnresolvedValueMember:
case ConstraintKind::ValueWitness:
case ConstraintKind::PropertyWrapper:
return ConstraintClassification::Member;

Expand Down Expand Up @@ -711,6 +730,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,

case ConstraintKind::ValueMember:
case ConstraintKind::UnresolvedValueMember:
case ConstraintKind::ValueWitness:
return Member.First;

case ConstraintKind::SyntacticElement:
Expand All @@ -732,6 +752,7 @@ class Constraint final : public llvm::ilist_node<Constraint>,

case ConstraintKind::ValueMember:
case ConstraintKind::UnresolvedValueMember:
case ConstraintKind::ValueWitness:
return Member.Second;

default:
Expand All @@ -757,13 +778,20 @@ class Constraint final : public llvm::ilist_node<Constraint>,
DeclNameRef getMember() const {
assert(Kind == ConstraintKind::ValueMember ||
Kind == ConstraintKind::UnresolvedValueMember);
return Member.Name;
return Member.Member.Name;
}

/// Retrieve the requirement being referenced by a value witness constraint.
ValueDecl *getRequirement() const {
assert(Kind == ConstraintKind::ValueWitness);
return Member.Member.Ref;
}

/// Determine the kind of function reference we have for a member reference.
FunctionRefKind getFunctionRefKind() const {
if (Kind == ConstraintKind::ValueMember ||
Kind == ConstraintKind::UnresolvedValueMember)
Kind == ConstraintKind::UnresolvedValueMember ||
Kind == ConstraintKind::ValueWitness)
return static_cast<FunctionRefKind>(TheFunctionRefKind);

// Conservative answer: drop all of the labels.
Expand Down Expand Up @@ -823,7 +851,8 @@ class Constraint final : public llvm::ilist_node<Constraint>,
/// Retrieve the DC in which the member was used.
DeclContext *getMemberUseDC() const {
assert(Kind == ConstraintKind::ValueMember ||
Kind == ConstraintKind::UnresolvedValueMember);
Kind == ConstraintKind::UnresolvedValueMember ||
Kind == ConstraintKind::ValueWitness);
return Member.UseDC;
}

Expand Down
20 changes: 20 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -4354,6 +4354,26 @@ class ConstraintSystem {
}
}

/// Add a value witness constraint to the constraint system.
void addValueWitnessConstraint(
Type baseTy, ValueDecl *requirement, Type memberTy, DeclContext *useDC,
FunctionRefKind functionRefKind, ConstraintLocatorBuilder locator) {
assert(baseTy);
assert(memberTy);
assert(requirement);
assert(useDC);
switch (simplifyValueWitnessConstraint(
ConstraintKind::ValueWitness, baseTy, requirement, memberTy, useDC,
functionRefKind, TMF_GenerateConstraints, locator)) {
case SolutionKind::Unsolved:
llvm_unreachable("Unsolved result when generating constraints!");

case SolutionKind::Solved:
case SolutionKind::Error:
break;
}
}

/// Add an explicit conversion constraint (e.g., \c 'x as T').
///
/// \param fromType The type of the expression being converted.
Expand Down
81 changes: 61 additions & 20 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ struct ASTContext::Implementation {
/// The declaration of 'AsyncSequence.makeAsyncIterator()'.
FuncDecl *MakeAsyncIterator = nullptr;

/// The declaration of 'IteratorProtocol.next()'.
FuncDecl *IteratorNext = nullptr;

/// The declaration of 'AsyncIteratorProtocol.next()'.
FuncDecl *AsyncIteratorNext = nullptr;

/// The declaration of Swift.Optional<T>.Some.
EnumElementDecl *OptionalSomeDecl = nullptr;

Expand Down Expand Up @@ -806,31 +812,40 @@ FuncDecl *ASTContext::getPlusFunctionOnString() const {
return getImpl().PlusFunctionOnString;
}

FuncDecl *ASTContext::getSequenceMakeIterator() const {
if (getImpl().MakeIterator) {
return getImpl().MakeIterator;
}

auto proto = getProtocol(KnownProtocolKind::Sequence);
if (!proto)
return nullptr;

for (auto result : proto->lookupDirect(Id_makeIterator)) {
static FuncDecl *lookupRequirement(ProtocolDecl *proto,
Identifier requirement) {
for (auto result : proto->lookupDirect(requirement)) {
if (result->getDeclContext() != proto)
continue;

if (auto func = dyn_cast<FuncDecl>(result)) {
if (func->getParameters()->size() != 0)
continue;

getImpl().MakeIterator = func;
return func;
}
}

return nullptr;
}

FuncDecl *ASTContext::getSequenceMakeIterator() const {
if (getImpl().MakeIterator) {
return getImpl().MakeIterator;
}

auto proto = getProtocol(KnownProtocolKind::Sequence);
if (!proto)
return nullptr;

if (auto *func = lookupRequirement(proto, Id_makeIterator)) {
getImpl().MakeIterator = func;
return func;
}

return nullptr;
}

FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
if (getImpl().MakeAsyncIterator) {
return getImpl().MakeAsyncIterator;
Expand All @@ -840,17 +855,43 @@ FuncDecl *ASTContext::getAsyncSequenceMakeAsyncIterator() const {
if (!proto)
return nullptr;

for (auto result : proto->lookupDirect(Id_makeAsyncIterator)) {
if (result->getDeclContext() != proto)
continue;
if (auto *func = lookupRequirement(proto, Id_makeAsyncIterator)) {
getImpl().MakeAsyncIterator = func;
return func;
}

if (auto func = dyn_cast<FuncDecl>(result)) {
if (func->getParameters()->size() != 0)
continue;
return nullptr;
}

getImpl().MakeAsyncIterator = func;
return func;
}
FuncDecl *ASTContext::getIteratorNext() const {
if (getImpl().IteratorNext) {
return getImpl().IteratorNext;
}

auto proto = getProtocol(KnownProtocolKind::IteratorProtocol);
if (!proto)
return nullptr;

if (auto *func = lookupRequirement(proto, Id_next)) {
getImpl().IteratorNext = func;
return func;
}

return nullptr;
}

FuncDecl *ASTContext::getAsyncIteratorNext() const {
if (getImpl().AsyncIteratorNext) {
return getImpl().AsyncIteratorNext;
}

auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
if (!proto)
return nullptr;

if (auto *func = lookupRequirement(proto, Id_next)) {
getImpl().AsyncIteratorNext = func;
return func;
}

return nullptr;
Expand Down
1 change: 1 addition & 0 deletions lib/Sema/CSBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,7 @@ void PotentialBindings::infer(Constraint *constraint) {

case ConstraintKind::ValueMember:
case ConstraintKind::UnresolvedValueMember:
case ConstraintKind::ValueWitness:
case ConstraintKind::PropertyWrapper: {
// If current type variable represents a member type of some reference,
// it would be bound once member is resolved either to a actual member
Expand Down
15 changes: 15 additions & 0 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4057,6 +4057,21 @@ generateForEachStmtConstraints(
AwaitExpr::createImplicit(ctx, /*awaitLoc=*/SourceLoc(), nextCall);
}

// The iterator type must conform to IteratorProtocol.
{
ProtocolDecl *iteratorProto = TypeChecker::getProtocol(
cs.getASTContext(), stmt->getForLoc(),
isAsync ? KnownProtocolKind::AsyncIteratorProtocol
: KnownProtocolKind::IteratorProtocol);
if (!iteratorProto)
return None;

cs.setContextualType(
nextRef->getBase(),
TypeLoc::withoutLoc(iteratorProto->getDeclaredInterfaceType()),
CTP_ForEachSequence);
}

SolutionApplicationTarget nextTarget(nextCall, dc, CTP_Unused,
/*contextualType=*/Type(),
/*isDiscarded=*/false);
Expand Down
Loading