Skip to content

Commit

Permalink
Basic ParamSpec Concatenate and literal support (#11847)
Browse files Browse the repository at this point in the history
This PR adds a new Parameters proper type to represent ParamSpec parameters 
(more about this in the PR), along with supporting the Concatenate operator.

Closes #11833
Closes #12276
Closes #12257
Refs #8645
External ref python/typeshed#4827

Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 7, 2022
1 parent 4ff8d04 commit 07d8878
Show file tree
Hide file tree
Showing 27 changed files with 1,473 additions and 121 deletions.
7 changes: 7 additions & 0 deletions docs/source/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,13 @@ section of the command line docs.
from foo import bar
__all__ = ['bar']
.. confval:: strict_concatenate

:type: boolean
:default: False

Make arguments prepended via ``Concatenate`` be truly positional-only.

.. confval:: strict_equality

:type: boolean
Expand Down
4 changes: 2 additions & 2 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mypy.expandtype import expand_type
from mypy.types import (
Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types,
TypeVarLikeType, ProperType, ParamSpecType, get_proper_type
TypeVarLikeType, ProperType, ParamSpecType, Parameters, get_proper_type
)
from mypy.nodes import Context

Expand Down Expand Up @@ -94,7 +94,7 @@ def apply_generic_arguments(
nt = id_to_type.get(param_spec.id)
if nt is not None:
nt = get_proper_type(nt)
if isinstance(nt, CallableType):
if isinstance(nt, CallableType) or isinstance(nt, Parameters):
callable = callable.expand_param_spec(nt)

# Apply arguments to argument types.
Expand Down
3 changes: 2 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5224,7 +5224,7 @@ def check_subtype(self,
code: Optional[ErrorCode] = None,
outer_context: Optional[Context] = None) -> bool:
"""Generate an error if the subtype is not compatible with supertype."""
if is_subtype(subtype, supertype):
if is_subtype(subtype, supertype, options=self.options):
return True

if isinstance(msg, ErrorMessage):
Expand Down Expand Up @@ -5260,6 +5260,7 @@ def check_subtype(self,
self.msg.note(note, context, code=code)
if note_msg:
self.note(note_msg, context, code=code)
self.msg.maybe_note_concatenate_pos_args(subtype, supertype, context, code=code)
if (isinstance(supertype, Instance) and supertype.type.is_protocol and
isinstance(subtype, (Instance, TupleType, TypedDictType))):
self.msg.report_protocol_problems(subtype, supertype, context, code=code)
Expand Down
2 changes: 1 addition & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,7 @@ def check_arg(self,
isinstance(callee_type.item, Instance) and
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol)):
self.msg.concrete_only_call(callee_type, context)
elif not is_subtype(caller_type, callee_type):
elif not is_subtype(caller_type, callee_type, options=self.chk.options):
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
return
code = messages.incompatible_argument(n,
Expand Down
82 changes: 75 additions & 7 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType,
ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any,
UnpackType, callable_with_ellipsis, TUPLE_LIKE_INSTANCE_NAMES,
UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES,
)
from mypy.maptype import map_instance_to_supertype
import mypy.subtypes
Expand Down Expand Up @@ -406,6 +406,9 @@ def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]:
def visit_unpack_type(self, template: UnpackType) -> List[Constraint]:
raise NotImplementedError

def visit_parameters(self, template: Parameters) -> List[Constraint]:
raise RuntimeError("Parameters cannot be constrained to")

# Non-leaf types

def visit_instance(self, template: Instance) -> List[Constraint]:
Expand Down Expand Up @@ -446,7 +449,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for tvar, mapped_arg, instance_arg in zip(tvars, mapped.args, instance.args):
# TODO: ParamSpecType
# TODO(PEP612): More ParamSpec work (or is Parameters the only thing accepted)
if isinstance(tvar, TypeVarType):
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
Expand All @@ -456,6 +459,27 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
if tvar.variance != COVARIANT:
res.extend(infer_constraints(
mapped_arg, instance_arg, neg_op(self.direction)))
elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType):
suffix = get_proper_type(instance_arg)

if isinstance(suffix, CallableType):
prefix = mapped_arg.prefix
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)

if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
# TODO: constraints between prefixes
prefix = mapped_arg.prefix
suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types):],
suffix.arg_kinds[len(prefix.arg_kinds):],
suffix.arg_names[len(prefix.arg_names):])
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))

