Skip to content

Commit fe942dc

Browse files
committed
avoid unnecessarily widening generic specializations
1 parent 73520e4 commit fe942dc

File tree

7 files changed

+136
-54
lines changed

7 files changed

+136
-54
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
190190
reveal_type(k) # revealed: list[tuple[list[int], ...]]
191191

192192
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
193-
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
194-
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
193+
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]
195194

196195
type IntList = list[int]
197196

@@ -416,13 +415,13 @@ a = f("a")
416415
reveal_type(a) # revealed: list[Literal["a"]]
417416

418417
b: list[int | Literal["a"]] = f("a")
419-
reveal_type(b) # revealed: list[int | Literal["a"]]
418+
reveal_type(b) # revealed: list[Literal["a"] | int]
420419

421420
c: list[int | str] = f("a")
422-
reveal_type(c) # revealed: list[int | str]
421+
reveal_type(c) # revealed: list[str | int]
423422

424423
d: list[int | tuple[int, int]] = f((1, 2))
425-
reveal_type(d) # revealed: list[int | tuple[int, int]]
424+
reveal_type(d) # revealed: list[tuple[int, int] | int]
426425

427426
e: list[int] = f(True)
428427
reveal_type(e) # revealed: list[int]
@@ -437,8 +436,43 @@ def f2[T: int](x: T) -> T:
437436
return x
438437

439438
i: int = f2(True)
440-
reveal_type(i) # revealed: int
439+
reveal_type(i) # revealed: Literal[True]
441440

442441
j: int | str = f2(True)
443442
reveal_type(j) # revealed: Literal[True]
444443
```
444+
445+
Types are not widened unnecessarily:
446+
447+
```py
448+
def id[T](x: T) -> T:
449+
return x
450+
451+
def lst[T](x: T) -> list[T]:
452+
return [x]
453+
454+
def _(i: int):
455+
a: int | None = i
456+
b: int | None = id(i)
457+
c: int | str | None = id(i)
458+
reveal_type(a) # revealed: int
459+
reveal_type(b) # revealed: int
460+
reveal_type(c) # revealed: int
461+
462+
a: list[int | None] | None = [i]
463+
b: list[int | None] | None = id([i])
464+
c: list[int | None] | int | None = id([i])
465+
reveal_type(a) # revealed: list[int | None]
466+
467+
# TODO: these should reveal `list[int | None]`
468+
# we currently do not use the call expression annotation as type context for argument inference
469+
reveal_type(b) # revealed: list[Unknown | int] | list[int | None] | None
470+
reveal_type(c) # revealed: list[Unknown | int] | list[int | None] | int | None
471+
472+
a: list[int | None] | None = [i]
473+
b: list[int | None] | None = lst(i)
474+
c: list[int | None] | int | None = lst(i)
475+
reveal_type(a) # revealed: list[int | None]
476+
reveal_type(b) # revealed: list[int | None]
477+
reveal_type(c) # revealed: list[int | None]
478+
```

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}
907907

908908
nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}
909909

910-
# TODO: this should be an error (invalid type for `name` in innermost node)
910+
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
911911
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
912912
```
913913

crates/ty_python_semantic/src/types.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,22 +1202,27 @@ impl<'db> Type<'db> {
12021202
if yes { self.negate(db) } else { *self }
12031203
}
12041204

1205+
/// Filters union elements based on the provided predicate.
1206+
pub(crate) fn filter(self, db: &'db dyn Db, f: impl FnMut(&Type<'db>) -> bool) -> Type<'db> {
1207+
if let Type::Union(union) = self {
1208+
union.filter(db, f)
1209+
} else {
1210+
self
1211+
}
1212+
}
1213+
12051214
/// Remove the union elements that are not related to `target`.
12061215
pub(crate) fn filter_disjoint_elements(
12071216
self,
12081217
db: &'db dyn Db,
12091218
target: Type<'db>,
12101219
inferable: InferableTypeVars<'_, 'db>,
12111220
) -> Type<'db> {
1212-
if let Type::Union(union) = self {
1213-
union.filter(db, |elem| {
1214-
!elem
1215-
.when_disjoint_from(db, target, inferable)
1216-
.is_always_satisfied()
1217-
})
1218-
} else {
1219-
self
1220-
}
1221+
self.filter(db, |elem| {
1222+
!elem
1223+
.when_disjoint_from(db, target, inferable)
1224+
.is_always_satisfied()
1225+
})
12211226
}
12221227

