Skip to content

Commit

Permalink
Use helper methods for a lot of ArgKind checks (python#10793)
Browse files Browse the repository at this point in the history
Part of the motivation here is that I want to make positional-only
arguments be properly reflected in the argument kinds, and having most
of the logic done through helpers will make that easier.
  • Loading branch information
msullivan authored Jul 10, 2021
1 parent 3552971 commit a9f3b5e
Show file tree
Hide file tree
Showing 17 changed files with 83 additions and 75 deletions.
7 changes: 3 additions & 4 deletions mypy/argmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
for ai, actual_kind in enumerate(actual_kinds):
if actual_kind == nodes.ARG_POS:
if fi < nformals:
if formal_kinds[fi] in [nodes.ARG_POS, nodes.ARG_OPT,
nodes.ARG_NAMED, nodes.ARG_NAMED_OPT]:
if not formal_kinds[fi].is_star():
formal_to_actual[fi].append(ai)
fi += 1
elif formal_kinds[fi] == nodes.ARG_STAR:
Expand All @@ -52,14 +51,14 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind],
# Assume that it is an iterable (if it isn't, there will be
# an error later).
while fi < nformals:
if formal_kinds[fi] in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT, nodes.ARG_STAR2):
if formal_kinds[fi].is_named(star=True):
break
else:
formal_to_actual[fi].append(ai)
if formal_kinds[fi] == nodes.ARG_STAR:
break
fi += 1
elif actual_kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
elif actual_kind.is_named():
assert actual_names is not None, "Internal error: named kinds without names given"
name = actual_names[ai]
if name in formal_names:
Expand Down
26 changes: 11 additions & 15 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
ParamSpecExpr,
ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
)
from mypy.literals import literal
from mypy import nodes
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def infer_arg_types_in_context(

for i, actuals in enumerate(formal_to_actual):
for ai in actuals:
if arg_kinds[ai] not in (nodes.ARG_STAR, nodes.ARG_STAR2):
if not arg_kinds[ai].is_star():
res[ai] = self.accept(args[ai], callee.arg_types[i])

# Fill in the rest of the argument types.
Expand Down Expand Up @@ -1364,18 +1364,14 @@ def check_argument_count(self,

# Check for too many or few values for formals.
for i, kind in enumerate(callee.arg_kinds):
if kind == nodes.ARG_POS and (not formal_to_actual[i] and
not is_unexpected_arg_error):
# No actual for a mandatory positional formal.
if kind.is_required() and not formal_to_actual[i] and not is_unexpected_arg_error:
# No actual for a mandatory formal
if messages:
messages.too_few_arguments(callee, context, actual_names)
ok = False
elif kind == nodes.ARG_NAMED and (not formal_to_actual[i] and
not is_unexpected_arg_error):
# No actual for a mandatory named formal
if messages:
argname = callee.arg_names[i] or "?"
messages.missing_named_argument(callee, context, argname)
if kind.is_positional():
messages.too_few_arguments(callee, context, actual_names)
else:
argname = callee.arg_names[i] or "?"
messages.missing_named_argument(callee, context, argname)
ok = False
elif not kind.is_star() and is_duplicate_mapping(
formal_to_actual[i], actual_types, actual_kinds):
Expand All @@ -1385,7 +1381,7 @@ def check_argument_count(self,
if messages:
messages.duplicate_argument_value(callee, i, context)
ok = False
elif (kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT) and formal_to_actual[i] and
elif (kind.is_named() and formal_to_actual[i] and
actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]):
# Positional argument when expecting a keyword argument.
if messages:
Expand Down Expand Up @@ -1925,7 +1921,7 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C
for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)):
if new_kind == target_kind:
continue
elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT):
elif new_kind.is_positional() and target_kind.is_positional():
new_kinds[i] = ARG_POS
else:
too_complex = True
Expand Down
5 changes: 2 additions & 3 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
is_protocol_implementation, find_member
)
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, INVARIANT, COVARIANT, CONTRAVARIANT
from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT
import mypy.typeops
from mypy import state
from mypy import meet
Expand Down Expand Up @@ -536,11 +536,10 @@ def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]:
"""
num_args = len(t.arg_types)
new_names = []
named = (ARG_NAMED, ARG_NAMED_OPT)
for i in range(num_args):
t_name = t.arg_names[i]
s_name = s.arg_names[i]
if t_name == s_name or t.arg_kinds[i] in named or s.arg_kinds[i] in named:
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
new_names.append(t_name)
else:
new_names.append(None)
Expand Down
8 changes: 4 additions & 4 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,12 +1711,12 @@ def format(typ: Type) -> str:
for arg_name, arg_type, arg_kind in zip(
func.arg_names, func.arg_types, func.arg_kinds):
if (arg_kind == ARG_POS and arg_name is None
or verbosity == 0 and arg_kind in (ARG_POS, ARG_OPT)):
or verbosity == 0 and arg_kind.is_positional()):

arg_strings.append(format(arg_type))
else:
constructor = ARG_CONSTRUCTOR_NAMES[arg_kind]
if arg_kind in (ARG_STAR, ARG_STAR2) or arg_name is None:
if arg_kind.is_star() or arg_name is None:
arg_strings.append("{}({})".format(
constructor,
format(arg_type)))
Expand Down Expand Up @@ -1849,7 +1849,7 @@ def [T <: int] f(self, x: int, y: T) -> None
for i in range(len(tp.arg_types)):
if s:
s += ', '
if tp.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not asterisk:
if tp.arg_kinds[i].is_named() and not asterisk:
s += '*, '
asterisk = True
if tp.arg_kinds[i] == ARG_STAR:
Expand All @@ -1861,7 +1861,7 @@ def [T <: int] f(self, x: int, y: T) -> None
if name:
s += name + ': '
s += format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT):
if tp.arg_kinds[i].is_optional():
s += ' = ...'

# If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list
Expand Down
20 changes: 20 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,26 @@ class ArgKind(Enum):
# In an argument list, keyword-only and also optional
ARG_NAMED_OPT = 5

def is_positional(self, star: bool = False) -> bool:
return (
self == ARG_POS
or self == ARG_OPT
or (star and self == ARG_STAR)
)

def is_named(self, star: bool = False) -> bool:
return (
self == ARG_NAMED
or self == ARG_NAMED_OPT
or (star and self == ARG_STAR2)
)

def is_required(self) -> bool:
return self == ARG_POS or self == ARG_NAMED

def is_optional(self) -> bool:
return self == ARG_OPT or self == ARG_NAMED_OPT

def is_star(self) -> bool:
return self == ARG_STAR or self == ARG_STAR2

Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict, NamedTuple, Optional

import mypy.plugin
from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var
from mypy.plugins.common import add_method_to_class
from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType

Expand Down Expand Up @@ -65,7 +65,7 @@ def _find_other_type(method: _MethodInfo) -> Type:
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind in (ARG_POS, ARG_OPT):
if arg_kind.is_positional():
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break
Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/singledispatch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mypy.messages import format_type
from mypy.plugins.common import add_method_to_class
from mypy.nodes import (
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, ARG_STAR, ARG_OPT, Context
ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context
)
from mypy.subtypes import is_subtype
from mypy.types import (
Expand Down Expand Up @@ -98,7 +98,7 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
)
return ctx.default_return_type

elif func_type.arg_kinds[0] not in (ARG_POS, ARG_OPT, ARG_STAR):
elif not func_type.arg_kinds[0].is_positional(star=True):
fail(
ctx,
'First argument to singledispatch function must be a positional argument',
Expand Down
4 changes: 2 additions & 2 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]:
extra: List[Tuple[str, List[mypy.nodes.Var]]] = []
for arg in o.arguments:
kind: mypy.nodes.ArgKind = arg.kind
if kind in (mypy.nodes.ARG_POS, mypy.nodes.ARG_NAMED):
if kind.is_required():
args.append(arg.variable)
elif kind in (mypy.nodes.ARG_OPT, mypy.nodes.ARG_NAMED_OPT):
elif kind.is_optional():
assert arg.initializer is not None
args.append(('default', [arg.variable, arg.initializer]))
elif kind == mypy.nodes.ARG_STAR:
Expand Down
5 changes: 2 additions & 3 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr,
ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo,
IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block,
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT
Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED,
)
from mypy.stubgenc import generate_stub_for_c_module
from mypy.stubutil import (
Expand Down Expand Up @@ -631,8 +631,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False,
if not isinstance(get_proper_type(annotated_type), AnyType):
annotation = ": {}".format(self.print_annotation(annotated_type))
if arg_.initializer:
if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*')
for arg in args):
if kind.is_named() and not any(arg.startswith('*') for arg in args):
args.append('*')
if not annotation:
typename = self.get_str_type_of_node(arg_.initializer, True, False)
Expand Down
14 changes: 7 additions & 7 deletions mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _verify_arg_default_value(
) -> Iterator[str]:
"""Checks whether argument default values are compatible."""
if runtime_arg.default != inspect.Parameter.empty:
if stub_arg.kind not in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT):
if stub_arg.kind.is_required():
yield (
'runtime argument "{}" has a default value but stub argument does not'.format(
runtime_arg.name
Expand Down Expand Up @@ -363,7 +363,7 @@ def _verify_arg_default_value(
)
)
else:
if stub_arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT):
if stub_arg.kind.is_optional():
yield (
'stub argument "{}" has a default value but runtime argument does not'.format(
stub_arg.variable.name
Expand Down Expand Up @@ -406,7 +406,7 @@ def has_default(arg: Any) -> bool:
if isinstance(arg, inspect.Parameter):
return arg.default != inspect.Parameter.empty
if isinstance(arg, nodes.Argument):
return arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT)
return arg.kind.is_optional()
raise AssertionError

def get_desc(arg: Any) -> str:
Expand All @@ -433,9 +433,9 @@ def from_funcitem(stub: nodes.FuncItem) -> "Signature[nodes.Argument]":
stub_sig: Signature[nodes.Argument] = Signature()
stub_args = maybe_strip_cls(stub.name, stub.arguments)
for stub_arg in stub_args:
if stub_arg.kind in (nodes.ARG_POS, nodes.ARG_OPT):
if stub_arg.kind.is_positional():
stub_sig.pos.append(stub_arg)
elif stub_arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
elif stub_arg.kind.is_named():
stub_sig.kwonly[stub_arg.variable.name] = stub_arg
elif stub_arg.kind == nodes.ARG_STAR:
stub_sig.varpos = stub_arg
Expand Down Expand Up @@ -531,9 +531,9 @@ def get_kind(arg_name: str) -> nodes.ArgKind:
initializer=None,
kind=get_kind(arg_name),
)
if arg.kind in (nodes.ARG_POS, nodes.ARG_OPT):
if arg.kind.is_positional():
sig.pos.append(arg)
elif arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT):
elif arg.kind.is_named():
sig.kwonly[arg.variable.name] = arg
elif arg.kind == nodes.ARG_STAR:
sig.varpos = arg
Expand Down
8 changes: 4 additions & 4 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# import mypy.solve
from mypy.nodes import (
FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT,
ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2

)
from mypy.maptype import map_instance_to_supertype
from mypy.expandtype import expand_type_by_instance
Expand Down Expand Up @@ -950,8 +950,8 @@ def _incompatible(left_arg: Optional[FormalArgument],

i = right_star.pos
assert i is not None
while i < len(left.arg_kinds) and left.arg_kinds[i] in (ARG_POS, ARG_OPT):
if allow_partial_overlap and left.arg_kinds[i] == ARG_OPT:
while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional():
if allow_partial_overlap and left.arg_kinds[i].is_optional():
break

left_by_position = left.argument_by_position(i)
Expand All @@ -970,7 +970,7 @@ def _incompatible(left_arg: Optional[FormalArgument],
right_names = {name for name in right.arg_names if name is not None}
left_only_names = set()
for name, kind in zip(left.arg_names, left.arg_kinds):
if name is None or kind in (ARG_STAR, ARG_STAR2) or name in right_names:
if name is None or kind.is_star() or name in right_names:
continue
left_only_names.add(name)

Expand Down
7 changes: 3 additions & 4 deletions mypy/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from mypy.build import State, Graph
from mypy.nodes import (
ArgKind, ARG_STAR, ARG_NAMED, ARG_STAR2, ARG_NAMED_OPT, FuncDef, MypyFile, SymbolTable,
ArgKind, ARG_STAR, ARG_STAR2, FuncDef, MypyFile, SymbolTable,
Decorator, RefExpr,
SymbolNode, TypeInfo, Expression, ReturnStmt, CallExpr,
reverse_builtin_aliases,
Expand Down Expand Up @@ -479,7 +479,7 @@ def format_args(self,
arg = '*' + arg
elif kind == ARG_STAR2:
arg = '**' + arg
elif kind in (ARG_NAMED, ARG_NAMED_OPT):
elif kind.is_named():
if name:
arg = "%s=%s" % (name, arg)
args.append(arg)
Expand Down Expand Up @@ -763,8 +763,7 @@ def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) ->

def is_tricky_callable(t: CallableType) -> bool:
"""Is t a callable that we need to put a ... in for syntax reasons?"""
return t.is_ellipsis_args or any(
k in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT) for k in t.arg_kinds)
return t.is_ellipsis_args or any(k.is_star() or k.is_named() for k in t.arg_kinds)


class TypeFormatter(TypeStrVisitor):
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,7 +778,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type],
assert found.fullname is not None
kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname]
kinds.append(kind)
if arg.name is not None and kind in {ARG_STAR, ARG_STAR2}:
if arg.name is not None and kind.is_star():
self.fail("{} arguments should not have names".format(
arg.constructor), arg)
return None
Expand Down
Loading

0 comments on commit a9f3b5e

Please sign in to comment.