return res
elif (self.direction == SUPERTYPE_OF and
instance.type.has_base(template.type.fullname)):
Expand All @@ -464,7 +488,6 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
for tvar, mapped_arg, template_arg in zip(tvars, mapped.args, template.args):
# TODO: ParamSpecType
if isinstance(tvar, TypeVarType):
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
Expand All @@ -474,6 +497,28 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
if tvar.variance != COVARIANT:
res.extend(infer_constraints(
template_arg, mapped_arg, neg_op(self.direction)))
elif (isinstance(tvar, ParamSpecType) and
isinstance(template_arg, ParamSpecType)):
suffix = get_proper_type(mapped_arg)

if isinstance(suffix, CallableType):
prefix = template_arg.prefix
from_concat = bool(prefix.arg_types) or suffix.from_concatenate
suffix = suffix.copy_modified(from_concatenate=from_concat)

if isinstance(suffix, Parameters) or isinstance(suffix, CallableType):
# no such thing as variance for ParamSpecs
# TODO: is there a case I am missing?
# TODO: constraints between prefixes
prefix = template_arg.prefix

suffix = suffix.copy_modified(
suffix.arg_types[len(prefix.arg_types):],
suffix.arg_kinds[len(prefix.arg_kinds):],
suffix.arg_names[len(prefix.arg_names):])
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
return res
if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
# We avoid infinite recursion for structural subtypes by checking
Expand Down Expand Up @@ -564,11 +609,34 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]:
# Negate direction due to function argument type contravariance.
res.extend(infer_constraints(t, a, neg_op(self.direction)))
else:
# sometimes, it appears we try to get constraints between two paramspec callables?
# TODO: Direction
# TODO: Deal with arguments that come before param spec ones?
res.append(Constraint(param_spec.id,
SUBTYPE_OF,
cactual.copy_modified(ret_type=NoneType())))
# TODO: check the prefixes match
prefix = param_spec.prefix
prefix_len = len(prefix.arg_types)
cactual_ps = cactual.param_spec()

if not cactual_ps:
res.append(Constraint(param_spec.id,
SUBTYPE_OF,
cactual.copy_modified(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
ret_type=NoneType())))
else:
res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps))

# compare prefixes
cactual_prefix = cactual.copy_modified(
arg_types=cactual.arg_types[:prefix_len],
arg_kinds=cactual.arg_kinds[:prefix_len],
arg_names=cactual.arg_names[:prefix_len])

# TODO: see above "FIX" comments for param_spec is None case
# TODO: this assume positional arguments
for t, a in zip(prefix.arg_types, cactual_prefix.arg_types):
res.extend(infer_constraints(t, a, neg_op(self.direction)))

template_ret_type, cactual_ret_type = template.ret_type, cactual.ret_type
if template.type_guard is not None:
Expand Down
5 changes: 4 additions & 1 deletion mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType,
CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType,
DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType,
get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, UnpackType
get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, Parameters, UnpackType
)
from mypy.nodes import ARG_STAR, ARG_STAR2

Expand Down Expand Up @@ -59,6 +59,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
return AnyType(TypeOfAny.special_form)

def visit_parameters(self, t: Parameters) -> ProperType:
raise RuntimeError("Parameters should have been bound to a class")

def visit_unpack_type(self, t: UnpackType) -> ProperType:
raise NotImplementedError

Expand Down
48 changes: 40 additions & 8 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
NoneType, Overloaded, TupleType, TypedDictType, UnionType,
ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId,
FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType,
TypeAliasType, ParamSpecType, TypeVarLikeType, UnpackType
TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor,
UnpackType
)


Expand Down Expand Up @@ -101,15 +102,41 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
repl = get_proper_type(self.variables.get(t.id, t))
if isinstance(repl, Instance):
inst = repl
# Return copy of instance with type erasure flag on.
# TODO: what does prefix mean in this case?
# TODO: why does this case even happen? Instances aren't plural.
return Instance(inst.type, inst.args, line=inst.line, column=inst.column)
elif isinstance(repl, ParamSpecType):
return repl.with_flavor(t.flavor)
return repl.copy_modified(flavor=t.flavor, prefix=t.prefix.copy_modified(
arg_types=t.prefix.arg_types + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
))
elif isinstance(repl, Parameters) or isinstance(repl, CallableType):
# if the paramspec is *P.args or **P.kwargs:
if t.flavor != ParamSpecFlavor.BARE:
assert isinstance(repl, CallableType), "Should not be able to get here."
# Is this always the right thing to do?
param_spec = repl.param_spec()
if param_spec:
return param_spec.with_flavor(t.flavor)
else:
return repl
else:
return Parameters(t.prefix.arg_types + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables])
else:
# TODO: should this branch be removed? better not to fail silently
return repl

