Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polymorphic inference: support for parameter specifications and lambdas #15837

Merged
merged 15 commits into from
Aug 15, 2023
11 changes: 7 additions & 4 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecType,
PartialType,
TupleType,
Expand Down Expand Up @@ -109,9 +108,13 @@ def apply_generic_arguments(
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
nt = get_proper_type(nt)
if isinstance(nt, (CallableType, Parameters)):
callable = callable.expand_param_spec(nt)
# ParamSpec expansion is special-cased, so we need to always expand callable
# as a whole, not expanding arguments individually.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(
variables=[tv for tv in tvars if tv.id not in id_to_type]
)

# Apply arguments to argument types.
var_arg = callable.var_arg()
Expand Down
13 changes: 10 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4280,12 +4280,14 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
return_type = self.return_types[-1]
return_type = get_proper_type(return_type)

is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
if isinstance(return_type, UninhabitedType):
self.fail(message_registry.NO_RETURN_EXPECTED, s)
return
# Avoid extra error messages for failed inference in lambdas
if not is_lambda or not return_type.ambiguous:
self.fail(message_registry.NO_RETURN_EXPECTED, s)
return

if s.expr:
is_lambda = isinstance(self.scope.top_function(), LambdaExpr)
declared_none_return = isinstance(return_type, NoneType)
declared_any_return = isinstance(return_type, AnyType)

Expand Down Expand Up @@ -7366,6 +7368,11 @@ def visit_erased_type(self, t: ErasedType) -> bool:
# This can happen inside a lambda.
return True

def visit_type_var(self, t: TypeVarType) -> bool:
# This is needed to prevent leaking into partial types during
# multi-step type inference.
return t.id.is_meta_var()


class SetNothingToAny(TypeTranslator):
"""Replace all ambiguous <nothing> types with Any (to avoid spurious extra errors)."""
Expand Down
123 changes: 109 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from mypy.checkstrformat import StringFormatterChecker
from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars
from mypy.errors import ErrorWatcher, report_internal_error
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
from mypy.expandtype import (
expand_type,
expand_type_by_instance,
freshen_all_functions_type_vars,
freshen_function_type_vars,
)
from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments
from mypy.literals import literal
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -122,6 +127,7 @@
false_only,
fixup_partial_type,
function_type,
get_all_type_vars,
get_type_vars,
is_literal_type_like,
make_simplified_union,
Expand All @@ -145,6 +151,7 @@
LiteralValue,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
Expand All @@ -167,6 +174,7 @@
get_proper_types,
has_recursive_types,
is_named_instance,
remove_dups,
split_with_prefix_and_suffix,
)
from mypy.types_utils import (
Expand Down Expand Up @@ -1570,6 +1578,16 @@ def check_callable_call(
lambda i: self.accept(args[i]),
)

# This is tricky: return type may contain its own type variables, like in
# def [S] (S) -> def [T] (T) -> tuple[S, T], so we need to update their ids
# to avoid possible id clashes if this call itself appears in a generic
# function body.
ret_type = get_proper_type(callee.ret_type)
if isinstance(ret_type, CallableType) and ret_type.variables:
fresh_ret_type = freshen_all_functions_type_vars(callee.ret_type)
freeze_all_type_vars(fresh_ret_type)
callee = callee.copy_modified(ret_type=fresh_ret_type)

if callee.is_generic():
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
Expand All @@ -1588,7 +1606,7 @@ def check_callable_call(
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context
callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context
)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
Expand Down Expand Up @@ -1855,6 +1873,8 @@ def infer_function_type_arguments_using_context(
# def identity(x: T) -> T: return x
#
# expects_literal(identity(3)) # Should type-check
# TODO: we may want to add similar exception if all arguments are lambdas, since
# in this case external context is almost everything we have.
if not is_generic_instance(ctx) and not is_literal_type_like(ctx):
return callable.copy_modified()
args = infer_type_arguments(callable.variables, ret_type, erased_ctx)
Expand All @@ -1876,7 +1896,9 @@ def infer_function_type_arguments(
callee_type: CallableType,
args: list[Expression],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
need_refresh: bool,
context: Context,
) -> CallableType:
"""Infer the type arguments for a generic callee type.
Expand Down Expand Up @@ -1918,7 +1940,14 @@ def infer_function_type_arguments(
if 2 in arg_pass_nums:
# Second pass of type inference.
(callee_type, inferred_args) = self.infer_function_type_arguments_pass2(
callee_type, args, arg_kinds, formal_to_actual, inferred_args, context
callee_type,
args,
arg_kinds,
arg_names,
formal_to_actual,
inferred_args,
need_refresh,
context,
)

if (
Expand All @@ -1944,6 +1973,17 @@ def infer_function_type_arguments(
or set(get_type_vars(a)) & set(callee_type.variables)
for a in inferred_args
):
if need_refresh:
# Technically we need to refresh formal_to_actual after *each* inference pass,
# since each pass can expand ParamSpec or TypeVarTuple. Although such situations
# are very rare, not doing this can cause crashes.
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee_type.arg_kinds,
callee_type.arg_names,
lambda a: self.accept(args[a]),
)
# If the regular two-phase inference didn't work, try inferring type
# variables while allowing for polymorphic solutions, i.e. for solutions
# potentially involving free variables.
Expand Down Expand Up @@ -1991,8 +2031,10 @@ def infer_function_type_arguments_pass2(
callee_type: CallableType,
args: list[Expression],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
old_inferred_args: Sequence[Type | None],
need_refresh: bool,
context: Context,
) -> tuple[CallableType, list[Type | None]]:
"""Perform second pass of generic function type argument inference.
Expand All @@ -2014,6 +2056,14 @@ def infer_function_type_arguments_pass2(
if isinstance(arg, (NoneType, UninhabitedType)) or has_erased_component(arg):
inferred_args[i] = None
callee_type = self.apply_generic_arguments(callee_type, inferred_args, context)
if need_refresh:
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee_type.arg_kinds,
callee_type.arg_names,
lambda a: self.accept(args[a]),
)

arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual)

Expand Down Expand Up @@ -4674,8 +4724,22 @@ def infer_lambda_type_using_context(
# they must be considered as indeterminate. We use ErasedType since it
# does not affect type inference results (it is for purposes like this
# only).
callable_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType()))
assert isinstance(callable_ctx, CallableType)
if self.chk.options.new_type_inference:
# With new type inference we can preserve argument types even if they
# are generic, since new inference algorithm can handle constraints
# like S <: T (we still erase return type since it's ultimately unknown).
extra_vars = []
for arg in ctx.arg_types:
meta_vars = [tv for tv in get_all_type_vars(arg) if tv.id.is_meta_var()]
extra_vars.extend([tv for tv in meta_vars if tv not in extra_vars])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think extra_vars could be a set maybe? That would mean an IMO simpler comprehension.

