Skip to content

Commit 1ade4f2

Browse files
authored
[ty] Avoid unnecessarily widening generic specializations (#20875)
## Summary Ignore the type context when specializing a generic call if it leads to an unnecessarily wide return type. For example, [the example mentioned here](#20796 (comment)) works as expected after this change: ```py def id[T](x: T) -> T: return x def _(i: int): x: int | None = id(i) y: int | None = i reveal_type(x) # revealed: int reveal_type(y) # revealed: int ``` I also added extended our usage of `filter_disjoint_elements` to tuple and typed-dict inference, which resolves astral-sh/ty#1266.
1 parent 8dad58d commit 1ade4f2

File tree

8 files changed

+155
-57
lines changed

8 files changed

+155
-57
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
190190
reveal_type(k) # revealed: list[tuple[list[int], ...]]
191191

192192
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
193-
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
194-
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
193+
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]
195194

196195
type IntList = list[int]
197196

@@ -416,13 +415,14 @@ a = f("a")
416415
reveal_type(a) # revealed: list[Literal["a"]]
417416

418417
b: list[int | Literal["a"]] = f("a")
419-
reveal_type(b) # revealed: list[int | Literal["a"]]
418+
reveal_type(b) # revealed: list[Literal["a"] | int]
420419

421420
c: list[int | str] = f("a")
422-
reveal_type(c) # revealed: list[int | str]
421+
reveal_type(c) # revealed: list[str | int]
423422

424423
d: list[int | tuple[int, int]] = f((1, 2))
425-
reveal_type(d) # revealed: list[int | tuple[int, int]]
424+
# TODO: We could avoid reordering the union elements here.
425+
reveal_type(d) # revealed: list[tuple[int, int] | int]
426426

427427
e: list[int] = f(True)
428428
reveal_type(e) # revealed: list[int]
@@ -437,8 +437,49 @@ def f2[T: int](x: T) -> T:
437437
return x
438438

439439
i: int = f2(True)
440-
reveal_type(i) # revealed: int
440+
reveal_type(i) # revealed: Literal[True]
441441

442442
j: int | str = f2(True)
443443
reveal_type(j) # revealed: Literal[True]
444444
```
445+
446+
Types are not widened unnecessarily:
447+
448+
```py
449+
def id[T](x: T) -> T:
450+
return x
451+
452+
def lst[T](x: T) -> list[T]:
453+
return [x]
454+
455+
def _(i: int):
456+
a: int | None = i
457+
b: int | None = id(i)
458+
c: int | str | None = id(i)
459+
reveal_type(a) # revealed: int
460+
reveal_type(b) # revealed: int
461+
reveal_type(c) # revealed: int
462+
463+
a: list[int | None] | None = [i]
464+
b: list[int | None] | None = id([i])
465+
c: list[int | None] | int | None = id([i])
466+
reveal_type(a) # revealed: list[int | None]
467+
# TODO: these should reveal `list[int | None]`
468+
# we currently do not use the call expression annotation as type context for argument inference
469+
reveal_type(b) # revealed: list[Unknown | int]
470+
reveal_type(c) # revealed: list[Unknown | int]
471+
472+
a: list[int | None] | None = [i]
473+
b: list[int | None] | None = lst(i)
474+
c: list[int | None] | int | None = lst(i)
475+
reveal_type(a) # revealed: list[int | None]
476+
reveal_type(b) # revealed: list[int | None]
477+
reveal_type(c) # revealed: list[int | None]
478+
479+
a: list | None = []
480+
b: list | None = id([])
481+
c: list | int | None = id([])
482+
reveal_type(a) # revealed: list[Unknown]
483+
reveal_type(b) # revealed: list[Unknown]
484+
reveal_type(c) # revealed: list[Unknown]
485+
```

crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class Member:
1111
role: str = field(default="user")
1212
tag: str | None = field(default=None, init=False)
1313

14-
# revealed: (self: Member, name: str, role: str = str) -> None
14+
# revealed: (self: Member, name: str, role: str = Literal["user"]) -> None
1515
reveal_type(Member.__init__)
1616

1717
alice = Member(name="Alice", role="admin")
@@ -37,7 +37,7 @@ class Data:
3737
content: list[int] = field(default_factory=list)
3838
timestamp: datetime = field(default_factory=datetime.now, init=False)
3939

40-
# revealed: (self: Data, content: list[int] = list[int]) -> None
40+
# revealed: (self: Data, content: list[int] = Unknown) -> None
4141
reveal_type(Data.__init__)
4242

4343
data = Data([1, 2, 3])
@@ -64,7 +64,7 @@ class Person:
6464
role: str = field(default="user", kw_only=True)
6565

6666
# TODO: this would ideally show a default value of `None` for `age`
67-
# revealed: (self: Person, name: str, *, age: int | None = int | None, role: str = str) -> None
67+
# revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None
6868
reveal_type(Person.__init__)
6969

7070
alice = Person(role="admin", name="Alice")

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}
907907

908908
nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}
909909

910-
# TODO: this should be an error (invalid type for `name` in innermost node)
910+
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
911911
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
912912
```
913913