def visit_unpack_type(self, t: UnpackType) -> Type:
raise NotImplementedError

def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))

def visit_callable_type(self, t: CallableType) -> Type:
param_spec = t.param_spec()
if param_spec is not None:
Expand All @@ -121,13 +148,18 @@ def visit_callable_type(self, t: CallableType) -> Type:
# must expand both of them with all the argument types,
# kinds and names in the replacement. The return type in
# the replacement is ignored.
if isinstance(repl, CallableType):
if isinstance(repl, CallableType) or isinstance(repl, Parameters):
# Substitute *args: P.args, **kwargs: P.kwargs
t = t.expand_param_spec(repl)
# TODO: Substitute remaining arg types
return t.copy_modified(ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self)
if t.type_guard is not None else None))
prefix = param_spec.prefix
# we need to expand the types in the prefix, so might as well
# not get them in the first place
t = t.expand_param_spec(repl, no_prefix=True)
return t.copy_modified(
arg_types=self.expand_types(prefix.arg_types) + t.arg_types,
arg_kinds=prefix.arg_kinds + t.arg_kinds,
arg_names=prefix.arg_names + t.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None))

return t.copy_modified(arg_types=self.expand_types(t.arg_types),
ret_type=t.ret_type.accept(self),
Expand Down
7 changes: 6 additions & 1 deletion mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CallableType, Instance, Overloaded, TupleType, TypedDictType,
TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType,
TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType,
UnpackType,
Parameters, UnpackType,
)
from mypy.visitor import NodeVisitor
from mypy.lookup import lookup_fully_qualified
Expand Down Expand Up @@ -255,6 +255,11 @@ def visit_param_spec(self, p: ParamSpecType) -> None:
def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)

def visit_parameters(self, p: Parameters) -> None:
for argt in p.arg_types:
if argt is not None:
argt.accept(self)

def visit_unbound_type(self, o: UnboundType) -> None:
for a in o.args:
a.accept(self)
Expand Down
3 changes: 3 additions & 0 deletions mypy/indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def visit_param_spec(self, t: types.ParamSpecType) -> Set[str]:
def visit_unpack_type(self, t: types.UnpackType) -> Set[str]:
return t.type.accept(self)

def visit_parameters(self, t: types.Parameters) -> Set[str]:
return self._visit(t.arg_types)

def visit_instance(self, t: types.Instance) -> Set[str]:
out = self._visit(t.args)
if t.type:
Expand Down
8 changes: 7 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType,
TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType,
PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type,
ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType,
ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, Parameters,
UnpackType
)
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -260,6 +260,12 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
def visit_unpack_type(self, t: UnpackType) -> UnpackType:
raise NotImplementedError

def visit_parameters(self, t: Parameters) -> ProperType:
if self.s == t:
return t
else:
return self.default(self.s)

def visit_instance(self, t: Instance) -> ProperType:
if isinstance(self.s, Instance):
if self.instance_joiner is None:
Expand Down
4 changes: 4 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ def add_invertible_flag(flag: str,
" non-overlapping types",
group=strictness_group)

add_invertible_flag('--strict-concatenate', default=False, strict_flag=True,
help="Make arguments prepended via Concatenate be truly positional-only",
group=strictness_group)

strict_help = "Strict mode; enables the following flags: {}".format(
", ".join(strict_flag_names))
strictness_group.add_argument(
Expand Down
13 changes: 12 additions & 1 deletion mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType,
ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType,
ParamSpecType, UnpackType,
ParamSpecType, Parameters, UnpackType,
)
from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype
from mypy.erasetype import erase_type
Expand Down Expand Up @@ -509,6 +509,17 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType:
def visit_unpack_type(self, t: UnpackType) -> ProperType:
raise NotImplementedError

def visit_parameters(self, t: Parameters) -> ProperType:
# TODO: is this the right variance?
if isinstance(self.s, Parameters) or isinstance(self.s, CallableType):
if len(t.arg_types) != len(self.s.arg_types):
return self.default(self.s)
return t.copy_modified(
arg_types=[meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)]
)
else:
return self.default(self.s)

def visit_instance(self, t: Instance) -> ProperType:
if isinstance(self.s, Instance):
if t.type == self.s.type:
Expand Down
Loading

0 comments on commit 07d8878

Please sign in to comment.