Skip to content

[CodeCompletion] Explicitly support enum pattern matching #38627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ IDENTIFIER(decodeIfPresent)
IDENTIFIER(Decoder)
IDENTIFIER(decoder)
IDENTIFIER_(Differentiation)
IDENTIFIER_WITH_NAME(PatternMatchVar, "$match")
IDENTIFIER(dynamicallyCall)
IDENTIFIER(dynamicMember)
IDENTIFIER(Element)
Expand Down
11 changes: 8 additions & 3 deletions include/swift/Sema/CodeCompletionTypeChecking.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,26 @@ namespace swift {
/// formed during expression type-checking.
class UnresolvedMemberTypeCheckCompletionCallback: public TypeCheckCompletionCallback {
public:
struct Result {
struct ExprResult {
Type ExpectedTy;
bool IsImplicitSingleExpressionReturn;
};

private:
CodeCompletionExpr *CompletionExpr;
SmallVector<Result, 4> Results;
SmallVector<ExprResult, 4> ExprResults;
SmallVector<Type, 1> EnumPatternTypes;
bool GotCallback = false;

public:
UnresolvedMemberTypeCheckCompletionCallback(CodeCompletionExpr *CompletionExpr)
: CompletionExpr(CompletionExpr) {}

ArrayRef<Result> getResults() const { return Results; }
ArrayRef<ExprResult> getExprResults() const { return ExprResults; }

/// If we are completing in a pattern matching position, the types of all
/// enums for whose cases are valid as an \c EnumElementPattern.
ArrayRef<Type> getEnumPatternTypes() const { return EnumPatternTypes; }

/// True if at least one solution was passed via the \c sawSolution
/// callback.
Expand Down
76 changes: 59 additions & 17 deletions lib/IDE/CodeCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3312,8 +3312,6 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
getSemanticContext(EED, Reason, dynamicLookupInfo),
expectedTypeContext);
Builder.setAssociatedDecl(EED);
if (HasTypeContext)
Builder.addFlair(CodeCompletionFlairBit::ExpressionSpecific);

addLeadingDot(Builder);
addValueBaseName(Builder, EED->getBaseIdentifier());
Expand Down Expand Up @@ -4372,6 +4370,23 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
addObjCPoundKeywordCompletions(/*needPound=*/true);
}

/// Returns \c true if \p VD is an initializer on the \c Optional or \c
/// Id_OptionalNilComparisonType type from the Swift stdlib.
static bool isInitializerOnOptional(Type T, ValueDecl *VD) {
bool IsOptionalType = false;
IsOptionalType |= static_cast<bool>(T->getOptionalObjectType());
if (auto *NTD = T->getAnyNominal()) {
IsOptionalType |= NTD->getBaseIdentifier() ==
VD->getASTContext().Id_OptionalNilComparisonType;
}
if (IsOptionalType && VD->getModuleContext()->isStdlibModule() &&
isa<ConstructorDecl>(VD)) {
return true;
} else {
return false;
}
}

void getUnresolvedMemberCompletions(Type T) {
if (!T->mayHaveMembers())
return;
Expand All @@ -4389,16 +4404,11 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
// We can only say .foo where foo is a static member of the contextual
// type and has the same type (or if the member is a function, then the
// same result type) as the contextual type.
FilteredDeclConsumer consumer(*this, [=](ValueDecl *VD,
DeclVisibilityKind Reason) {
if (T->getOptionalObjectType() &&
VD->getModuleContext()->isStdlibModule()) {
// In optional context, ignore '.init(<some>)', 'init(nilLiteral:)',
if (isa<ConstructorDecl>(VD))
return false;
}
return true;
});
FilteredDeclConsumer consumer(
*this, [=](ValueDecl *VD, DeclVisibilityKind Reason) {
// In optional context, ignore '.init(<some>)', 'init(nilLiteral:)',
return !isInitializerOnOptional(T, VD);
});

auto baseType = MetatypeType::get(T);
llvm::SaveAndRestore<LookupKind> SaveLook(Kind, LookupKind::ValueExpr);
Expand All @@ -4410,6 +4420,21 @@ class CompletionLookup final : public swift::VisibleDeclConsumer {
/*includeProtocolExtensionMembers*/true);
}

/// Complete all enum members declared on \p T.
void getEnumElementPatternCompletions(Type T) {
if (!isa_and_nonnull<EnumDecl>(T->getAnyNominal()))
return;

auto baseType = MetatypeType::get(T);
llvm::SaveAndRestore<LookupKind> SaveLook(Kind, LookupKind::EnumElement);
llvm::SaveAndRestore<Type> SaveType(ExprType, baseType);
llvm::SaveAndRestore<bool> SaveUnresolved(IsUnresolvedMember, true);
lookupVisibleMemberDecls(*this, baseType, CurrDeclContext,
/*includeInstanceMembers=*/false,
/*includeDerivedRequirements=*/false,
/*includeProtocolExtensionMembers=*/true);
}

void getUnresolvedMemberCompletions(ArrayRef<Type> Types) {
NeedLeadingDot = !HaveDot;

Expand Down Expand Up @@ -6461,8 +6486,8 @@ static void deliverCompletionResults(CodeCompletionContext &CompletionContext,
}

void deliverUnresolvedMemberResults(
ArrayRef<UnresolvedMemberTypeCheckCompletionCallback::Result> Results,
DeclContext *DC, SourceLoc DotLoc,
ArrayRef<UnresolvedMemberTypeCheckCompletionCallback::ExprResult> Results,
ArrayRef<Type> EnumPatternTypes, DeclContext *DC, SourceLoc DotLoc,
ide::CodeCompletionContext &CompletionCtx,
CodeCompletionConsumer &Consumer) {
ASTContext &Ctx = DC->getASTContext();
Expand All @@ -6471,7 +6496,7 @@ void deliverUnresolvedMemberResults(

assert(DotLoc.isValid());
Lookup.setHaveDot(DotLoc);
Lookup.shouldCheckForDuplicates(Results.size() > 1);
Lookup.shouldCheckForDuplicates(Results.size() + EnumPatternTypes.size() > 1);

// Get the canonical versions of the top-level types
SmallPtrSet<CanType, 4> originalTypes;
Expand All @@ -6496,6 +6521,22 @@ void deliverUnresolvedMemberResults(
Lookup.getUnresolvedMemberCompletions(Result.ExpectedTy);
}

// Offer completions when interpreting the pattern match as an
// EnumElementPattern.
for (auto &Ty : EnumPatternTypes) {
Lookup.setExpectedTypes({Ty}, /*IsImplicitSingleExpressionReturn=*/false,
/*expectsNonVoid=*/true);
Lookup.setIdealExpectedType(Ty);

// We can pattern match MyEnum against Optional<MyEnum>
if (Ty->getOptionalObjectType()) {
Type Unwrapped = Ty->lookThroughAllOptionalTypes();
Lookup.getEnumElementPatternCompletions(Unwrapped);
}

Lookup.getEnumElementPatternCompletions(Ty);
}

deliverCompletionResults(CompletionCtx, Lookup, DC, Consumer);
}

Expand Down Expand Up @@ -6608,8 +6649,9 @@ bool CodeCompletionCallbacksImpl::trySolverCompletion(bool MaybeFuncBody) {
Lookup.fallbackTypeCheck(CurDeclContext);

addKeywords(CompletionContext.getResultSink(), MaybeFuncBody);
deliverUnresolvedMemberResults(Lookup.getResults(), CurDeclContext, DotLoc,
CompletionContext, Consumer);
deliverUnresolvedMemberResults(Lookup.getExprResults(),
Lookup.getEnumPatternTypes(), CurDeclContext,
DotLoc, CompletionContext, Consumer);
return true;
}
case CompletionKind::KeyPathExprSwift: {
Expand Down
151 changes: 91 additions & 60 deletions lib/Sema/TypeCheckCodeCompletion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -811,59 +811,11 @@ class CompletionContextFinder : public ASTWalker {

} // end namespace

// Determine if the target expression is the implicit BinaryExpr generated for
// pattern-matching in a switch/if/guard case (<completion> ~= matchValue).
static bool isForPatternMatch(SolutionApplicationTarget &target) {
if (target.getExprContextualTypePurpose() != CTP_Condition)
return false;
Expr *condition = target.getAsExpr();
if (!condition->isImplicit())
return false;
if (auto *BE = dyn_cast<BinaryExpr>(condition)) {
Identifier id;
if (auto *ODRE = dyn_cast<OverloadedDeclRefExpr>(BE->getFn())) {
id = ODRE->getDecls().front()->getBaseIdentifier();
} else if (auto *DRE = dyn_cast<DeclRefExpr>(BE->getFn())) {
id = DRE->getDecl()->getBaseIdentifier();
}
if (id != target.getDeclContext()->getASTContext().Id_MatchOperator)
return false;
return isa<CodeCompletionExpr>(BE->getLHS());
}
return false;
}

/// Remove any solutions from the provided vector that both require fixes and have a
/// score worse than the best.
/// Remove any solutions from the provided vector that both require fixes and
/// have a score worse than the best.
static void filterSolutions(SolutionApplicationTarget &target,
SmallVectorImpl<Solution> &solutions,
CodeCompletionExpr *completionExpr) {
// FIXME: this is only needed because in pattern matching position, the
// code completion expression always becomes an expression pattern, which
// requires the ~= operator to be defined on the type being matched against.
// Pattern matching against an enum doesn't require that however, so valid
// solutions always end up having fixes. This is a problem because there will
// always be a valid solution as well. Optional defines ~= between Optional
// and _OptionalNilComparisonType (which defines a nilLiteral initializer),
// and the matched-against value can implicitly be made Optional if it isn't
// already, so _OptionalNilComparisonType is always a valid solution for the
// completion. That only generates the 'nil' completion, which is rarely what
// the user intends to write in this position and shouldn't be preferred over
// the other formed solutions (which require fixes). We should generate enum
// pattern completions separately, but for now ignore the
// _OptionalNilComparisonType solution.
if (isForPatternMatch(target) && completionExpr) {
solutions.erase(llvm::remove_if(solutions, [&](const Solution &S) {
ASTContext &ctx = S.getConstraintSystem().getASTContext();
if (!S.hasType(completionExpr))
return false;
if (auto ty = S.getResolvedType(completionExpr))
if (auto *NTD = ty->getAnyNominal())
return NTD->getBaseIdentifier() == ctx.Id_OptionalNilComparisonType;
return false;
}), solutions.end());
}

if (solutions.size() <= 1)
return;

Expand Down Expand Up @@ -1286,6 +1238,69 @@ sawSolution(const constraints::Solution &S) {
}
}

/// If the code completion variable occurs in a pattern matching position, we
/// have an AST that looks like this.
/// \code
/// (binary_expr implicit type='$T3'
/// (overloaded_decl_ref_expr function_ref=compound decls=[
/// Swift.(file).~=,
/// Swift.(file).Optional extension.~=])
/// (tuple_expr implicit type='($T1, (OtherEnum))'
/// (code_completion_expr implicit type='$T1')
/// (declref_expr implicit decl=swift_ide_test.(file).foo(x:).$match)))
/// \endcode
/// If the code completion expression occurs in such an AST, return the
/// declaration of the \c $match variable, otherwise return \c nullptr.
VarDecl *getMatchVarIfInPatternMatch(CodeCompletionExpr *CompletionExpr,
ConstraintSystem &CS) {
auto &Context = CS.getASTContext();

TupleExpr *ArgTuple =
dyn_cast_or_null<TupleExpr>(CS.getParentExpr(CompletionExpr));
if (!ArgTuple || !ArgTuple->isImplicit() || ArgTuple->getNumElements() != 2) {
return nullptr;
}

auto Binary = dyn_cast_or_null<BinaryExpr>(CS.getParentExpr(ArgTuple));
if (!Binary || !Binary->isImplicit()) {
return nullptr;
}
Comment on lines +1258 to +1267
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably conflicts with #38836 cc: @hamishknight


auto CalledOperator = Binary->getFn();
if (!CalledOperator || !CalledOperator->isImplicit()) {
return nullptr;
}
// The reference to the ~= operator might be an OverloadedDeclRefExpr or a
// DeclRefExpr, depending on how many ~= operators are viable.
if (auto Overloaded =
dyn_cast_or_null<OverloadedDeclRefExpr>(CalledOperator)) {
if (!llvm::all_of(Overloaded->getDecls(), [&Context](ValueDecl *D) {
return D->getBaseName() == Context.Id_MatchOperator;
})) {
return nullptr;
}
} else if (auto Ref = dyn_cast_or_null<DeclRefExpr>(CalledOperator)) {
if (Ref->getDecl()->getBaseName() != Context.Id_MatchOperator) {
return nullptr;
}
} else {
return nullptr;
}

auto MatchArg = dyn_cast_or_null<DeclRefExpr>(ArgTuple->getElement(1));
if (!MatchArg || !MatchArg->isImplicit()) {
return nullptr;
}

auto MatchVar = MatchArg->getDecl();
if (MatchVar && MatchVar->isImplicit() &&
MatchVar->getBaseName() == Context.Id_PatternMatchVar) {
return dyn_cast<VarDecl>(MatchVar);
} else {
return nullptr;
}
}

void UnresolvedMemberTypeCheckCompletionCallback::
sawSolution(const constraints::Solution &S) {
GotCallback = true;
Expand All @@ -1295,18 +1310,34 @@ sawSolution(const constraints::Solution &S) {
// If the type couldn't be determined (e.g. because there isn't any context
// to derive it from), let's not attempt to do a lookup since it wouldn't
// produce any useful results anyway.
if (!ExpectedTy || ExpectedTy->is<UnresolvedType>())
return;

// If ExpectedTy is a duplicate of any other result, ignore this solution.
if (llvm::any_of(Results, [&](const Result &R) {
return R.ExpectedTy->isEqual(ExpectedTy);
})) {
return;
if (ExpectedTy && !ExpectedTy->is<UnresolvedType>()) {
// If ExpectedTy is a duplicate of any other result, ignore this solution.
if (!llvm::any_of(ExprResults, [&](const ExprResult &R) {
return R.ExpectedTy->isEqual(ExpectedTy);
})) {
bool SingleExprBody =
isImplicitSingleExpressionReturn(CS, CompletionExpr);
ExprResults.push_back({ExpectedTy, SingleExprBody});
}
}

bool SingleExprBody = isImplicitSingleExpressionReturn(CS, CompletionExpr);
Results.push_back({ExpectedTy, SingleExprBody});
if (auto MatchVar = getMatchVarIfInPatternMatch(CompletionExpr, CS)) {
Type MatchVarType;
// If the MatchVar has an explicit type, it's not part of the solution. But
// we can look it up in the constraint system directly.
if (auto T = S.getConstraintSystem().getVarType(MatchVar)) {
MatchVarType = T;
} else {
MatchVarType = S.getResolvedType(MatchVar);
}
if (MatchVarType && !MatchVarType->is<UnresolvedType>()) {
if (!llvm::any_of(EnumPatternTypes, [&](const Type &R) {
return R->isEqual(MatchVarType);
})) {
EnumPatternTypes.push_back(MatchVarType);
}
}
}
}

void KeyPathTypeCheckCompletionCallback::sawSolution(
Expand Down
8 changes: 3 additions & 5 deletions lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,9 @@ bool TypeChecker::typeCheckExprPattern(ExprPattern *EP, DeclContext *DC,
PrettyStackTracePattern stackTrace(Context, "type-checking", EP);

// Create a 'let' binding to stand in for the RHS value.
auto *matchVar = new (Context) VarDecl(/*IsStatic*/false,
VarDecl::Introducer::Let,
EP->getLoc(),
Context.getIdentifier("$match"),
DC);
auto *matchVar =
new (Context) VarDecl(/*IsStatic*/ false, VarDecl::Introducer::Let,
EP->getLoc(), Context.Id_PatternMatchVar, DC);
matchVar->setInterfaceType(rhsType->mapTypeOutOfContext());

matchVar->setImplicit();
Expand Down
Loading