crates/ty_python_semantic/src/types.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,24 +1233,37 @@ impl<'db> Type<'db> {
12331233
if yes { self.negate(db) } else { *self }
12341234
}
12351235

1236-
/// Remove the union elements that are not related to `target`.
1237-
pub(crate) fn filter_disjoint_elements(
1236+
/// If the type is a union, filters union elements based on the provided predicate.
1237+
///
1238+
/// Otherwise, returns the type unchanged.
1239+
pub(crate) fn filter_union(
12381240
self,
12391241
db: &'db dyn Db,
1240-
target: Type<'db>,
1241-
inferable: InferableTypeVars<'_, 'db>,
1242+
f: impl FnMut(&Type<'db>) -> bool,
12421243
) -> Type<'db> {
12431244
if let Type::Union(union) = self {
1244-
union.filter(db, |elem| {
1245-
!elem
1246-
.when_disjoint_from(db, target, inferable)
1247-
.is_always_satisfied()
1248-
})
1245+
union.filter(db, f)
12491246
} else {
12501247
self
12511248
}
12521249
}
12531250

1251+
/// If the type is a union, removes union elements that are disjoint from `target`.
1252+
///
1253+
/// Otherwise, returns the type unchanged.
1254+
pub(crate) fn filter_disjoint_elements(
1255+
self,
1256+
db: &'db dyn Db,
1257+
target: Type<'db>,
1258+
inferable: InferableTypeVars<'_, 'db>,
1259+
) -> Type<'db> {
1260+
self.filter_union(db, |elem| {
1261+
!elem
1262+
.when_disjoint_from(db, target, inferable)
1263+
.is_always_satisfied()
1264+
})
1265+
}
1266+
12541267
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
12551268
/// is not a literal.
12561269
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
@@ -11185,9 +11198,9 @@ impl<'db> UnionType<'db> {
1118511198
pub(crate) fn filter(
1118611199
self,
1118711200
db: &'db dyn Db,
11188-
filter_fn: impl FnMut(&&Type<'db>) -> bool,
11201+
mut f: impl FnMut(&Type<'db>) -> bool,
1118911202
) -> Type<'db> {
11190-
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
11203+
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
1119111204
}
1119211205

1119311206
pub(crate) fn map_with_boundness(

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2524,20 +2524,23 @@ struct ArgumentTypeChecker<'a, 'db> {
25242524
argument_matches: &'a [MatchedArgument<'db>],
25252525
parameter_tys: &'a mut [Option<Type<'db>>],
25262526
call_expression_tcx: &'a TypeContext<'db>,
2527+
return_ty: Type<'db>,
25272528
errors: &'a mut Vec<BindingError<'db>>,
25282529

25292530
inferable_typevars: InferableTypeVars<'db, 'db>,
25302531
specialization: Option<Specialization<'db>>,
25312532
}
25322533

25332534
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
2535+
#[expect(clippy::too_many_arguments)]
25342536
fn new(
25352537
db: &'db dyn Db,
25362538
signature: &'a Signature<'db>,
25372539
arguments: &'a CallArguments<'a, 'db>,
25382540
argument_matches: &'a [MatchedArgument<'db>],
25392541
parameter_tys: &'a mut [Option<Type<'db>>],
25402542
call_expression_tcx: &'a TypeContext<'db>,
2543+
return_ty: Type<'db>,
25412544
errors: &'a mut Vec<BindingError<'db>>,
25422545
) -> Self {
25432546
Self {
@@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25472550
argument_matches,
25482551
parameter_tys,
25492552
call_expression_tcx,
2553+
return_ty,
25502554
errors,
25512555
inferable_typevars: InferableTypeVars::None,
25522556
specialization: None,
@@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25882592
// TODO: Use the list of inferable typevars from the generic context of the callable.
25892593
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
25902594

2591-
// Note that we infer the annotated type _before_ the arguments if this call is part of
2592-
// an annotated assignment, to closer match the order of any unions written in the type
2593-
// annotation.
2594-
if let Some(return_ty) = self.signature.return_ty
2595-
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2596-
{
2597-
match call_expression_tcx {
2598-
// A type variable is not a useful type-context for expression inference, and applying it
2599-
// to the return type can lead to confusing unions in nested generic calls.
2600-
Type::TypeVar(_) => {}
2601-
2602-
_ => {
2603-
// Ignore any specialization errors here, because the type context is only used as a hint
2604-
// to infer a more assignable return type.
2605-
let _ = builder.infer(return_ty, call_expression_tcx);
2606-
}
2607-
}
2608-
}
2609-
26102595
let parameters = self.signature.parameters();
26112596
for (argument_index, adjusted_argument_index, _, argument_type) in
26122597
self.enumerate_argument_types()
@@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26312616
}
26322617
}
26332618

