Skip to content

Commit 981fd9c

Browse files
committed
avoid unnecessarily widening generic specializations
1 parent 43eddc5 commit 981fd9c

File tree

7 files changed

+117
-52
lines changed

7 files changed

+117
-52
lines changed

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ k: list[tuple[list[int], ...]] | None = [([],), ([1, 2], [3, 4]), ([5], [6], [7]
190190
reveal_type(k) # revealed: list[tuple[list[int], ...]]
191191

192192
l: tuple[list[int], *tuple[list[typing.Any], ...], list[str]] | None = ([1, 2, 3], [4, 5, 6], [7, 8, 9], ["10", "11", "12"])
193-
# TODO: this should be `tuple[list[int], list[Any | int], list[Any | int], list[str]]`
194-
reveal_type(l) # revealed: tuple[list[Unknown | int], list[Unknown | int], list[Unknown | int], list[Unknown | str]]
193+
reveal_type(l) # revealed: tuple[list[int], list[Any | int], list[Any | int], list[str]]
195194

196195
type IntList = list[int]
197196

@@ -416,13 +415,13 @@ a = f("a")
416415
reveal_type(a) # revealed: list[Literal["a"]]
417416

418417
b: list[int | Literal["a"]] = f("a")
419-
reveal_type(b) # revealed: list[int | Literal["a"]]
418+
reveal_type(b) # revealed: list[Literal["a"] | int]
420419

421420
c: list[int | str] = f("a")
422-
reveal_type(c) # revealed: list[int | str]
421+
reveal_type(c) # revealed: list[str | int]
423422

424423
d: list[int | tuple[int, int]] = f((1, 2))
425-
reveal_type(d) # revealed: list[int | tuple[int, int]]
424+
reveal_type(d) # revealed: list[tuple[int, int] | int]
426425

427426
e: list[int] = f(True)
428427
reveal_type(e) # revealed: list[int]
@@ -437,8 +436,43 @@ def f2[T: int](x: T) -> T:
437436
return x
438437

439438
i: int = f2(True)
440-
reveal_type(i) # revealed: int
439+
reveal_type(i) # revealed: Literal[True]
441440

442441
j: int | str = f2(True)
443442
reveal_type(j) # revealed: Literal[True]
444443
```
444+
445+
Types are not widened unnecessarily:
446+
447+
```py
448+
def id[T](x: T) -> T:
449+
return x
450+
451+
def lst[T](x: T) -> list[T]:
452+
return [x]
453+
454+
def _(i: int):
455+
a: int | None = i
456+
b: int | None = id(i)
457+
c: int | str | None = id(i)
458+
reveal_type(a) # revealed: int
459+
reveal_type(b) # revealed: int
460+
reveal_type(c) # revealed: int
461+
462+
a: list[int | None] | None = [i]
463+
b: list[int | None] | None = id([i])
464+
c: list[int | None] | int | None = id([i])
465+
reveal_type(a) # revealed: list[int | None]
466+
467+
# TODO: these should reveal `list[int | None]`
468+
# we currently do not use the call expression annotation as type context for argument inference
469+
reveal_type(b) # revealed: list[Unknown | int] | list[int | None] | None
470+
reveal_type(c) # revealed: list[Unknown | int] | list[int | None] | int | None
471+
472+
a: list[int | None] | None = [i]
473+
b: list[int | None] | None = lst(i)
474+
c: list[int | None] | int | None = lst(i)
475+
reveal_type(a) # revealed: list[int | None]
476+
reveal_type(b) # revealed: list[int | None]
477+
reveal_type(c) # revealed: list[int | None]
478+
```

crates/ty_python_semantic/resources/mdtest/typed_dict.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ grandchild: Node = {"name": "grandchild", "parent": child}
907907

908908
nested: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": "n3", "parent": None}}}
909909

910-
# TODO: this should be an error (invalid type for `name` in innermost node)
910+
# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Node`: value of type `Literal[3]`"
911911
nested_invalid: Node = {"name": "n1", "parent": {"name": "n2", "parent": {"name": 3, "parent": None}}}
912912
```
913913

crates/ty_python_semantic/src/types.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,15 +1196,20 @@ impl<'db> Type<'db> {
11961196
if yes { self.negate(db) } else { *self }
11971197
}
11981198

1199-
/// Remove the union elements that are not related to `target`.
1200-
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
1199+
/// Filters union elements based on the provided predicate.
1200+
pub(crate) fn filter(self, db: &'db dyn Db, f: impl FnMut(&Type<'db>) -> bool) -> Type<'db> {
12011201
if let Type::Union(union) = self {
1202-
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
1202+
union.filter(db, f)
12031203
} else {
12041204
self
12051205
}
12061206
}
12071207

1208+
/// Remove the union elements that are not related to `target`.
1209+
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
1210+
self.filter(db, |elem| !elem.is_disjoint_from(db, target))
1211+
}
1212+
12081213
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
12091214
/// is not a literal.
12101215
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {
@@ -10861,9 +10866,9 @@ impl<'db> UnionType<'db> {
1086110866
pub(crate) fn filter(
1086210867
self,
1086310868
db: &'db dyn Db,
10864-
filter_fn: impl FnMut(&&Type<'db>) -> bool,
10869+
mut f: impl FnMut(&Type<'db>) -> bool,
1086510870
) -> Type<'db> {
10866-
Self::from_elements(db, self.elements(db).iter().filter(filter_fn))
10871+
Self::from_elements(db, self.elements(db).iter().filter(|ty| f(ty)))
1086710872
}
1086810873

1086910874
pub(crate) fn map_with_boundness(

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 46 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2459,19 +2459,22 @@ struct ArgumentTypeChecker<'a, 'db> {
24592459
argument_matches: &'a [MatchedArgument<'db>],
24602460
parameter_tys: &'a mut [Option<Type<'db>>],
24612461
call_expression_tcx: &'a TypeContext<'db>,
2462+
return_ty: Type<'db>,
24622463
errors: &'a mut Vec<BindingError<'db>>,
24632464

24642465
specialization: Option<Specialization<'db>>,
24652466
}
24662467

24672468
impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
2469+
#[allow(clippy::too_many_arguments)]
24682470
fn new(
24692471
db: &'db dyn Db,
24702472
signature: &'a Signature<'db>,
24712473
arguments: &'a CallArguments<'a, 'db>,
24722474
argument_matches: &'a [MatchedArgument<'db>],
24732475
parameter_tys: &'a mut [Option<Type<'db>>],
24742476
call_expression_tcx: &'a TypeContext<'db>,
2477+
return_ty: Type<'db>,
24752478
errors: &'a mut Vec<BindingError<'db>>,
24762479
) -> Self {
24772480
Self {
@@ -2481,6 +2484,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24812484
argument_matches,
24822485
parameter_tys,
24832486
call_expression_tcx,
2487+
return_ty,
24842488
errors,
24852489
specialization: None,
24862490
}
@@ -2514,31 +2518,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25142518
}
25152519

25162520
fn infer_specialization(&mut self) {
2517-
if self.signature.generic_context.is_none() {
2521+
let Some(gc) = self.signature.generic_context else {
25182522
return;
2519-
}
2523+
};
25202524

25212525
let mut builder = SpecializationBuilder::new(self.db);
25222526

2523-
// Note that we infer the annotated type _before_ the arguments if this call is part of
2524-
// an annotated assignment, to closer match the order of any unions written in the type
2525-
// annotation.
2526-
if let Some(return_ty) = self.signature.return_ty
2527-
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
2528-
{
2529-
match call_expression_tcx {
2530-
// A type variable is not a useful type-context for expression inference, and applying it
2531-
// to the return type can lead to confusing unions in nested generic calls.
2532-
Type::TypeVar(_) => {}
2533-
2534-
_ => {
2535-
// Ignore any specialization errors here, because the type context is only used as a hint
2536-
// to infer a more assignable return type.
2537-
let _ = builder.infer(return_ty, call_expression_tcx);
2538-
}
2539-
}
2540-
}
2541-
25422527
let parameters = self.signature.parameters();
25432528
for (argument_index, adjusted_argument_index, _, argument_type) in
25442529
self.enumerate_argument_types()
@@ -2563,10 +2548,42 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25632548
}
25642549
}
25652550

2566-
self.specialization = self
2567-
.signature
2568-
.generic_context
2569-
.map(|gc| builder.build(gc, *self.call_expression_tcx));
2551+
// Build the specialization once without type context.
2552+
let isolated_specialization = builder.build(gc, *self.call_expression_tcx);
2553+
let isolated_return_ty = self
2554+
.return_ty
2555+
.apply_specialization(self.db, isolated_specialization);
2556+
2557+
// TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2558+
// annotated assignment, to closer match the order of any unions written in the type annotation.
2559+
if let Some(call_expression_tcx) = self.call_expression_tcx.annotation {
2560+
match call_expression_tcx {
2561+
// A type variable is not a useful type-context for expression inference, and applying it
2562+
// to the return type can lead to confusing unions in nested generic calls.
2563+
Type::TypeVar(_) => {}
2564+
2565+
_ => {
2566+
// Ignore any specialization errors here, because the type context is only used as a hint
2567+
// to infer a more assignable return type.
2568+
let _ = builder.infer(self.return_ty, call_expression_tcx);
2569+
}
2570+
}
2571+
2572+
// Build the specialization a second time with the type context.
2573+
let specialization = builder.build(gc, *self.call_expression_tcx);
2574+
let return_ty = self.return_ty.apply_specialization(self.db, specialization);
2575+
2576+
// The type context should only be used to infer a more assignable return type, not an unnecessarily wide
2577+
// one, so we may ignore the type context here to prefer the narrower return type.
2578+
if !isolated_return_ty.is_subtype_of(self.db, return_ty) {
2579+
self.return_ty = return_ty;
2580+
self.specialization = Some(specialization);
2581+
return;
2582+
}
2583+
}
2584+
2585+
self.return_ty = isolated_return_ty;
2586+
self.specialization = Some(isolated_specialization);
25702587
}
25712588

25722589
fn check_argument_type(
@@ -2754,8 +2771,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27542771
}
27552772
}
27562773

2757-
fn finish(self) -> Option<Specialization<'db>> {
2758-
self.specialization
2774+
fn finish(self) -> (Option<Specialization<'db>>, Type<'db>) {
2775+
(self.specialization, self.return_ty)
27592776
}
27602777
}
27612778

@@ -2908,6 +2925,7 @@ impl<'db> Binding<'db> {
29082925
&self.argument_matches,
29092926
&mut self.parameter_tys,
29102927
call_expression_tcx,
2928+
self.return_ty,
29112929
&mut self.errors,
29122930
);
29132931

@@ -2916,10 +2934,7 @@ impl<'db> Binding<'db> {
29162934
checker.infer_specialization();
29172935

29182936
checker.check_argument_types();
2919-
self.specialization = checker.finish();
2920-
if let Some(specialization) = self.specialization {
2921-
self.return_ty = self.return_ty.apply_specialization(db, specialization);
2922-
}
2937+
(self.specialization, self.return_ty) = checker.finish();
29232938
}
29242939

29252940
pub(crate) fn set_return_type(&mut self, return_ty: Type<'db>) {

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,6 +1198,7 @@ impl<'db> SpecializationBuilder<'db> {
11981198
let tcx = tcx_specialization.and_then(|specialization| {
11991199
specialization.get(self.db, variable.bound_typevar)
12001200
});
1201+
12011202
ty = ty.map(|ty| ty.promote_literals(self.db, TypeContext::new(tcx)));
12021203
}
12031204

@@ -1220,7 +1221,7 @@ impl<'db> SpecializationBuilder<'db> {
12201221
pub(crate) fn infer(
12211222
&mut self,
12221223
formal: Type<'db>,
1223-
mut actual: Type<'db>,
1224+
actual: Type<'db>,
12241225
) -> Result<(), SpecializationError<'db>> {
12251226
if formal == actual {
12261227
return Ok(());
@@ -1249,9 +1250,11 @@ impl<'db> SpecializationBuilder<'db> {
12491250
return Ok(());
12501251
}
12511252

1252-
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
1253-
// So, here we remove the union elements that are not related to `formal`.
1254-
actual = actual.filter_disjoint_elements(self.db, formal);
1253+
// Remove the union elements that are not related to `formal`.
1254+
//
1255+
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T`
1256+
// to `int`.
1257+
let actual = actual.filter_disjoint_elements(self.db, formal);
12551258

12561259
match (formal, actual) {
12571260
// TODO: We haven't implemented a full unification solver yet. If typevars appear in

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ impl<'db> TypeContext<'db> {
391391
.and_then(|ty| ty.known_specialization(db, known_class))
392392
}
393393

394-
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
394+
pub(crate) fn map(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
395395
Self {
396396
annotation: self.annotation.map(f),
397397
}

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5819,6 +5819,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58195819
parenthesized: _,
58205820
} = tuple;
58215821

5822+
// Remove any union elements of that are unrelated to the tuple type.
5823+
let tcx = tcx.map(|annotation| {
5824+
annotation.filter_disjoint_elements(self.db(), KnownClass::Tuple.to_instance(self.db()))
5825+
});
5826+
58225827
let annotated_tuple = tcx
58235828
.known_specialization(self.db(), KnownClass::Tuple)
58245829
.and_then(|specialization| {
@@ -5884,7 +5889,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58845889
} = dict;
58855890

58865891
// Validate `TypedDict` dictionary literal assignments.
5887-
if let Some(typed_dict) = tcx.annotation.and_then(Type::as_typed_dict)
5892+
if let Some(tcx) = tcx.annotation
5893+
&& let Some(typed_dict) = tcx.filter(self.db(), Type::is_typed_dict).as_typed_dict()
58885894
&& let Some(ty) = self.infer_typed_dict_expression(dict, typed_dict)
58895895
{
58905896
return ty;
@@ -5964,9 +5970,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59645970
return None;
59655971
};
59665972

5967-
let tcx = tcx.map_annotation(|annotation| {
5968-
// Remove any union elements of `annotation` that are not related to `collection_ty`.
5969-
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
5973+
// Remove any union elements of that are unrelated to the collection type.
5974+
//
5975+
// For example, we only want the `list[int]` from `annotation: list[int] | None` if
5976+
// `collection_ty` is `list`.
5977+
let tcx = tcx.map(|annotation| {
59705978
let collection_ty = collection_class.to_instance(self.db());
59715979
annotation.filter_disjoint_elements(self.db(), collection_ty)
59725980
});

0 commit comments

Comments
 (0)