From a9f3b5e5b21304c63f93e96e0bfeab2fa2945a00 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Sat, 10 Jul 2021 00:11:01 -0700 Subject: [PATCH] Use helper methods for a lot of ArgKind checks (#10793) Part of the motivation here is that I want to make positional-only arguments be properly reflected in the argument kinds, and having most of the logic done through helpers will make that easier. --- mypy/argmap.py | 7 +++---- mypy/checkexpr.py | 26 +++++++++++--------------- mypy/join.py | 5 ++--- mypy/messages.py | 8 ++++---- mypy/nodes.py | 20 ++++++++++++++++++++ mypy/plugins/functools.py | 4 ++-- mypy/plugins/singledispatch.py | 4 ++-- mypy/strconv.py | 4 ++-- mypy/stubgen.py | 5 ++--- mypy/stubtest.py | 14 +++++++------- mypy/subtypes.py | 8 ++++---- mypy/suggestions.py | 7 +++---- mypy/typeanal.py | 2 +- mypy/types.py | 26 ++++++++++++-------------- mypyc/ir/func_ir.py | 4 ++-- mypyc/irbuild/function.py | 4 ++-- mypyc/irbuild/ll_builder.py | 10 ++++------ 17 files changed, 83 insertions(+), 75 deletions(-) diff --git a/mypy/argmap.py b/mypy/argmap.py index bed41de3c426..d9453aa0b640 100644 --- a/mypy/argmap.py +++ b/mypy/argmap.py @@ -29,8 +29,7 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], for ai, actual_kind in enumerate(actual_kinds): if actual_kind == nodes.ARG_POS: if fi < nformals: - if formal_kinds[fi] in [nodes.ARG_POS, nodes.ARG_OPT, - nodes.ARG_NAMED, nodes.ARG_NAMED_OPT]: + if not formal_kinds[fi].is_star(): formal_to_actual[fi].append(ai) fi += 1 elif formal_kinds[fi] == nodes.ARG_STAR: @@ -52,14 +51,14 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], # Assume that it is an iterable (if it isn't, there will be # an error later). while fi < nformals: - if formal_kinds[fi] in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT, nodes.ARG_STAR2): + if formal_kinds[fi].is_named(star=True): break else: formal_to_actual[fi].append(ai) if formal_kinds[fi] == nodes.ARG_STAR: break fi += 1 - elif actual_kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT): + elif actual_kind.is_named(): assert actual_names is not None, "Internal error: named kinds without names given" name = actual_names[ai] if name in formal_names: diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 2d78adee3e3d..065dfa3e0f69 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -32,7 +32,7 @@ YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, ParamSpecExpr, - ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, + ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, ) from mypy.literals import literal from mypy import nodes @@ -1111,7 +1111,7 @@ def infer_arg_types_in_context( for i, actuals in enumerate(formal_to_actual): for ai in actuals: - if arg_kinds[ai] not in (nodes.ARG_STAR, nodes.ARG_STAR2): + if not arg_kinds[ai].is_star(): res[ai] = self.accept(args[ai], callee.arg_types[i]) # Fill in the rest of the argument types. @@ -1364,18 +1364,14 @@ def check_argument_count(self, # Check for too many or few values for formals. for i, kind in enumerate(callee.arg_kinds): - if kind == nodes.ARG_POS and (not formal_to_actual[i] and - not is_unexpected_arg_error): - # No actual for a mandatory positional formal. + if kind.is_required() and not formal_to_actual[i] and not is_unexpected_arg_error: + # No actual for a mandatory formal if messages: - messages.too_few_arguments(callee, context, actual_names) - ok = False - elif kind == nodes.ARG_NAMED and (not formal_to_actual[i] and - not is_unexpected_arg_error): - # No actual for a mandatory named formal - if messages: - argname = callee.arg_names[i] or "?" - messages.missing_named_argument(callee, context, argname) + if kind.is_positional(): + messages.too_few_arguments(callee, context, actual_names) + else: + argname = callee.arg_names[i] or "?" + messages.missing_named_argument(callee, context, argname) ok = False elif not kind.is_star() and is_duplicate_mapping( formal_to_actual[i], actual_types, actual_kinds): @@ -1385,7 +1381,7 @@ def check_argument_count(self, if messages: messages.duplicate_argument_value(callee, i, context) ok = False - elif (kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT) and formal_to_actual[i] and + elif (kind.is_named() and formal_to_actual[i] and actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]): # Positional argument when expecting a keyword argument. if messages: @@ -1925,7 +1921,7 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C for i, (new_kind, target_kind) in enumerate(zip(new_kinds, target.arg_kinds)): if new_kind == target_kind: continue - elif new_kind in (ARG_POS, ARG_OPT) and target_kind in (ARG_POS, ARG_OPT): + elif new_kind.is_positional() and target_kind.is_positional(): new_kinds[i] = ARG_POS else: too_complex = True diff --git a/mypy/join.py b/mypy/join.py index bad256ccf11c..2cbc1a9edc8f 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -14,7 +14,7 @@ is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype, is_protocol_implementation, find_member ) -from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, INVARIANT, COVARIANT, CONTRAVARIANT +from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT import mypy.typeops from mypy import state from mypy import meet @@ -536,11 +536,10 @@ def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]: """ num_args = len(t.arg_types) new_names = [] - named = (ARG_NAMED, ARG_NAMED_OPT) for i in range(num_args): t_name = t.arg_names[i] s_name = s.arg_names[i] - if t_name == s_name or t.arg_kinds[i] in named or s.arg_kinds[i] in named: + if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named(): new_names.append(t_name) else: new_names.append(None) diff --git a/mypy/messages.py b/mypy/messages.py index 39241925aa25..00567a4f16a8 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -1711,12 +1711,12 @@ def format(typ: Type) -> str: for arg_name, arg_type, arg_kind in zip( func.arg_names, func.arg_types, func.arg_kinds): if (arg_kind == ARG_POS and arg_name is None - or verbosity == 0 and arg_kind in (ARG_POS, ARG_OPT)): + or verbosity == 0 and arg_kind.is_positional()): arg_strings.append(format(arg_type)) else: constructor = ARG_CONSTRUCTOR_NAMES[arg_kind] - if arg_kind in (ARG_STAR, ARG_STAR2) or arg_name is None: + if arg_kind.is_star() or arg_name is None: arg_strings.append("{}({})".format( constructor, format(arg_type))) @@ -1849,7 +1849,7 @@ def [T <: int] f(self, x: int, y: T) -> None for i in range(len(tp.arg_types)): if s: s += ', ' - if tp.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not asterisk: + if tp.arg_kinds[i].is_named() and not asterisk: s += '*, ' asterisk = True if tp.arg_kinds[i] == ARG_STAR: @@ -1861,7 +1861,7 @@ def [T <: int] f(self, x: int, y: T) -> None if name: s += name + ': ' s += format_type_bare(tp.arg_types[i]) - if tp.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT): + if tp.arg_kinds[i].is_optional(): s += ' = ...' # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list diff --git a/mypy/nodes.py b/mypy/nodes.py index f9fb6ad13370..ef58ebda22ef 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1534,6 +1534,26 @@ class ArgKind(Enum): # In an argument list, keyword-only and also optional ARG_NAMED_OPT = 5 + def is_positional(self, star: bool = False) -> bool: + return ( + self == ARG_POS + or self == ARG_OPT + or (star and self == ARG_STAR) + ) + + def is_named(self, star: bool = False) -> bool: + return ( + self == ARG_NAMED + or self == ARG_NAMED_OPT + or (star and self == ARG_STAR2) + ) + + def is_required(self) -> bool: + return self == ARG_POS or self == ARG_NAMED + + def is_optional(self) -> bool: + return self == ARG_OPT or self == ARG_NAMED_OPT + def is_star(self) -> bool: return self == ARG_STAR or self == ARG_STAR2 diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index bf71465e1003..0984abe80cee 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -2,7 +2,7 @@ from typing import Dict, NamedTuple, Optional import mypy.plugin -from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR2, Argument, FuncItem, Var +from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var from mypy.plugins.common import add_method_to_class from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType @@ -65,7 +65,7 @@ def _find_other_type(method: _MethodInfo) -> Type: cur_pos_arg = 0 other_arg = None for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types): - if arg_kind in (ARG_POS, ARG_OPT): + if arg_kind.is_positional(): if cur_pos_arg == first_arg_pos: other_arg = arg_type break diff --git a/mypy/plugins/singledispatch.py b/mypy/plugins/singledispatch.py index 6050d8843e04..93ad92f42bb7 100644 --- a/mypy/plugins/singledispatch.py +++ b/mypy/plugins/singledispatch.py @@ -1,7 +1,7 @@ from mypy.messages import format_type from mypy.plugins.common import add_method_to_class from mypy.nodes import ( - ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, ARG_STAR, ARG_OPT, Context + ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context ) from mypy.subtypes import is_subtype from mypy.types import ( @@ -98,7 +98,7 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type: ) return ctx.default_return_type - elif func_type.arg_kinds[0] not in (ARG_POS, ARG_OPT, ARG_STAR): + elif not func_type.arg_kinds[0].is_positional(star=True): fail( ctx, 'First argument to singledispatch function must be a positional argument', diff --git a/mypy/strconv.py b/mypy/strconv.py index 15649eac9aa0..c63063af0776 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -62,9 +62,9 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]: extra: List[Tuple[str, List[mypy.nodes.Var]]] = [] for arg in o.arguments: kind: mypy.nodes.ArgKind = arg.kind - if kind in (mypy.nodes.ARG_POS, mypy.nodes.ARG_NAMED): + if kind.is_required(): args.append(arg.variable) - elif kind in (mypy.nodes.ARG_OPT, mypy.nodes.ARG_NAMED_OPT): + elif kind.is_optional(): assert arg.initializer is not None args.append(('default', [arg.variable, arg.initializer])) elif kind == mypy.nodes.ARG_STAR: diff --git a/mypy/stubgen.py b/mypy/stubgen.py index bc4fd7cd51ff..e51bfabf438d 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -73,7 +73,7 @@ TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo, IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block, - Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT + Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, ) from mypy.stubgenc import generate_stub_for_c_module from mypy.stubutil import ( @@ -631,8 +631,7 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, if not isinstance(get_proper_type(annotated_type), AnyType): annotation = ": {}".format(self.print_annotation(annotated_type)) if arg_.initializer: - if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*') - for arg in args): + if kind.is_named() and not any(arg.startswith('*') for arg in args): args.append('*') if not annotation: typename = self.get_str_type_of_node(arg_.initializer, True, False) diff --git a/mypy/stubtest.py b/mypy/stubtest.py index 9c83080663e9..ac81f3a34604 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -334,7 +334,7 @@ def _verify_arg_default_value( ) -> Iterator[str]: """Checks whether argument default values are compatible.""" if runtime_arg.default != inspect.Parameter.empty: - if stub_arg.kind not in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT): + if stub_arg.kind.is_required(): yield ( 'runtime argument "{}" has a default value but stub argument does not'.format( runtime_arg.name @@ -363,7 +363,7 @@ def _verify_arg_default_value( ) ) else: - if stub_arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT): + if stub_arg.kind.is_optional(): yield ( 'stub argument "{}" has a default value but runtime argument does not'.format( stub_arg.variable.name @@ -406,7 +406,7 @@ def has_default(arg: Any) -> bool: if isinstance(arg, inspect.Parameter): return arg.default != inspect.Parameter.empty if isinstance(arg, nodes.Argument): - return arg.kind in (nodes.ARG_OPT, nodes.ARG_NAMED_OPT) + return arg.kind.is_optional() raise AssertionError def get_desc(arg: Any) -> str: @@ -433,9 +433,9 @@ def from_funcitem(stub: nodes.FuncItem) -> "Signature[nodes.Argument]": stub_sig: Signature[nodes.Argument] = Signature() stub_args = maybe_strip_cls(stub.name, stub.arguments) for stub_arg in stub_args: - if stub_arg.kind in (nodes.ARG_POS, nodes.ARG_OPT): + if stub_arg.kind.is_positional(): stub_sig.pos.append(stub_arg) - elif stub_arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT): + elif stub_arg.kind.is_named(): stub_sig.kwonly[stub_arg.variable.name] = stub_arg elif stub_arg.kind == nodes.ARG_STAR: stub_sig.varpos = stub_arg @@ -531,9 +531,9 @@ def get_kind(arg_name: str) -> nodes.ArgKind: initializer=None, kind=get_kind(arg_name), ) - if arg.kind in (nodes.ARG_POS, nodes.ARG_OPT): + if arg.kind.is_positional(): sig.pos.append(arg) - elif arg.kind in (nodes.ARG_NAMED, nodes.ARG_NAMED_OPT): + elif arg.kind.is_named(): sig.kwonly[arg.variable.name] = arg elif arg.kind == nodes.ARG_STAR: sig.varpos = arg diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7b54e0f83c79..cb630c68b62a 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -18,7 +18,7 @@ # import mypy.solve from mypy.nodes import ( FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT, - ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2 + ) from mypy.maptype import map_instance_to_supertype from mypy.expandtype import expand_type_by_instance @@ -950,8 +950,8 @@ def _incompatible(left_arg: Optional[FormalArgument], i = right_star.pos assert i is not None - while i < len(left.arg_kinds) and left.arg_kinds[i] in (ARG_POS, ARG_OPT): - if allow_partial_overlap and left.arg_kinds[i] == ARG_OPT: + while i < len(left.arg_kinds) and left.arg_kinds[i].is_positional(): + if allow_partial_overlap and left.arg_kinds[i].is_optional(): break left_by_position = left.argument_by_position(i) @@ -970,7 +970,7 @@ def _incompatible(left_arg: Optional[FormalArgument], right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): - if name is None or kind in (ARG_STAR, ARG_STAR2) or name in right_names: + if name is None or kind.is_star() or name in right_names: continue left_only_names.add(name) diff --git a/mypy/suggestions.py b/mypy/suggestions.py index f3602b2fe3d4..24a1e0eb0c3e 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -37,7 +37,7 @@ ) from mypy.build import State, Graph from mypy.nodes import ( - ArgKind, ARG_STAR, ARG_NAMED, ARG_STAR2, ARG_NAMED_OPT, FuncDef, MypyFile, SymbolTable, + ArgKind, ARG_STAR, ARG_STAR2, FuncDef, MypyFile, SymbolTable, Decorator, RefExpr, SymbolNode, TypeInfo, Expression, ReturnStmt, CallExpr, reverse_builtin_aliases, @@ -479,7 +479,7 @@ def format_args(self, arg = '*' + arg elif kind == ARG_STAR2: arg = '**' + arg - elif kind in (ARG_NAMED, ARG_NAMED_OPT): + elif kind.is_named(): if name: arg = "%s=%s" % (name, arg) args.append(arg) @@ -763,8 +763,7 @@ def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) -> def is_tricky_callable(t: CallableType) -> bool: """Is t a callable that we need to put a ... in for syntax reasons?""" - return t.is_ellipsis_args or any( - k in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT) for k in t.arg_kinds) + return t.is_ellipsis_args or any(k.is_star() or k.is_named() for k in t.arg_kinds) class TypeFormatter(TypeStrVisitor): diff --git a/mypy/typeanal.py b/mypy/typeanal.py index e105e8082cc4..82b2f08e0e1e 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -778,7 +778,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], assert found.fullname is not None kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname] kinds.append(kind) - if arg.name is not None and kind in {ARG_STAR, ARG_STAR2}: + if arg.name is not None and kind.is_star(): self.fail("{} arguments should not have names".format( arg.constructor), arg) return None diff --git a/mypy/types.py b/mypy/types.py index a9e652091dff..61910007cde3 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -14,9 +14,8 @@ import mypy.nodes from mypy import state from mypy.nodes import ( - INVARIANT, SymbolNode, ArgKind, - ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT, - FuncDef, + INVARIANT, SymbolNode, FuncDef, + ArgKind, ARG_POS, ARG_STAR, ARG_STAR2, ) from mypy.util import IdMapper from mypy.bogus_type import Bogus @@ -1178,8 +1177,7 @@ def max_possible_positional_args(self) -> int: This takes into account *arg and **kwargs but excludes keyword-only args.""" if self.is_var_arg or self.is_kw_arg: return sys.maxsize - blacklist = (ARG_NAMED, ARG_NAMED_OPT) - return len([kind not in blacklist for kind in self.arg_kinds]) + return sum([kind.is_positional() for kind in self.arg_kinds]) def formal_arguments(self, include_star_args: bool = False) -> Iterator[FormalArgument]: """Yields the formal arguments corresponding to this callable, ignoring *arg and **kwargs. @@ -1192,12 +1190,12 @@ def formal_arguments(self, include_star_args: bool = False) -> Iterator[FormalAr done_with_positional = False for i in range(len(self.arg_types)): kind = self.arg_kinds[i] - if kind in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT): + if kind.is_named() or kind.is_star(): done_with_positional = True - if not include_star_args and kind in (ARG_STAR, ARG_STAR2): + if not include_star_args and kind.is_star(): continue - required = kind in (ARG_POS, ARG_NAMED) + required = kind.is_required() pos = None if done_with_positional else i yield FormalArgument( self.arg_names[i], @@ -1212,13 +1210,13 @@ def argument_by_name(self, name: Optional[str]) -> Optional[FormalArgument]: for i, (arg_name, kind, typ) in enumerate( zip(self.arg_names, self.arg_kinds, self.arg_types)): # No more positional arguments after these. - if kind in (ARG_STAR, ARG_STAR2, ARG_NAMED, ARG_NAMED_OPT): + if kind.is_named() or kind.is_star(): seen_star = True - if kind == ARG_STAR or kind == ARG_STAR2: + if kind.is_star(): continue if arg_name == name: position = None if seen_star else i - return FormalArgument(name, position, typ, kind in (ARG_POS, ARG_NAMED)) + return FormalArgument(name, position, typ, kind.is_required()) return self.try_synthesizing_arg_from_kwarg(name) def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgument]: @@ -1231,7 +1229,7 @@ def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgume self.arg_kinds[position], self.arg_types[position], ) - if kind in (ARG_POS, ARG_OPT): + if kind.is_positional(): return FormalArgument(name, position, typ, kind == ARG_POS) else: return self.try_synthesizing_arg_from_vararg(position) @@ -2112,7 +2110,7 @@ def visit_callable_type(self, t: CallableType) -> str: for i in range(len(t.arg_types)): if s != '': s += ', ' - if t.arg_kinds[i] in (ARG_NAMED, ARG_NAMED_OPT) and not bare_asterisk: + if t.arg_kinds[i].is_named() and not bare_asterisk: s += '*, ' bare_asterisk = True if t.arg_kinds[i] == ARG_STAR: @@ -2123,7 +2121,7 @@ def visit_callable_type(self, t: CallableType) -> str: if name: s += name + ': ' s += t.arg_types[i].accept(self) - if t.arg_kinds[i] in (ARG_OPT, ARG_NAMED_OPT): + if t.arg_kinds[i].is_optional(): s += ' =' s = '({})'.format(s) diff --git a/mypyc/ir/func_ir.py b/mypyc/ir/func_ir.py index 604161c31659..aa432dc7fb40 100644 --- a/mypyc/ir/func_ir.py +++ b/mypyc/ir/func_ir.py @@ -3,7 +3,7 @@ from typing import List, Optional, Sequence from typing_extensions import Final -from mypy.nodes import FuncDef, Block, ArgKind, ARG_POS, ARG_OPT, ARG_NAMED_OPT +from mypy.nodes import FuncDef, Block, ArgKind, ARG_POS from mypyc.common import JsonDict from mypyc.ir.ops import ( @@ -26,7 +26,7 @@ def __init__(self, name: str, typ: RType, kind: ArgKind = ARG_POS) -> None: @property def optional(self) -> bool: - return self.kind == ARG_OPT or self.kind == ARG_NAMED_OPT + return self.kind.is_optional() def __repr__(self) -> str: return 'RuntimeArg(name=%s, type=%s, optional=%r)' % (self.name, self.type, self.optional) diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index e70202115888..f1a76462b322 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -14,7 +14,7 @@ from mypy.nodes import ( ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, - FuncItem, LambdaExpr, SymbolNode, ArgKind, ARG_NAMED, ARG_NAMED_OPT, TypeInfo + FuncItem, LambdaExpr, SymbolNode, ArgKind, TypeInfo ) from mypy.types import CallableType, get_proper_type @@ -649,7 +649,7 @@ def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> Ar fake_vars = [(Var(arg.name), arg.type) for arg in rt_args] args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line) for var, type in fake_vars] - arg_names = [arg.name if arg.kind in (ARG_NAMED, ARG_NAMED_OPT) else None + arg_names = [arg.name if arg.kind.is_named() else None for arg in rt_args] arg_kinds = [concrete_arg_kind(arg.kind) for arg in rt_args] return ArgInfo(args, arg_names, arg_kinds) diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 8b8353133d68..8ffc33cf0cf4 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -325,7 +325,7 @@ def _py_vector_call(self, API should be used instead. """ # We can do this if all args are positional or named (no *args or **kwargs). - if arg_kinds is None or all(kind in (ARG_POS, ARG_NAMED) for kind in arg_kinds): + if arg_kinds is None or all(not kind.is_star() for kind in arg_kinds): if arg_values: # Create a C array containing all arguments as boxed values. array = Register(RArray(object_rprimitive, len(arg_values))) @@ -396,7 +396,7 @@ def _py_vector_method_call(self, Return the return value if successful. Return None if a non-vectorcall API should be used instead. """ - if arg_kinds is None or all(kind in (ARG_POS, ARG_NAMED) for kind in arg_kinds): + if arg_kinds is None or all(not kind.is_star() for kind in arg_kinds): method_name_reg = self.load_str(method_name) array = Register(RArray(object_rprimitive, len(arg_values) + 1)) self_arg = self.coerce(obj, object_rprimitive, line) @@ -485,10 +485,8 @@ def gen_method_call(self, arg_kinds: Optional[List[ArgKind]] = None, arg_names: Optional[List[Optional[str]]] = None) -> Value: """Generate either a native or Python method call.""" - # If arg_kinds contains values other than arg_pos and arg_named, then fallback to - # Python method call. - if (arg_kinds is not None - and not all(kind in (ARG_POS, ARG_NAMED) for kind in arg_kinds)): + # If we have *args, then fallback to Python method call. + if (arg_kinds is not None and any(kind.is_star() for kind in arg_kinds)): return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names) # If the base type is one of ours, do a MethodCall