2634-
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
2619+
// Build the specialization first without inferring the type context.
2620+
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
2621+
let isolated_return_ty = self
2622+
.return_ty
2623+
.apply_specialization(self.db, isolated_specialization);
2624+
2625+
let mut try_infer_tcx = || {
2626+
let return_ty = self.signature.return_ty?;
2627+
let call_expression_tcx = self.call_expression_tcx.annotation?;
2628+
2629+
// A type variable is not a useful type-context for expression inference, and applying it
2630+
// to the return type can lead to confusing unions in nested generic calls.
2631+
if call_expression_tcx.is_type_var() {
2632+
return None;
2633+
}
2634+
2635+
// If the return type is already assignable to the annotated type, we can ignore the
2636+
// type context and prefer the narrower inferred type.
2637+
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
2638+
return None;
2639+
}
2640+
2641+
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2642+
// annotated assignment, to closer match the order of any unions written in the type annotation.
2643+
builder.infer(return_ty, call_expression_tcx).ok()?;
2644+
2645+
// Otherwise, build the specialization again after inferring the type context.
2646+
let specialization = builder.build(generic_context, *self.call_expression_tcx);
2647+
let return_ty = return_ty.apply_specialization(self.db, specialization);
2648+
2649+
Some((Some(specialization), return_ty))
2650+
};
2651+
2652+
(self.specialization, self.return_ty) =
2653+
try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty));
26352654
}
26362655

26372656
fn check_argument_type(
@@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
28262845
}
28272846
}
28282847

2829-
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
2830-
(self.inferable_typevars, self.specialization)
2848+
fn finish(
2849+
self,
2850+
) -> (
2851+
InferableTypeVars<'db, 'db>,
2852+
Option<Specialization<'db>>,
2853+
Type<'db>,
2854+
) {
2855+
(self.inferable_typevars, self.specialization, self.return_ty)
28312856
}
28322857
}
28332858

@@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> {
29853010
&self.argument_matches,
29863011
&mut self.parameter_tys,
29873012
call_expression_tcx,
3013+
self.return_ty,
29883014
&mut self.errors,
29893015
);
29903016

29913017
// If this overload is generic, first see if we can infer a specialization of the function
29923018
// from the arguments that were passed in.
29933019
checker.infer_specialization();
2994-
29953020
checker.check_argument_types();
2996-
(self.inferable_typevars, self.specialization) = checker.finish();
2997-
if let Some(specialization) = self.specialization {
2998-
self.return_ty = self.return_ty.apply_specialization(db, specialization);
2999-
}
3021+
3022+
(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
30003023
}
30013024

30023025
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
12291229
let tcx = tcx_specialization.and_then(|specialization| {
12301230
specialization.get(self.db, variable.bound_typevar)
12311231
});
1232+
12321233
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
12331234
}
12341235

@@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
12511252
pub(crate) fn infer(
12521253
&mut self,
12531254
formal: Type<'db>,
1254-
mut actual: Type<'db>,
1255+
actual: Type<'db>,
12551256
) -> Result<(), SpecializationError<'db>> {
12561257
if formal == actual {
12571258
return Ok(());
@@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
12821283
return Ok(());
12831284
}
12841285

1285-
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1286-
// So, here we remove the union elements that are not related to `formal`.
1287-
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
1286+
// Remove the union elements that are not related to `formal`.
1287+
//
1288+
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
1289+
// to `int`.
1290+
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
12881291

12891292
match (formal, actual) {
12901293
// TODO: We haven't implemented a full unification solver yet. If typevars appear in

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
391391
.and_then(|ty| ty.known_specialization(db, known_class))
392392
}
393393

394-
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
394+
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
395395
Self {
396396
annotation: self.annotation.map(f),
397397
}

0 commit comments

Comments
 (0)