Skip to content

Commit 5c817d8

Browse files
committed
avoid unnecessarily widening generic specializations
1 parent 73520e4 commit 5c817d8

File tree

7 files changed

+151
-54
lines changed

7 files changed

+151
-54
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
190190
reveal_type(k) # revealed: list[tuple[list[int], ...]]
191191

192192
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
193-
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
194-
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
193+
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]
195194

196195
type IntList = list[int]
197196

@@ -416,13 +415,13 @@ a = f("a")
416415
reveal_type(a) # revealed: list[Literal["a"]]
417416

418417
b: list[int | Literal["a"]] = f("a")
419-
reveal_type(b) # revealed: list[int | Literal["a"]]
418+
reveal_type(b) # revealed: list[Literal["a"] | int]
420419

421420
c: list[int | str] = f("a")
422-
reveal_type(c) # revealed: list[int | str]
421+
reveal_type(c) # revealed: list[str | int]
423422

424423
d: list[int | tuple[int, int]] = f((1, 2))
425-
reveal_type(d) # revealed: list[int | tuple[int, int]]
424+
reveal_type(d) # revealed: list[tuple[int, int] | int]
426425

427426
e: list[int] = f(True)
428427
reveal_type(e) # revealed: list[int]
@@ -437,8 +436,49 @@ def f2[T: int](x: T) -> T:
437436
return x
438437

439438
i: int = f2(True)
440-
reveal_type(i) # revealed: int
439+
reveal_type(i) # revealed: Literal[True]
441440

442441
j: int | str = f2(True)
443442
reveal_type(j) # revealed: Literal[True]
444443
```
444+
445+
Types are not widened unnecessarily:
446+
447+
```py
448+
def id[T](x: T) -> T:
449+
return x
450+
451+
def lst[T](x: T) -> list[T]:
452+
return [x]
453+
454+
def _(i: int):
455+
a: int | None = i
456+
b: int | None = id(i)
457+
c: int | str | None = id(i)
458+
reveal_type(a) # revealed: int
459+
reveal_type(b) # revealed: int
460+
reveal_type(c) # revealed: int
461+
462+
a: list[int | None] | None = [i]
463+
b: list[int | None] | None = id([i])
464+
c: list[int | None] | int | None = id([i])
465+
reveal_type(a) # revealed: list[int | None]
466+
# TODO: these should reveal `list[int | None]`
467+
# we currently do not use the call expression annotation as type context for argument inference
468+
reveal_type(b) # revealed: list[Unknown | int]
469+
reveal_type(c) # revealed: list[Unknown | int]
470+
471+
a: list[int | None] | None = [i]
472+
b: list[int | None] | None = lst(i)
473+
c: list[int | None] | int | None = lst(i)
474+
reveal_type(a) # revealed: list[int | None]
475+
reveal_type(b) # revealed: list[int | None]
476+
reveal_type(c) # revealed: list[int | None]
477+
478+
a: list | None = []
479+
b: list | None = id([])
480+
c: list | int | None = id([])
481+
reveal_type(a) # revealed: list[Unknown]
482+
reveal_type(b) # revealed: list[Unknown]
483+
reveal_type(c) # revealed: list[Unknown]
484+
```

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}
907907

908908
nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}
909909

910-
# TODO: this should be an error (invalid type for `name` in innermost node)
910+
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
911911
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
912912
```
913913

crates/ty_python_semantic/src/types.rs

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,24 +1202,37 @@ impl<'db> Type<'db> {
12021202
if yes { self.negate(db) } else { *self }
12031203
}
12041204

