Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use type information from isinstance checks in comprehensions #2000

Merged
merged 1 commit into from
Aug 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 53 additions & 40 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,55 +1531,68 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr,
type_name: str,
id_for_messages: str) -> Type:
"""Type check a generator expression or a list comprehension."""
self.check_for_comp(gen)
with self.chk.binder.frame_context():
self.check_for_comp(gen)

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
tv = TypeVarType(tvdef)
constructor = CallableType(
[tv],
[nodes.ARG_POS],
[None],
self.chk.named_generic_type(type_name, [tv]),
self.chk.named_type('builtins.function'),
name=id_for_messages,
variables=[tvdef])
return self.check_call(constructor,
[gen.left_expr], [nodes.ARG_POS], gen)[0]
# Infer the type of the list comprehension by using a synthetic generic
# callable type.
tvdef = TypeVarDef('T', -1, [], self.chk.object_type())
tv = TypeVarType(tvdef)
constructor = CallableType(
[tv],
[nodes.ARG_POS],
[None],
self.chk.named_generic_type(type_name, [tv]),
self.chk.named_type('builtins.function'),
name=id_for_messages,
variables=[tvdef])
return self.check_call(constructor,
[gen.left_expr], [nodes.ARG_POS], gen)[0]

def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type:
"""Type check a dictionary comprehension."""
self.check_for_comp(e)

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
kt = TypeVarType(ktdef)
vt = TypeVarType(vtdef)
constructor = CallableType(
[kt, vt],
[nodes.ARG_POS, nodes.ARG_POS],
[None, None],
self.chk.named_generic_type('builtins.dict', [kt, vt]),
self.chk.named_type('builtins.function'),
name='<dictionary-comprehension>',
variables=[ktdef, vtdef])
return self.check_call(constructor,
[e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0]
with self.chk.binder.frame_context():
self.check_for_comp(e)

# Infer the type of the list comprehension by using a synthetic generic
# callable type.
ktdef = TypeVarDef('KT', -1, [], self.chk.object_type())
vtdef = TypeVarDef('VT', -2, [], self.chk.object_type())
kt = TypeVarType(ktdef)
vt = TypeVarType(vtdef)
constructor = CallableType(
[kt, vt],
[nodes.ARG_POS, nodes.ARG_POS],
[None, None],
self.chk.named_generic_type('builtins.dict', [kt, vt]),
self.chk.named_type('builtins.function'),
name='<dictionary-comprehension>',
variables=[ktdef, vtdef])
return self.check_call(constructor,
[e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0]

def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> None:
"""Check the for_comp part of comprehensions. That is the part from 'for':
... for x in y if z

Note: This adds the type information derived from the condlists to the current binder.
"""
with self.chk.binder.frame_context():
for index, sequence, conditions in zip(e.indices, e.sequences,
e.condlists):
sequence_type = self.chk.analyze_iterable_item_type(sequence)
self.chk.analyze_index_variables(index, sequence_type, e)
for condition in conditions:
self.accept(condition)
for index, sequence, conditions in zip(e.indices, e.sequences,
e.condlists):
sequence_type = self.chk.analyze_iterable_item_type(sequence)
self.chk.analyze_index_variables(index, sequence_type, e)
for condition in conditions:
self.accept(condition)

# values are only part of the comprehension when all conditions are true
true_map, _ = mypy.checker.find_isinstance_check(
condition, self.chk.type_map,
self.chk.typing_mode_weak()
)

if true_map:
for var, type in true_map.items():
self.chk.binder.push(var, type)

def visit_conditional_expr(self, e: ConditionalExpr) -> Type:
cond_type = self.accept(e.cond)
Expand Down
4 changes: 2 additions & 2 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ class GeneratorExpr(Expression):
"""Generator expression ... for ... in ... [ for ... in ... ] [ if ... ]."""

left_expr = None # type: Expression
sequences_expr = None # type: List[Expression]
sequences = None # type: List[Expression]
condlists = None # type: List[List[Expression]]
indices = None # type: List[Expression]

Expand Down Expand Up @@ -1548,7 +1548,7 @@ class DictionaryComprehension(Expression):

key = None # type: Expression
value = None # type: Expression
sequences_expr = None # type: List[Expression]
sequences = None # type: List[Expression]
condlists = None # type: List[List[Expression]]
indices = None # type: List[Expression]

Expand Down
10 changes: 10 additions & 0 deletions test-data/unit/check-isinstance.test
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,13 @@ else:
1()
[builtins fixtures/isinstance.py]
[out]
[case testComprehensionIsInstance]
from typing import List, Union
a = [] # type: List[Union[int, str]]
l = [x for x in a if isinstance(x, int)]
g = (x for x in a if isinstance(x, int))
d = {0: x for x in a if isinstance(x, int)}
reveal_type(l) # E: Revealed type is 'builtins.list[builtins.int*]'
reveal_type(g) # E: Revealed type is 'typing.Iterator[builtins.int*]'
reveal_type(d) # E: Revealed type is 'builtins.dict[builtins.int*, builtins.int*]'
[builtins fixtures/isinstancelist.py]
13 changes: 12 additions & 1 deletion test-data/unit/fixtures/isinstancelist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import builtinclass, Iterable, Iterator, Generic, TypeVar, List
from typing import builtinclass, Iterable, Iterator, Generic, TypeVar, List, Mapping, overload, Tuple

@builtinclass
class object:
Expand All @@ -24,10 +24,21 @@ def __add__(self, x: str) -> str: pass
def __getitem__(self, x: int) -> str: pass

T = TypeVar('T')
KT = TypeVar('KT')
VT = TypeVar('VT')

class list(Iterable[T], Generic[T]):
def __iter__(self) -> Iterator[T]: pass
def __mul__(self, x: int) -> list[T]: pass
def __setitem__(self, x: int, v: T) -> None: pass
def __getitem__(self, x: int) -> T: pass
def __add__(self, x: List[T]) -> T: pass

class dict(Iterable[KT], Mapping[KT, VT], Generic[KT, VT]):
@overload
def __init__(self, **kwargs: VT) -> None: pass
@overload
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
def __setitem__(self, k: KT, v: VT) -> None: pass
def __iter__(self) -> Iterator[KT]: pass
def update(self, a: Mapping[KT, VT]) -> None: pass