Skip to content

Infer type for partial generic type from assignment #8036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, rvalue,
infer_lvalue_type)
else:
self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue)
lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue)
# If we're assigning to __getattr__ or similar methods, check that the signature is
# valid.
Expand Down Expand Up @@ -2141,6 +2142,37 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
rvalue_type = remove_instance_last_known_values(rvalue_type)
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)

def try_infer_partial_generic_type_from_assignment(self,
lvalue: Lvalue,
rvalue: Expression) -> None:
"""Try to infer a precise type for partial generic type from assignment.

Example where this happens:

x = []
if foo():
x = [1] # Infer List[int] as type of 'x'
"""
if (isinstance(lvalue, NameExpr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this actually be a NameExpr or RefExpr is fine? I am thinking about self.x = [], please add tests if yes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partial types are not properly supported with attributes in general, so this is consistent with how things work elsewhere.

and isinstance(lvalue.node, Var)
and isinstance(lvalue.node.type, PartialType)):
var = lvalue.node
typ = lvalue.node.type
if typ.type is None:
return
partial_types = self.find_partial_types(var)
if partial_types is None:
return
rvalue_type = self.expr_checker.accept(rvalue)
rvalue_type = get_proper_type(rvalue_type)
if isinstance(rvalue_type, Instance):
if rvalue_type.type == typ.type:
var.type = rvalue_type
del partial_types[var]
elif isinstance(rvalue_type, AnyType):
var.type = fill_typevars_with_any(typ.type)
del partial_types[var]

def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],
rvalue: Expression) -> bool:
lvalue_node = lvalue.node
Expand Down
14 changes: 11 additions & 3 deletions mypy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,14 +598,22 @@ def remove_duplicates(self, errors: List[ErrorTuple]) -> List[ErrorTuple]:
i = 0
while i < len(errors):
dup = False
# Use slightly special formatting for member conflicts reporting.
conflicts_notes = False
j = i - 1
while j >= 0 and errors[j][0] == errors[i][0]:
if errors[j][4].strip() == 'Got:':
conflicts_notes = True
j -= 1
j = i - 1
while (j >= 0 and errors[j][0] == errors[i][0] and
errors[j][1] == errors[i][1]):
if (errors[j][3] == errors[i][3] and
# Allow duplicate notes in overload conflicts reporting.
not (errors[i][3] == 'note' and
errors[i][4].strip() in allowed_duplicates
or errors[i][4].strip().startswith('def ')) and
not ((errors[i][3] == 'note' and
errors[i][4].strip() in allowed_duplicates)
or (errors[i][4].strip().startswith('def ') and
conflicts_notes)) and
errors[j][4] == errors[i][4]): # ignore column
dup = True
break
Expand Down
52 changes: 26 additions & 26 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1367,34 +1367,29 @@ a = []
a.append(1)
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyUsingUpdate]
a = []
a.extend([''])
a.append(0) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str"
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndNotAnnotated]
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndReadBeforeAppend]
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
if a: pass
a.xyz # E: "List[Any]" has no attribute "xyz"
a.append('')
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndIncompleteTypeInAppend]
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
a.append([])
a() # E: "List[Any]" not callable
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndMultipleAssignment]
a, b = [], []
Expand All @@ -1403,15 +1398,13 @@ b.append('')
a() # E: "List[int]" not callable
b() # E: "List[str]" not callable
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyInFunction]
def f() -> None:
a = []
a.append(1)
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndNotAnnotatedInFunction]
def f() -> None:
Expand All @@ -1422,7 +1415,6 @@ def g() -> None: pass
a = []
a.append(1)
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndReadBeforeAppendInFunction]
def f() -> None:
Expand All @@ -1431,15 +1423,13 @@ def f() -> None:
a.xyz # E: "List[Any]" has no attribute "xyz"
a.append('')
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyInClassBody]
class A:
a = []
a.append(1)
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndNotAnnotatedInClassBody]
class A:
Expand All @@ -1449,7 +1439,6 @@ class B:
a = []
a.append(1)
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyInMethod]
class A:
Expand All @@ -1458,14 +1447,12 @@ class A:
a.append(1)
a.append('') # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndNotAnnotatedInMethod]
class A:
def f(self) -> None:
a = [] # E: Need type annotation for 'a' (hint: "a: List[<type>] = ...")
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyInMethodViaAttribute]
class A:
Expand All @@ -1475,7 +1462,6 @@ class A:
self.a.append(1)
self.a.append('')
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyInClassBodyAndOverriden]
from typing import List
Expand All @@ -1490,57 +1476,49 @@ class B(A):
def x(self) -> List[int]: # E: Signature of "x" incompatible with supertype "A"
return [123]
[builtins fixtures/list.pyi]
[out]