1205-
/// Remove the union elements that are not related to `target`.
1206-
pub(crate) fn filter_disjoint_elements(
1205+
/// If the type is a union, filters union elements based on the provided predicate.
1206+
///
1207+
/// Otherwise, returns the type unchanged.
1208+
pub(crate) fn filter_union(
12071209
self,
12081210
db: &'db dyn Db,
1209-
target: Type<'db>,
1210-
inferable: InferableTypeVars<'_, 'db>,
1211+
f: impl FnMut(&Type<'db>) -> bool,
12111212
) -> Type<'db> {
12121213
if let Type::Union(union) = self {
1213-
union.filter(db, |elem| {
1214-
!elem
1215-
.when_disjoint_from(db, target, inferable)
1216-
.is_always_satisfied()
1217-
})
1214+
union.filter(db, f)
12181215
} else {
12191216
self
12201217
}
12211218
}
12221219

1220+
/// If the type is a union, removes union elements that are disjoint from `target`.
1221+
///
1222+
/// Otherwise, returns the type unchanged.
1223+
pub(crate) fn filter_disjoint_elements(
1224+
self,
1225+
db: &'db dyn Db,
1226+
target: Type<'db>,
1227+
inferable: InferableTypeVars<'_, 'db>,
1228+
) -> Type<'db> {
1229+
self.filter_union(db, |elem| {
1230+
!elem
1231+
.when_disjoint_from(db, target, inferable)
1232+
.is_always_satisfied()
1233+
})
1234+
}
1235+
12231236
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
12241237
/// is not a literal.
12251238
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
@@ -11127,9 +11140,9 @@ impl<'db> UnionType<'db> {
1112711140
pub(crate) fn filter(
1112811141
self,
1112911142
db: &'db dyn Db,
11130-
filter_fn: impl FnMut(&&Type<'db>) -> bool,
11143+
mut f: impl FnMut(&Type<'db>) -> bool,
1113111144
) -> Type<'db> {
11132-
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
11145+
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
1113311146
}
1113411147

1113511148
pub(crate) fn map_with_boundness(

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,20 +2490,23 @@ struct ArgumentTypeChecker<'a, 'db> {
24902490
argument_matches: &'a [MatchedArgument<'db>],
24912491
parameter_tys: &'a mut [Option<Type<'db>>],
24922492
call_expression_tcx: &'a TypeContext<'db>,
2493+
return_ty: Type<'db>,
24932494
errors: &'a mut Vec<BindingError<'db>>,
24942495

24952496
inferable_typevars: InferableTypeVars<'db, 'db>,
24962497
specialization: Option<Specialization<'db>>,
24972498
}
24982499

24992500
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
2501+
#[expect(clippy::too_many_arguments)]
25002502
fn new(
25012503
db: &'db dyn Db,
25022504
signature: &'a Signature<'db>,
25032505
arguments: &'a CallArguments<'a, 'db>,
25042506
argument_matches: &'a [MatchedArgument<'db>],
25052507
parameter_tys: &'a mut [Option<Type<'db>>],
25062508
call_expression_tcx: &'a TypeContext<'db>,
2509+
return_ty: Type<'db>,
25072510
errors: &'a mut Vec<BindingError<'db>>,
25082511
) -> Self {
25092512
Self {
@@ -2513,6 +2516,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25132516
argument_matches,
25142517
parameter_tys,
25152518
call_expression_tcx,
2519+
return_ty,
25162520
errors,
25172521
inferable_typevars: InferableTypeVars::None,
25182522
specialization: None,
@@ -2554,25 +2558,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25542558
// TODO: Use the list of inferable typevars from the generic context of the callable.
25552559
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
25562560

2557-
// Note that we infer the annotated type _before_ the arguments if this call is part of
2558-
// an annotated assignment, to closer match the order of any unions written in the type
2559-
// annotation.
2560-
if let Some(return_ty) = self.signature.return_ty
2561-
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2562-
{
2563-
match call_expression_tcx {
2564-
// A type variable is not a useful type-context for expression inference, and applying it
2565-
// to the return type can lead to confusing unions in nested generic calls.
2566-
Type::TypeVar(_) => {}
2567-
2568-
_ => {
2569-
// Ignore any specialization errors here, because the type context is only used as a hint
2570-
// to infer a more assignable return type.
2571-
let _ = builder.infer(return_ty, call_expression_tcx);
2572-
}
2573-
}
2574-
}
2575-
25762561
let parameters = self.signature.parameters();
25772562
for (argument_index, adjusted_argument_index, _, argument_type) in
25782563
self.enumerate_argument_types()
@@ -2597,7 +2582,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25972582
}
25982583
}
25992584

2600-
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
2585+
// Build the specialization first without inferring the type context.
2586+
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
2587+
let isolated_return_ty = self
2588+
.return_ty
2589+
.apply_specialization(self.db, isolated_specialization);
2590+
2591+
let mut try_infer_tcx = || {
2592+
let return_ty = self.signature.return_ty?;
2593+
let call_expression_tcx = self.call_expression_tcx.annotation?;
2594+
2595+
// A type variable is not a useful type-context for expression inference, and applying it
2596+
// to the return type can lead to confusing unions in nested generic calls.
2597+
if call_expression_tcx.is_type_var() {
2598+
return None;
2599+
}
2600+
2601+
// If the return type is already assignable to the annotated type, we can ignore the
2602+
// type context and prefer the narrower inferred type.
2603+
if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) {
2604+
return None;
2605+
}
2606+
2607+
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2608+
// annotated assignment, to closer match the order of any unions written in the type annotation.
2609+
builder.infer(return_ty, call_expression_tcx).ok()?;
2610+
2611+
// Otherwise, build the specialization again after inferring the type context.
2612+
let specialization = builder.build(generic_context, *self.call_expression_tcx);
2613+
let return_ty = return_ty.apply_specialization(self.db, specialization);
2614+
2615+
Some((Some(specialization), return_ty))
2616+
};
2617+
2618+
(self.specialization, self.return_ty) =
2619+
try_infer_tcx().unwrap_or((Some(isolated_specialization), isolated_return_ty));
26012620
}
26022621

