Skip to content

Commit df31e1c

Browse files
committed
[ty] Type inference for comprehensions
1 parent 1734ddf commit df31e1c

File tree

6 files changed

+185
-27
lines changed

6 files changed

+185
-27
lines changed

crates/ty_python_semantic/resources/mdtest/comprehensions/basic.md

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,94 @@ async def _():
103103
# revealed: Unknown
104104
[reveal_type(x) async for x in range(3)]
105105
```
106+
107+
## Comprehension expression types
108+
109+
The type of the comprehension expression itself should reflect the inferred element type:
110+
111+
```py
112+
from typing import TypedDict, Literal
113+
114+
# revealed: list[int]
115+
reveal_type([x for x in range(10)])
116+
117+
# revealed: set[int]
118+
reveal_type({x for x in range(10)})
119+
120+
# revealed: dict[int, str]
121+
reveal_type({x: str(x) for x in range(10)})
122+
123+
# revealed: list[tuple[int, Unknown | str]]
124+
reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]])
125+
126+
squares: list[int | None] = [x**2 for x in range(10)]
127+
reveal_type(squares) # revealed: list[int | None]
128+
```
129+
130+
Inference for comprehensions takes the type context into account:
131+
132+
```py
133+
# Without type context:
134+
reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int]
135+
reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str]
136+
reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str, Unknown | int]
137+
reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int]
138+
139+
# With type context:
140+
xs: list[int] = [x for x in [1, 2, 3]]
141+
reveal_type(xs) # revealed: list[int]
142+
143+
ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]}
144+
reveal_type(ys) # revealed: dict[int, str]
145+
146+
zs: set[int] = {x for x in [1, 2, 3]}
147+
```
148+
149+
This also works for nested comprehensions:
150+
151+
```py
152+
table = [[(x, y) for x in range(3)] for y in range(3)]
153+
reveal_type(table) # revealed: list[list[tuple[int, int]]]
154+
155+
# TODO: no error here
156+
# error: [invalid-assignment]
157+
table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)]
158+
reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]]
159+
```
160+
161+
The type context is propagated down into the comprehension:
162+
163+
```py
164+
class Person(TypedDict):
165+
name: str
166+
167+
persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]]
168+
reveal_type(persons) # revealed: list[Person]
169+
170+
# TODO: This should be an error
171+
invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]]
172+
```
173+
174+
We promote literals to avoid overly-precise types in invariant positions:
175+
176+
```py
177+
reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str]
178+
reveal_type({x for x in (1, 2, 3)}) # revealed: set[int]
179+
reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str, int]
180+
```
181+
182+
Type context can prevent this promotion from happening:
183+
184+
```py
185+
list_of_literals: list[Literal["a", "b", "c"]] = [x for x in ("a", "b", "c")]
186+
reveal_type(list_of_literals) # revealed: list[Literal["a", "b", "c"]]
187+
188+
dict_with_literal_keys: dict[Literal["a", "b", "c"], int] = {k: 0 for k in ("a", "b", "c")}
189+
reveal_type(dict_with_literal_keys) # revealed: dict[Literal["a", "b", "c"], int]
190+
191+
dict_with_literal_values: dict[str, Literal[1, 2, 3]] = {str(k): k for k in (1, 2, 3)}
192+
reveal_type(dict_with_literal_values) # revealed: dict[str, Literal[1, 2, 3]]
193+
194+
set_with_literals: set[Literal[1, 2, 3]] = {k for k in (1, 2, 3)}
195+
reveal_type(set_with_literals) # revealed: set[Literal[1, 2, 3]]
196+
```

crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)})
5151
## Dict comprehensions
5252

5353
```py
54-
# revealed: dict[@Todo(dict comprehension key type), @Todo(dict comprehension value type)]
54+
# revealed: dict[int, int]
5555
reveal_type({x: y for x, y in enumerate(range(42))})
5656
```

crates/ty_python_semantic/resources/mdtest/literal/collections/list.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)])
4141
## List comprehensions
4242

4343
```py
44-
reveal_type([x for x in range(42)]) # revealed: list[@Todo(list comprehension element type)]
44+
reveal_type([x for x in range(42)]) # revealed: list[int]
4545
```

crates/ty_python_semantic/resources/mdtest/literal/collections/set.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)})
3535
## Set comprehensions
3636

3737
```py
38-
reveal_type({x for x in range(42)}) # revealed: set[@Todo(set comprehension element type)]
38+
reveal_type({x for x in range(42)}) # revealed: set[int]
3939
```

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59435943
ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
59445944
ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
59455945
ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
5946-
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
5947-
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
5948-
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp),
5946+
ast::Expr::ListComp(listcomp) => {
5947+
self.infer_list_comprehension_expression(listcomp, tcx)
5948+
}
5949+
ast::Expr::DictComp(dictcomp) => {
5950+
self.infer_dict_comprehension_expression(dictcomp, tcx)
5951+
}
5952+
ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp, tcx),
59495953
ast::Expr::Name(name) => self.infer_name_expression(name),
59505954
ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute),
59515955
ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op),
@@ -6450,52 +6454,115 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
64506454
)
64516455
}
64526456

6453-
fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> {
6457+
/// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred
6458+
/// element / key-value types from the comprehension expression.
6459+
fn infer_comprehension_specialization(
6460+
&self,
6461+
collection_class: KnownClass,
6462+
inferred_element_types: &[Type<'db>],
6463+
tcx: TypeContext<'db>,
6464+
) -> Type<'db> {
6465+
// Remove any union elements of that are unrelated to the collection type.
6466+
let tcx = tcx.map(|annotation| {
6467+
annotation.filter_disjoint_elements(
6468+
self.db(),
6469+
collection_class.to_instance(self.db()),
6470+
InferableTypeVars::None,
6471+
)
6472+
});
6473+
6474+
if let Some(annotated_element_types) = tcx
6475+
.known_specialization(self.db(), collection_class)
6476+
.map(|specialization| specialization.types(self.db()))
6477+
&& annotated_element_types
6478+
.iter()
6479+
.zip(inferred_element_types.iter())
6480+
.all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated))
6481+
{
6482+
collection_class
6483+
.to_specialized_instance(self.db(), annotated_element_types.iter().copied())
6484+
} else {
6485+
collection_class.to_specialized_instance(
6486+
self.db(),
6487+
inferred_element_types
6488+
.iter()
6489+
.map(|ty| ty.promote_literals(self.db(), TypeContext::default())),
6490+
)
6491+
}
6492+
}
6493+
6494+
fn infer_list_comprehension_expression(
6495+
&mut self,
6496+
listcomp: &ast::ExprListComp,
6497+
tcx: TypeContext<'db>,
6498+
) -> Type<'db> {
64546499
let ast::ExprListComp {
64556500
range: _,
64566501
node_index: _,
6457-
elt: _,
6502+
elt,
64586503
generators,
64596504
} = listcomp;
64606505

64616506
self.infer_first_comprehension_iter(generators);
64626507

6463-
KnownClass::List
6464-
.to_specialized_instance(self.db(), [todo_type!("list comprehension element type")])
6508+
let scope_id = self
6509+
.index
6510+
.node_scope(NodeWithScopeRef::ListComprehension(listcomp));
6511+
let scope = scope_id.to_scope_id(self.db(), self.file());
6512+
let inference = infer_scope_types(self.db(), scope);
6513+
let element_type = inference.expression_type(elt.as_ref());
6514+
6515+
self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx)
64656516
}
64666517

6467-
fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> {
6518+
fn infer_dict_comprehension_expression(
6519+
&mut self,
6520+
dictcomp: &ast::ExprDictComp,
6521+
tcx: TypeContext<'db>,
6522+
) -> Type<'db> {
64686523
let ast::ExprDictComp {
64696524
range: _,
64706525
node_index: _,
6471-
key: _,
6472-
value: _,
6526+
key,
6527+
value,
64736528
generators,
64746529
} = dictcomp;
64756530

64766531
self.infer_first_comprehension_iter(generators);
64776532

6478-
KnownClass::Dict.to_specialized_instance(
6479-
self.db(),
6480-
[
6481-
todo_type!("dict comprehension key type"),
6482-
todo_type!("dict comprehension value type"),
6483-
],
6484-
)
6533+
let scope_id = self
6534+
.index
6535+
.node_scope(NodeWithScopeRef::DictComprehension(dictcomp));
6536+
let scope = scope_id.to_scope_id(self.db(), self.file());
6537+
let inference = infer_scope_types(self.db(), scope);
6538+
let key_type = inference.expression_type(key.as_ref());
6539+
let value_type = inference.expression_type(value.as_ref());
6540+
6541+
self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx)
64856542
}
64866543

6487-
fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> {
6544+
fn infer_set_comprehension_expression(
6545+
&mut self,
6546+
setcomp: &ast::ExprSetComp,
6547+
tcx: TypeContext<'db>,
6548+
) -> Type<'db> {
64886549
let ast::ExprSetComp {
64896550
range: _,
64906551
node_index: _,
6491-
elt: _,
6552+
elt,
64926553
generators,
64936554
} = setcomp;
64946555

64956556
self.infer_first_comprehension_iter(generators);
64966557

6497-
KnownClass::Set
6498-
.to_specialized_instance(self.db(), [todo_type!("set comprehension element type")])
6558+
let scope_id = self
6559+
.index
6560+
.node_scope(NodeWithScopeRef::SetComprehension(setcomp));
6561+
let scope = scope_id.to_scope_id(self.db(), self.file());
6562+
let inference = infer_scope_types(self.db(), scope);
6563+
let element_type = inference.expression_type(elt.as_ref());
6564+
6565+
self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx)
64996566
}
65006567

65016568
fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) {

crates/ty_python_semantic/src/types/infer/builder/type_expression.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
346346
}
347347

348348
ast::Expr::DictComp(dictcomp) => {
349-
self.infer_dict_comprehension_expression(dictcomp);
349+
self.infer_dict_comprehension_expression(dictcomp, TypeContext::default());
350350
self.report_invalid_type_expression(
351351
expression,
352352
format_args!("Dict comprehensions are not allowed in type expressions"),
@@ -355,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
355355
}
356356

357357
ast::Expr::ListComp(listcomp) => {
358-
self.infer_list_comprehension_expression(listcomp);
358+
self.infer_list_comprehension_expression(listcomp, TypeContext::default());
359359
self.report_invalid_type_expression(
360360
expression,
361361
format_args!("List comprehensions are not allowed in type expressions"),
@@ -364,7 +364,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
364364
}
365365

366366
ast::Expr::SetComp(setcomp) => {
367-
self.infer_set_comprehension_expression(setcomp);
367+
self.infer_set_comprehension_expression(setcomp, TypeContext::default());
368368
self.report_invalid_type_expression(
369369
expression,
370370
format_args!("Set comprehensions are not allowed in type expressions"),

0 commit comments

Comments
 (0)