[case testInferSetInitializedToEmpty]
a = set()
a.add(1)
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
[builtins fixtures/set.pyi]
[out]

[case testInferSetInitializedToEmptyUsingDiscard]
a = set()
a.discard('')
a.add(0) # E: Argument 1 to "add" of "set" has incompatible type "int"; expected "str"
[builtins fixtures/set.pyi]
[out]

[case testInferSetInitializedToEmptyUsingUpdate]
a = set()
a.update({0})
a.add('') # E: Argument 1 to "add" of "set" has incompatible type "str"; expected "int"
[builtins fixtures/set.pyi]
[out]

[case testInferDictInitializedToEmpty]
a = {}
a[1] = ''
a() # E: "Dict[int, str]" not callable
[builtins fixtures/dict.pyi]
[out]

[case testInferDictInitializedToEmptyUsingUpdate]
a = {}
a.update({'': 42})
a() # E: "Dict[str, int]" not callable
[builtins fixtures/dict.pyi]
[out]

[case testInferDictInitializedToEmptyUsingUpdateError]
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
a.update([1, 2]) # E: Argument 1 to "update" of "dict" has incompatible type "List[int]"; expected "Mapping[Any, Any]"
a() # E: "Dict[Any, Any]" not callable
[builtins fixtures/dict.pyi]
[out]

[case testInferDictInitializedToEmptyAndIncompleteTypeInUpdate]
a = {} # E: Need type annotation for 'a' (hint: "a: Dict[<type>, <type>] = ...")
a[1] = {}
b = {} # E: Need type annotation for 'b' (hint: "b: Dict[<type>, <type>] = ...")
b[{}] = 1
[builtins fixtures/dict.pyi]
[out]

[case testInferDictInitializedToEmptyAndUpdatedFromMethod]
map = {}
Expand All @@ -1557,20 +1535,42 @@ def add():
[case testSpecialCaseEmptyListInitialization]
def f(blocks: Any): # E: Name 'Any' is not defined \
# N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
to_process = []
to_process = list(blocks)
[builtins fixtures/list.pyi]
[out]

[case testSpecialCaseEmptyListInitialization2]
def f(blocks: object):
to_process = [] # E: Need type annotation for 'to_process' (hint: "to_process: List[<type>] = ...")
to_process = []
to_process = list(blocks) # E: No overload variant of "list" matches argument type "object" \
# N: Possible overload variant: \
# N: def [T] __init__(self, x: Iterable[T]) -> List[T] \
# N: <1 more non-matching overload not shown>
[builtins fixtures/list.pyi]
[out]

[case testInferListInitializedToEmptyAndAssigned]
a = []
if bool():
a = [1]
reveal_type(a) # N: Revealed type is 'builtins.list[builtins.int*]'

def f():
return [1]
b = []
if bool():
b = f()
reveal_type(b) # N: Revealed type is 'builtins.list[Any]'

d = {}
if bool():
d = {1: 'x'}
reveal_type(d) # N: Revealed type is 'builtins.dict[builtins.int*, builtins.str*]'

dd = {} # E: Need type annotation for 'dd' (hint: "dd: Dict[<type>, <type>] = ...")
if bool():
dd = [1] # E: Incompatible types in assignment (expression has type "List[int]", variable has type "Dict[Any, Any]")
reveal_type(dd) # N: Revealed type is 'builtins.dict[Any, Any]'
[builtins fixtures/dict.pyi]


-- Inferring types of variables first initialized to None (partial types)
Expand Down