26032622
fn check_argument_type(
@@ -2792,8 +2811,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27922811
}
27932812
}
27942813

2795-
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
2796-
(self.inferable_typevars, self.specialization)
2814+
fn finish(
2815+
self,
2816+
) -> (
2817+
InferableTypeVars<'db, 'db>,
2818+
Option<Specialization<'db>>,
2819+
Type<'db>,
2820+
) {
2821+
(self.inferable_typevars, self.specialization, self.return_ty)
27972822
}
27982823
}
27992824

@@ -2950,18 +2975,16 @@ impl<'db> Binding<'db> {
29502975
&self.argument_matches,
29512976
&mut self.parameter_tys,
29522977
call_expression_tcx,
2978+
self.return_ty,
29532979
&mut self.errors,
29542980
);
29552981

29562982
// If this overload is generic, first see if we can infer a specialization of the function
29572983
// from the arguments that were passed in.
29582984
checker.infer_specialization();
2959-
29602985
checker.check_argument_types();
2961-
(self.inferable_typevars, self.specialization) = checker.finish();
2962-
if let Some(specialization) = self.specialization {
2963-
self.return_ty = self.return_ty.apply_specialization(db, specialization);
2964-
}
2986+
2987+
(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
29652988
}
29662989

29672990
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
12291229
let tcx = tcx_specialization.and_then(|specialization| {
12301230
specialization.get(self.db, variable.bound_typevar)
12311231
});
1232+
12321233
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
12331234
}
12341235

@@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
12511252
pub(crate) fn infer(
12521253
&mut self,
12531254
formal: Type<'db>,
1254-
mut actual: Type<'db>,
1255+
actual: Type<'db>,
12551256
) -> Result<(), SpecializationError<'db>> {
12561257
if formal == actual {
12571258
return Ok(());
@@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
12821283
return Ok(());
12831284
}
12841285

1285-
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1286-
// So, here we remove the union elements that are not related to `formal`.
1287-
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
1286+
// Remove the union elements that are not related to `formal`.
1287+
//
1288+
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
1289+
// to `int`.
1290+
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
12881291

12891292
match (formal, actual) {
12901293
// TODO: We haven't implemented a full unification solver yet. If typevars appear in

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
391391
.and_then(|ty| ty.known_specialization(db, known_class))
392392
}
393393

394-
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
394+
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
395395
Self {
396396
annotation: self.annotation.map(f),
397397
}

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5830,6 +5830,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58305830
parenthesized: _,
58315831
} = tuple;
58325832

5833+
// TODO: Use the list of inferable typevars from the generic context of tuple.
5834+
let inferable = InferableTypeVars::None;
5835+
5836+
// Remove any union elements of that are unrelated to the tuple type.
5837+
let tcx = tcx.map(|annotation| {
5838+
annotation.filter_disjoint_elements(
5839+
self.db(),
5840+
KnownClass::Tuple.to_instance(self.db()),
5841+
inferable,
5842+
)
5843+
});
5844+
58335845
let annotated_tuple = tcx
58345846
.known_specialization(self.db(), KnownClass::Tuple)
58355847
.and_then(|specialization| {
@@ -5895,7 +5907,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58955907
} = dict;
58965908

58975909
// Validate `TypedDict` dictionary literal assignments.
5898-
if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict)
5910+
if let Some(tcx) = tcx.annotation
5911+
&& let Some(typed_dict) = tcx
5912+
.filter_union(self.db(), Type::is_typed_dict)
5913+
.as_typed_dict()
58995914
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
59005915
{
59015916
return ty;
@@ -5978,9 +5993,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59785993
// TODO: Use the list of inferable typevars from the generic context of the collection
59795994
// class.
59805995
let inferable = InferableTypeVars::None;
5981-
let tcx = tcx.map_annotation(|annotation| {
5982-
// Remove any union elements of `annotation` that are not related to `collection_ty`.
5983-
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
5996+
5997+
// Remove any union elements of that are unrelated to the collection type.
5998+
//
5999+
// For example, we only want the `list[int]` from `annotation: list[int] | None` if
6000+
// `collection_ty` is `list`.
6001+
let tcx = tcx.map(|annotation| {
59846002
let collection_ty = collection_class.to_instance(self.db());
59856003
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
59866004
});

0 commit comments

Comments
 (0)