I'm also not sure why ctx.variables is guaranteed to not include these new variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC I use this logic with lists for variables here (and in few other places) to have stable order. Otherwise tests will randomly fail on reveal_type() (and it is generally good to have predictable stable order for comparison purposes).

callable_ctx = ctx.copy_modified(
ret_type=replace_meta_vars(ctx.ret_type, ErasedType()),
variables=list(ctx.variables) + extra_vars,
)
else:
erased_ctx = replace_meta_vars(ctx, ErasedType())
assert isinstance(erased_ctx, ProperType) and isinstance(erased_ctx, CallableType)
callable_ctx = erased_ctx

# The callable_ctx may have a fallback of builtins.type if the context
# is a constructor -- but this fallback doesn't make sense for lambdas.
Expand Down Expand Up @@ -5632,18 +5696,28 @@ def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()

def visit_callable_type(self, t: CallableType) -> Type:
found_vars = set()
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
found_vars |= set(get_type_vars(arg)) & self.poly_tvars
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why drop the prefix?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have something like Callable[[P], Foo[Concatenate[T, P]]], then get_all_type_vars() will return P twice, so we need to normalize.

)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)

found_vars -= self.bound_tvars
self.bound_tvars |= found_vars
def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= found_vars
self.bound_tvars -= set(found_vars)

assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = list(result.variables) + list(found_vars)
result.variables = list(result.variables) + found_vars
return result

def visit_type_var(self, t: TypeVarType) -> Type:
Expand All @@ -5652,8 +5726,9 @@ def visit_type_var(self, t: TypeVarType) -> Type:
return super().visit_type_var(t)

def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: Support polymorphic apply for ParamSpec.
raise PolyTranslationError()
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)

def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# TODO: Support polymorphic apply for TypeVarTuple.
Expand All @@ -5669,6 +5744,26 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
raise PolyTranslationError()

def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)

repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
# Note: consider supporting bindings in instances, e.g. LRUCache[[x: T], T].
Expand Down
Loading