Skip to content

Support type inference for defaultdict() #8167

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2813,14 +2813,26 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
partial_type = PartialType(None, name)
elif isinstance(init_type, Instance):
fullname = init_type.type.fullname
if (isinstance(lvalue, (NameExpr, MemberExpr)) and
is_ref = isinstance(lvalue, RefExpr)
if (is_ref and
(fullname == 'builtins.list' or
fullname == 'builtins.set' or
fullname == 'builtins.dict' or
fullname == 'collections.OrderedDict') and
all(isinstance(t, (NoneType, UninhabitedType))
for t in get_proper_types(init_type.args))):
partial_type = PartialType(init_type.type, name)
elif is_ref and fullname == 'collections.defaultdict':
arg0 = get_proper_type(init_type.args[0])
arg1 = get_proper_type(init_type.args[1])
if (isinstance(arg0, (NoneType, UninhabitedType)) and
isinstance(arg1, Instance) and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cleaner to move this check to the below helper.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. It's here since we rely on the narrowed down type below, so we'd need a type check anyway.

self.is_valid_defaultdict_partial_value_type(arg1)):
# Erase type argument, if one exists (this fills in Anys)
arg1 = self.named_type(arg1.type.fullname)
partial_type = PartialType(init_type.type, name, arg1)
else:
return False
else:
return False
else:
Expand All @@ -2829,6 +2841,28 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool
self.partial_types[-1].map[name] = lvalue
return True

def is_valid_defaultdict_partial_value_type(self, t: Instance) -> bool:
"""Check if t can be used as the basis for a partial defaultddict value type.

Examples:

* t is 'int' --> True
* t is 'list[<nothing>]' --> True
* t is 'dict[...]' --> False (only generic types with a single type
argument supported)
"""
if len(t.args) == 0:
return True
if len(t.args) == 1:
arg = get_proper_type(t.args[0])
# TODO: This is too permissive -- we only allow TypeVarType since
# they leak in cases like defaultdict(list) due to a bug.
# This can result in incorrect types being inferred, but only
# in rare cases.
if isinstance(arg, (TypeVarType, UninhabitedType, NoneType)):
return True
return False

def set_inferred_type(self, var: Var, lvalue: Lvalue, type: Type) -> None:
"""Store inferred variable type.

Expand Down Expand Up @@ -3018,16 +3052,21 @@ def try_infer_partial_type_from_indexed_assignment(
if partial_types is None:
return
typename = type_type.fullname
if typename == 'builtins.dict' or typename == 'collections.OrderedDict':
if (typename == 'builtins.dict'
or typename == 'collections.OrderedDict'
or typename == 'collections.defaultdict'):
# TODO: Don't infer things twice.
key_type = self.expr_checker.accept(lvalue.index)
value_type = self.expr_checker.accept(rvalue)
if (is_valid_inferred_type(key_type) and
is_valid_inferred_type(value_type)):
if not self.current_node_deferred:
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]
is_valid_inferred_type(value_type) and
not self.current_node_deferred and
not (typename == 'collections.defaultdict' and
var.type.value_type is not None and
not is_equivalent(value_type, var.type.value_type))):
var.type = self.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

def visit_expression_stmt(self, s: ExpressionStmt) -> None:
self.expr_checker.accept(s.expr, allow_none_return=True, always_allow_any=True)
Expand Down
125 changes: 90 additions & 35 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,42 +567,91 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
} # type: ClassVar[Dict[str, Dict[str, List[str]]]]

def try_infer_partial_type(self, e: CallExpr) -> None:
if isinstance(e.callee, MemberExpr) and isinstance(e.callee.expr, RefExpr):
var = e.callee.expr.node
if var is None and isinstance(e.callee.expr, MemberExpr):
var = self.get_partial_self_var(e.callee.expr)
if not isinstance(var, Var):
"""Try to make partial type precise from a call."""
if not isinstance(e.callee, MemberExpr):
return
callee = e.callee
if isinstance(callee.expr, RefExpr):
# Call a method with a RefExpr callee, such as 'x.method(...)'.
ret = self.get_partial_var(callee.expr)
if ret is None:
return
partial_types = self.chk.find_partial_types(var)
if partial_types is not None and not self.chk.current_node_deferred:
partial_type = var.type
if (partial_type is None or
not isinstance(partial_type, PartialType) or
partial_type.type is None):
# A partial None type -> can't infer anything.
return
typename = partial_type.type.fullname
methodname = e.callee.name
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
var.type = self.chk.named_generic_type(typename, [item_type])
del partial_types[var]
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
var.type = self.chk.named_generic_type(typename,
list(arg_type.args))
del partial_types[var]
var, partial_types = ret
typ = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if typ is not None:
var.type = typ
del partial_types[var]
elif isinstance(callee.expr, IndexExpr) and isinstance(callee.expr.base, RefExpr):
# Call 'x[y].method(...)'; may infer type of 'x' if it's a partial defaultdict.
if callee.expr.analyzed is not None:
return # A special form
base = callee.expr.base
index = callee.expr.index
ret = self.get_partial_var(base)
if ret is None:
return
var, partial_types = ret
partial_type = get_partial_instance_type(var.type)
if partial_type is None or partial_type.value_type is None:
return
value_type = self.try_infer_partial_value_type_from_call(e, callee.name, var)
if value_type is not None:
# Infer key type.
key_type = self.accept(index)
if mypy.checker.is_valid_inferred_type(key_type):
# Store inferred partial type.
assert partial_type.type is not None
typename = partial_type.type.fullname
var.type = self.chk.named_generic_type(typename,
[key_type, value_type])
del partial_types[var]

def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]:
var = ref.node
if var is None and isinstance(ref, MemberExpr):
var = self.get_partial_self_var(ref)
if not isinstance(var, Var):
return None
partial_types = self.chk.find_partial_types(var)
if partial_types is None:
return None
return var, partial_types

def try_infer_partial_value_type_from_call(
self,
e: CallExpr,
methodname: str,
var: Var) -> Optional[Instance]:
"""Try to make partial type precise from a call such as 'x.append(y)'."""
if self.chk.current_node_deferred:
return None
partial_type = get_partial_instance_type(var.type)
if partial_type is None:
return None
if partial_type.value_type:
typename = partial_type.value_type.type.fullname
else:
assert partial_type.type is not None
typename = partial_type.type.fullname
# Sometimes we can infer a full type for a partial List, Dict or Set type.
# TODO: Don't infer argument expression twice.
if (typename in self.item_args and methodname in self.item_args[typename]
and e.arg_kinds == [ARG_POS]):
item_type = self.accept(e.args[0])
if mypy.checker.is_valid_inferred_type(item_type):
return self.chk.named_generic_type(typename, [item_type])
elif (typename in self.container_args
and methodname in self.container_args[typename]
and e.arg_kinds == [ARG_POS]):
arg_type = get_proper_type(self.accept(e.args[0]))
if isinstance(arg_type, Instance):
arg_typename = arg_type.type.fullname
if arg_typename in self.container_args[typename][methodname]:
if all(mypy.checker.is_valid_inferred_type(item_type)
for item_type in arg_type.args):
return self.chk.named_generic_type(typename,
list(arg_type.args))
return None

def apply_function_plugin(self,
callee: CallableType,
Expand Down Expand Up @@ -4299,3 +4348,9 @@ def is_operator_method(fullname: Optional[str]) -> bool:
short_name in nodes.op_methods.values() or
short_name in nodes.reverse_op_methods.values() or
short_name in nodes.unary_op_methods.values())


def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]:
if t is None or not isinstance(t, PartialType) or t.type is None:
return None
return t
7 changes: 6 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1763,13 +1763,18 @@ class PartialType(ProperType):
# None for the 'None' partial type; otherwise a generic class
type = None # type: Optional[mypy.nodes.TypeInfo]
var = None # type: mypy.nodes.Var
# For partial defaultdict[K, V], the type V (K is unknown). If V is generic,
# the type argument is Any and will be replaced later.
value_type = None # type: Optional[Instance]

def __init__(self,
type: 'Optional[mypy.nodes.TypeInfo]',
var: 'mypy.nodes.Var') -> None:
var: 'mypy.nodes.Var',
value_type: 'Optional[Instance]' = None) -> None:
super().__init__()
self.type = type
self.var = var
self.value_type = value_type

def accept(self, visitor: 'TypeVisitor[T]') -> T:
return visitor.visit_partial_type(self)
Expand Down
3 changes: 3 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ellipsis: pass
# Primitive types are special in generated code.

class int:
@overload
def __init__(self) -> None: pass
@overload
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change to mypyc needed here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A test case relies on the int() constructor since it uses defaultdict(int).

def __init__(self, x: object, base: int = 10) -> None: pass
def __add__(self, n: int) -> int: pass
def __sub__(self, n: int) -> int: pass
Expand Down
94 changes: 94 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -2976,3 +2976,97 @@ x: Optional[str]
y = filter(None, [x])
reveal_type(y) # N: Revealed type is 'builtins.list[builtins.str*]'
[builtins fixtures/list.pyi]

[case testPartialDefaultDict]
from collections import defaultdict
x = defaultdict(int)
x[''] = 1
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a similar test case where value type is generic, e.g.:

x = defaultdict(list)  # No error
x['a'] = [1, 2, 3]

and

x = defaultdict(list)  # Error here
x['a'] = []


y = defaultdict(int) # E: Need type annotation for 'y'

z = defaultdict(int) # E: Need type annotation for 'z'
z[''] = ''
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictInconsistentValueTypes]
from collections import defaultdict
a = defaultdict(int) # E: Need type annotation for 'a'
a[''] = ''
a[''] = 1
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.int]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictListValue]
# flags: --no-strict-optional
from collections import defaultdict
a = defaultdict(list)
a['x'].append(1)
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

b = defaultdict(lambda: [])
b[1].append('x')
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictListValueStrictOptional]
# flags: --strict-optional
from collections import defaultdict
a = defaultdict(list)
a['x'].append(1)
reveal_type(a) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'

b = defaultdict(lambda: [])
b[1].append('x')
reveal_type(b) # N: Revealed type is 'collections.defaultdict[builtins.int, builtins.list[builtins.str]]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases]
from collections import defaultdict
class A:
def f(self) -> None:
self.x = defaultdict(list)
self.x['x'].append(1)
reveal_type(self.x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'
self.y = defaultdict(list) # E: Need type annotation for 'y'
s = self
s.y['x'].append(1)

x = {} # E: Need type annotation for 'x' (hint: "x: Dict[<type>, <type>] = ...")
x['x'].append(1)

y = defaultdict(list) # E: Need type annotation for 'y'
y[[]].append(1)
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCases2]
from collections import defaultdict

x = defaultdict(lambda: [1]) # E: Need type annotation for 'x'
x[1].append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
reveal_type(x) # N: Revealed type is 'collections.defaultdict[Any, builtins.list[builtins.int]]'

xx = defaultdict(lambda: {'x': 1}) # E: Need type annotation for 'xx'
xx[1]['z'] = 3
reveal_type(xx) # N: Revealed type is 'collections.defaultdict[Any, builtins.dict[builtins.str, builtins.int]]'

y = defaultdict(dict) # E: Need type annotation for 'y'
y['x'][1] = [3]

z = defaultdict(int) # E: Need type annotation for 'z'
z[1].append('')
reveal_type(z) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]

[case testPartialDefaultDictSpecialCase3]
from collections import defaultdict

x = defaultdict(list)
x['a'] = [1, 2, 3]
reveal_type(x) # N: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int*]]'

y = defaultdict(list) # E: Need type annotation for 'y'
y['a'] = []
reveal_type(y) # N: Revealed type is 'collections.defaultdict[Any, Any]'
[builtins fixtures/dict.pyi]
1 change: 1 addition & 0 deletions test-data/unit/fixtures/dict.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class list(Sequence[T]): # needed by some test cases
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __contains__(self, item: object) -> bool: pass
def append(self, item: T) -> None: pass

class tuple(Generic[T]): pass
class function: pass
Expand Down
12 changes: 7 additions & 5 deletions test-data/unit/lib-stub/collections.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Iterable, Union, Optional, Dict, TypeVar
from typing import Any, Iterable, Union, Optional, Dict, TypeVar, overload, Optional, Callable

def namedtuple(
typename: str,
Expand All @@ -10,8 +10,10 @@ def namedtuple(
defaults: Optional[Iterable[Any]] = ...
) -> Any: ...

K = TypeVar('K')
V = TypeVar('V')
KT = TypeVar('KT')
VT = TypeVar('VT')

class OrderedDict(Dict[K, V]):
def __setitem__(self, k: K, v: V) -> None: ...
class OrderedDict(Dict[KT, VT]): ...

class defaultdict(Dict[KT, VT]):
def __init__(self, default_factory: Optional[Callable[[], VT]]) -> None: ...
6 changes: 3 additions & 3 deletions test-data/unit/python2eval.test
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,11 @@ if MYPY:
x = b'abc'
[out]

[case testNestedGenericFailedInference]
[case testDefaultDictInference]
from collections import defaultdict
def foo() -> None:
x = defaultdict(list) # type: ignore
x = defaultdict(list)
x['lol'].append(10)
reveal_type(x)
[out]
_testNestedGenericFailedInference.py:5: note: Revealed type is 'collections.defaultdict[Any, builtins.list[Any]]'
_testDefaultDictInference.py:5: note: Revealed type is 'collections.defaultdict[builtins.str, builtins.list[builtins.int]]'