Skip to content

Commit 416e956

Browse files
authored
[ty] Infer better specializations of unions with None (etc) (#20749)
This PR adds a specialization inference special case that lets us handle the following examples better: ```py def f[T](t: T | None) -> T: ... def g[T](t: T | int | None) -> T | int: ... def _(x: str | None): reveal_type(f(x)) # revealed: str (previously str | None) def _(y: str | int | None): reveal_type(g(x)) # revealed: str | int (previously str | int | None) ``` We already have a special case for when the formal is a union where one element is a typevar, but it maps the entire actual type to the typevar (as you can see in the "previously" results above). The new special case kicks in when the actual is also a union. Now, we filter out any actual union elements that are already subtypes of the formal, and only bind whatever types remain to the typevar. (The `| None` pattern appears quite often in the ecosystem results, but it's more general and works with any number of non-typevar union elements.) The new constraint solver should handle this case as well, but it's worth adding this heuristic now with the old solver because it eliminates some false positives from the ecosystem report, and makes the ecosystem report less noisy on the other constraint solver PRs.
1 parent 88c0ce3 commit 416e956

File tree

3 files changed

+74
-19
lines changed

3 files changed

+74
-19
lines changed

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,23 @@ def g[T: A](b: B[T]):
441441
return f(b.x) # Fine
442442
```
443443

444-
## Constrained TypeVar in a union
444+
## Typevars in a union
445+
446+
```py
447+
def takes_in_union[T](t: T | None) -> T:
448+
raise NotImplementedError
449+
450+
def takes_in_bigger_union[T](t: T | int | None) -> T:
451+
raise NotImplementedError
452+
453+
def _(x: str | None) -> None:
454+
reveal_type(takes_in_union(x)) # revealed: str
455+
reveal_type(takes_in_bigger_union(x)) # revealed: str
456+
457+
def _(x: str | int | None) -> None:
458+
reveal_type(takes_in_union(x)) # revealed: str | int
459+
reveal_type(takes_in_bigger_union(x)) # revealed: str
460+
```
445461

446462
This is a regression test for an issue that surfaced in the primer report of an early version of
447463
<https://github.com/astral-sh/ruff/pull/19811>, where we failed to solve the `TypeVar` here due to

crates/ty_python_semantic/src/types.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,17 @@ impl<'db> Type<'db> {
943943
self.apply_type_mapping_impl(db, &TypeMapping::Materialize(materialization_kind), visitor)
944944
}
945945

946+
pub(crate) const fn is_type_var(self) -> bool {
947+
matches!(self, Type::TypeVar(_))
948+
}
949+
950+
pub(crate) const fn into_type_var(self) -> Option<BoundTypeVarInstance<'db>> {
951+
match self {
952+
Type::TypeVar(bound_typevar) => Some(bound_typevar),
953+
_ => None,
954+
}
955+
}
956+
946957
pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
947958
match self {
948959
Type::ClassLiteral(class_type) => Some(class_type),

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,25 +1138,53 @@ impl<'db> SpecializationBuilder<'db> {
11381138
}
11391139

11401140
match (formal, actual) {
1141-
(Type::Union(formal), _) => {
1142-
// TODO: We haven't implemented a full unification solver yet. If typevars appear
1143-
// in multiple union elements, we ideally want to express that _only one_ of them
1144-
// needs to match, and that we should infer the smallest type mapping that allows
1145-
// that.
1141+
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
1142+
// multiple union elements, we ideally want to express that _only one_ of them needs to
1143+
// match, and that we should infer the smallest type mapping that allows that.
1144+
//
1145+
// For now, we punt on fully handling multiple typevar elements. Instead, we handle two
1146+
// common cases specially:
1147+
(Type::Union(formal_union), Type::Union(actual_union)) => {
1148+
// First, if both formal and actual are unions, and precisely one formal union
1149+
// element _is_ a typevar (not _contains_ a typevar), then we remove any actual
1150+
// union elements that are a subtype of the formal (as a whole), and map the formal
1151+
// typevar to any remaining actual union elements.
1152+
//
1153+
// In particular, this handles cases like
1154+
//
1155+
// ```py
1156+
// def f[T](t: T | None) -> T: ...
1157+
// def g[T](t: T | int | None) -> T | int: ...
11461158
//
1147-
// For now, we punt on handling multiple typevar elements. Instead, if _precisely
1148-
// one_ union element _is_ a typevar (not _contains_ a typevar), then we go ahead
1149-
// and add a mapping between that typevar and the actual type. (Note that we've
1150-
// already handled above the case where the actual is assignable to a _non-typevar_
1151-
// union element.)
1152-
let mut bound_typevars =
1153-
formal.elements(self.db).iter().filter_map(|ty| match ty {
1154-
Type::TypeVar(bound_typevar) => Some(*bound_typevar),
1155-
_ => None,
1156-
});
1157-
let bound_typevar = bound_typevars.next();
1158-
let additional_bound_typevars = bound_typevars.next();
1159-
if let (Some(bound_typevar), None) = (bound_typevar, additional_bound_typevars) {
1159+
// def _(x: str | None):
1160+
// reveal_type(f(x)) # revealed: str
1161+
//
1162+
// def _(y: str | int | None):
1163+
// reveal_type(g(x)) # revealed: str | int
1164+
// ```
1165+
let formal_bound_typevars =
1166+
(formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var());
1167+
let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else {
1168+
return Ok(());
1169+
};
1170+
if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) {
1171+
return Ok(());
1172+
}
1173+
let remaining_actual =
1174+
actual_union.filter(self.db, |ty| !ty.is_subtype_of(self.db, formal));
1175+
if remaining_actual.is_never() {
1176+
return Ok(());
1177+
}
1178+
self.add_type_mapping(formal_bound_typevar, remaining_actual);
1179+
}
1180+
(Type::Union(formal), _) => {
1181+
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not
1182+
// _contains_ a typevar), then we add a mapping between that typevar and the actual
1183+
// type. (Note that we've already handled above the case where the actual is
1184+
// assignable to any _non-typevar_ union element.)
1185+
let bound_typevars =
1186+
(formal.elements(self.db).iter()).filter_map(|ty| ty.into_type_var());
1187+
if let Ok(bound_typevar) = bound_typevars.exactly_one() {
11601188
self.add_type_mapping(bound_typevar, actual);
11611189
}
11621190
}

0 commit comments

Comments
 (0)