12231228
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
@@ -11127,9 +11132,9 @@ impl<'db> UnionType<'db> {
1112711132
pub(crate) fn filter(
1112811133
self,
1112911134
db: &'db dyn Db,
11130-
filter_fn: impl FnMut(&&Type<'db>) -> bool,
11135+
mut f: impl FnMut(&Type<'db>) -> bool,
1113111136
) -> Type<'db> {
11132-
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
11137+
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
1113311138
}
1113411139

1113511140
pub(crate) fn map_with_boundness(

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,20 +2490,23 @@ struct ArgumentTypeChecker<'a, 'db> {
24902490
argument_matches: &'a [MatchedArgument<'db>],
24912491
parameter_tys: &'a mut [Option<Type<'db>>],
24922492
call_expression_tcx: &'a TypeContext<'db>,
2493+
return_ty: Type<'db>,
24932494
errors: &'a mut Vec<BindingError<'db>>,
24942495

24952496
inferable_typevars: InferableTypeVars<'db, 'db>,
24962497
specialization: Option<Specialization<'db>>,
24972498
}
24982499

24992500
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
2501+
#[allow(clippy::too_many_arguments)]
25002502
fn new(
25012503
db: &'db dyn Db,
25022504
signature: &'a Signature<'db>,
25032505
arguments: &'a CallArguments<'a, 'db>,
25042506
argument_matches: &'a [MatchedArgument<'db>],
25052507
parameter_tys: &'a mut [Option<Type<'db>>],
25062508
call_expression_tcx: &'a TypeContext<'db>,
2509+
return_ty: Type<'db>,
25072510
errors: &'a mut Vec<BindingError<'db>>,
25082511
) -> Self {
25092512
Self {
@@ -2513,6 +2516,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25132516
argument_matches,
25142517
parameter_tys,
25152518
call_expression_tcx,
2519+
return_ty,
25162520
errors,
25172521
inferable_typevars: InferableTypeVars::None,
25182522
specialization: None,
@@ -2554,25 +2558,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25542558
// TODO: Use the list of inferable typevars from the generic context of the callable.
25552559
let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars);
25562560

2557-
// Note that we infer the annotated type _before_ the arguments if this call is part of
2558-
// an annotated assignment, to closer match the order of any unions written in the type
2559-
// annotation.
2560-
if let Some(return_ty) = self.signature.return_ty
2561-
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2562-
{
2563-
match call_expression_tcx {
2564-
// A type variable is not a useful type-context for expression inference, and applying it
2565-
// to the return type can lead to confusing unions in nested generic calls.
2566-
Type::TypeVar(_) => {}
2567-
2568-
_ => {
2569-
// Ignore any specialization errors here, because the type context is only used as a hint
2570-
// to infer a more assignable return type.
2571-
let _ = builder.infer(return_ty, call_expression_tcx);
2572-
}
2573-
}
2574-
}
2575-
25762561
let parameters = self.signature.parameters();
25772562
for (argument_index, adjusted_argument_index, _, argument_type) in
25782563
self.enumerate_argument_types()
@@ -2597,7 +2582,42 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25972582
}
25982583
}
25992584

2600-
self.specialization = Some(builder.build(generic_context, *self.call_expression_tcx));
2585+
// Build the specialization once without type context.
2586+
let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx);
2587+
let isolated_return_ty = self
2588+
.return_ty
2589+
.apply_specialization(self.db, isolated_specialization);
2590+
2591+
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2592+
// annotated assignment, to closer match the order of any unions written in the type annotation.
2593+
if let Some(call_expression_tcx) = self.call_expression_tcx.annotation {
2594+
match call_expression_tcx {
2595+
// A type variable is not a useful type-context for expression inference, and applying it
2596+
// to the return type can lead to confusing unions in nested generic calls.
2597+
Type::TypeVar(_) => {}
2598+
2599+
_ => {
2600+
// Ignore any specialization errors here, because the type context is only used as a hint
2601+
// to infer a more assignable return type.
2602+
let _ = builder.infer(self.return_ty, call_expression_tcx);
2603+
}
2604+
}
2605+
2606+
// Build the specialization a second time with the type context.
2607+
let specialization = builder.build(generic_context, *self.call_expression_tcx);
2608+
let return_ty = self.return_ty.apply_specialization(self.db, specialization);
2609+
2610+
// The type context should only be used to infer a more assignable return type, not an unnecessarily wide
2611+
// one, so we may ignore the type context here to prefer the narrower return type.
2612+
if !isolated_return_ty.is_subtype_of(self.db, return_ty) {
2613+
self.return_ty = return_ty;
2614+
self.specialization = Some(specialization);
2615+
return;
2616+
}
2617+
}
2618+
2619+
self.return_ty = isolated_return_ty;
2620+
self.specialization = Some(isolated_specialization);
26012621
}
26022622

26032623
fn check_argument_type(
@@ -2792,8 +2812,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27922812
}
27932813
}
27942814

2795-
fn finish(self) -> (InferableTypeVars<'db, 'db>, Option<Specialization<'db>>) {
2796-
(self.inferable_typevars, self.specialization)
2815+
fn finish(
2816+
self,
2817+
) -> (
2818+
InferableTypeVars<'db, 'db>,
2819+
Option<Specialization<'db>>,
2820+
Type<'db>,
2821+
) {
2822+
(self.inferable_typevars, self.specialization, self.return_ty)
27972823
}
27982824
}
27992825

@@ -2950,18 +2976,16 @@ impl<'db> Binding<'db> {
29502976
&self.argument_matches,
29512977
&mut self.parameter_tys,
29522978
call_expression_tcx,
2979+
self.return_ty,
29532980
&mut self.errors,
29542981
);
29552982

29562983
// If this overload is generic, first see if we can infer a specialization of the function
29572984
// from the arguments that were passed in.
29582985
checker.infer_specialization();
2959-
29602986
checker.check_argument_types();
2961-
(self.inferable_typevars, self.specialization) = checker.finish();
2962-
if let Some(specialization) = self.specialization {
2963-
self.return_ty = self.return_ty.apply_specialization(db, specialization);
2964-
}
2987+
2988+
(self.inferable_typevars, self.specialization, self.return_ty) = checker.finish();
29652989
}
29662990

29672991
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,7 @@ impl<'db> SpecializationBuilder<'db> {
12291229
let tcx = tcx_specialization.and_then(|specialization| {
12301230
specialization.get(self.db, variable.bound_typevar)
12311231
});
1232+
12321233
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
12331234
}
12341235

@@ -1251,7 +1252,7 @@ impl<'db> SpecializationBuilder<'db> {
12511252
pub(crate) fn infer(
12521253
&mut self,
12531254
formal: Type<'db>,
1254-
mut actual: Type<'db>,
1255+
actual: Type<'db>,
12551256
) -> Result<(), SpecializationError<'db>> {
12561257
if formal == actual {
12571258
return Ok(());
@@ -1282,9 +1283,11 @@ impl<'db> SpecializationBuilder<'db> {
12821283
return Ok(());
12831284
}
12841285

1285-
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1286-
// So, here we remove the union elements that are not related to `formal`.
1287-
actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
1286+
// Remove the union elements that are not related to `formal`.
1287+
//
1288+
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
1289+
// to `int`.
1290+
let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable);
12881291

12891292
match (formal, actual) {
12901293
// TODO: We haven't implemented a full unification solver yet. If typevars appear in

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
391391
.and_then(|ty| ty.known_specialization(db, known_class))
392392
}
393393

394-
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
394+
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
395395
Self {
396396
annotation: self.annotation.map(f),
397397
}

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5830,6 +5830,18 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58305830
parenthesized: _,
58315831
} = tuple;
58325832

5833+
// TODO: Use the list of inferable typevars from the generic context of tuple.
5834+
let inferable = InferableTypeVars::None;
5835+
5836+
// Remove any union elements of that are unrelated to the tuple type.
5837+
let tcx = tcx.map(|annotation| {
5838+
annotation.filter_disjoint_elements(
5839+
self.db(),
5840+
KnownClass::Tuple.to_instance(self.db()),
5841+
inferable,
5842+
)
5843+
});
5844+
58335845
let annotated_tuple = tcx
58345846
.known_specialization(self.db(), KnownClass::Tuple)
58355847
.and_then(|specialization| {
@@ -5895,7 +5907,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58955907
} = dict;
58965908

58975909
// Validate `TypedDict` dictionary literal assignments.
5898-
if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict)
5910+
if let Some(tcx) = tcx.annotation
5911+
&& let Some(typed_dict) = tcx.filter(self.db(), Type::is_typed_dict).as_typed_dict()
58995912
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
59005913
{
59015914
return ty;
@@ -5978,9 +5991,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59785991
// TODO: Use the list of inferable typevars from the generic context of the collection
59795992
// class.
59805993
let inferable = InferableTypeVars::None;
5981-
let tcx = tcx.map_annotation(|annotation| {
5982-
// Remove any union elements of `annotation` that are not related to `collection_ty`.
5983-
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
5994+
5995+
// Remove any union elements of that are unrelated to the collection type.
5996+
//
5997+
// For example, we only want the `list[int]` from `annotation: list[int] | None` if
5998+
// `collection_ty` is `list`.
5999+
let tcx = tcx.map(|annotation| {
59846000
let collection_ty = collection_class.to_instance(self.db());
59856001
annotation.filter_disjoint_elements(self.db(), collection_ty, inferable)
59866002
});

0 commit comments

Comments
 (0)