Skip to content

Commit 0fb94c0

Browse files
authored
[ty] Infer parameter specializations of generic aliases (#18021)
This updates our function specialization inference to infer type mappings from parameters that are generic aliases, e.g.: ```py def f[T](x: list[T]) -> T: ... reveal_type(f(["a", "b"])) # revealed: str ``` Though note that we're still inferring the type of list literals as `list[Unknown]`, so for now we actually need something like the following in our tests: ```py def _(x: list[str]): reveal_type(f(x)) # revealed: str ```
1 parent 55df927 commit 0fb94c0

File tree

13 files changed

+98
-85
lines changed

13 files changed

+98
-85
lines changed

crates/ty_python_semantic/resources/mdtest/binary/tuples.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def _(x: tuple[int, str], y: tuple[None, tuple[int]]):
1717

1818
```py
1919
def _(x: tuple[int, ...], y: tuple[str, ...]):
20-
# TODO: should be `tuple[int | str, ...]`
21-
reveal_type(x + y) # revealed: tuple[int | Unknown, ...]
20+
reveal_type(x + y) # revealed: tuple[int | str, ...]
2221
reveal_type(x + (1, 2)) # revealed: tuple[int, ...]
2322
```

crates/ty_python_semantic/resources/mdtest/generics/legacy/functions.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,12 @@ def takes_in_protocol(x: CanIndex[T]) -> T:
8888
return x[0]
8989

9090
def deep_list(x: list[str]) -> None:
91-
# TODO: revealed: list[str]
92-
reveal_type(takes_in_list(x)) # revealed: list[Unknown]
91+
reveal_type(takes_in_list(x)) # revealed: list[str]
9392
# TODO: revealed: str
9493
reveal_type(takes_in_protocol(x)) # revealed: Unknown
9594

9695
def deeper_list(x: list[set[str]]) -> None:
97-
# TODO: revealed: list[set[str]]
98-
reveal_type(takes_in_list(x)) # revealed: list[Unknown]
96+
reveal_type(takes_in_list(x)) # revealed: list[set[str]]
9997
# TODO: revealed: set[str]
10098
reveal_type(takes_in_protocol(x)) # revealed: Unknown
10199

@@ -119,13 +117,11 @@ This also works when passing in arguments that are subclasses of the parameter t
119117
class Sub(list[int]): ...
120118
class GenericSub(list[T]): ...
121119

122-
# TODO: revealed: list[int]
123-
reveal_type(takes_in_list(Sub())) # revealed: list[Unknown]
120+
reveal_type(takes_in_list(Sub())) # revealed: list[int]
124121
# TODO: revealed: int
125122
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
126123

127-
# TODO: revealed: list[str]
128-
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[Unknown]
124+
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str]
129125
# TODO: revealed: str
130126
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
131127

crates/ty_python_semantic/resources/mdtest/generics/pep695/functions.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,12 @@ def takes_in_protocol[T](x: CanIndex[T]) -> T:
8383
return x[0]
8484

8585
def deep_list(x: list[str]) -> None:
86-
# TODO: revealed: list[str]
87-
reveal_type(takes_in_list(x)) # revealed: list[Unknown]
86+
reveal_type(takes_in_list(x)) # revealed: list[str]
8887
# TODO: revealed: str
8988
reveal_type(takes_in_protocol(x)) # revealed: Unknown
9089

9190
def deeper_list(x: list[set[str]]) -> None:
92-
# TODO: revealed: list[set[str]]
93-
reveal_type(takes_in_list(x)) # revealed: list[Unknown]
91+
reveal_type(takes_in_list(x)) # revealed: list[set[str]]
9492
# TODO: revealed: set[str]
9593
reveal_type(takes_in_protocol(x)) # revealed: Unknown
9694

@@ -114,13 +112,11 @@ This also works when passing in arguments that are subclasses of the parameter t
114112
class Sub(list[int]): ...
115113
class GenericSub[T](list[T]): ...
116114

117-
# TODO: revealed: list[int]
118-
reveal_type(takes_in_list(Sub())) # revealed: list[Unknown]
115+
reveal_type(takes_in_list(Sub())) # revealed: list[int]
119116
# TODO: revealed: int
120117
reveal_type(takes_in_protocol(Sub())) # revealed: Unknown
121118

122-
# TODO: revealed: list[str]
123-
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[Unknown]
119+
reveal_type(takes_in_list(GenericSub[str]())) # revealed: list[str]
124120
# TODO: revealed: str
125121
reveal_type(takes_in_protocol(GenericSub[str]())) # revealed: Unknown
126122

crates/ty_python_semantic/src/types.rs

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -562,25 +562,22 @@ impl<'db> Type<'db> {
562562

563563
fn is_none(&self, db: &'db dyn Db) -> bool {
564564
self.into_nominal_instance()
565-
.is_some_and(|instance| instance.class().is_known(db, KnownClass::NoneType))
565+
.is_some_and(|instance| instance.class.is_known(db, KnownClass::NoneType))
566566
}
567567

568568
fn is_bool(&self, db: &'db dyn Db) -> bool {
569569
self.into_nominal_instance()
570-
.is_some_and(|instance| instance.class().is_known(db, KnownClass::Bool))
570+
.is_some_and(|instance| instance.class.is_known(db, KnownClass::Bool))
571571
}
572572

573573
pub fn is_notimplemented(&self, db: &'db dyn Db) -> bool {
574-
self.into_nominal_instance().is_some_and(|instance| {
575-
instance
576-
.class()
577-
.is_known(db, KnownClass::NotImplementedType)
578-
})
574+
self.into_nominal_instance()
575+
.is_some_and(|instance| instance.class.is_known(db, KnownClass::NotImplementedType))
579576
}
580577

581578
pub fn is_object(&self, db: &'db dyn Db) -> bool {
582579
self.into_nominal_instance()
583-
.is_some_and(|instance| instance.class().is_object(db))
580+
.is_some_and(|instance| instance.class.is_object(db))
584581
}
585582

586583
pub const fn is_todo(&self) -> bool {
@@ -1063,7 +1060,7 @@ impl<'db> Type<'db> {
10631060
(_, Type::Never) => false,
10641061

10651062
// Everything is a subtype of `object`.
1066-
(_, Type::NominalInstance(instance)) if instance.class().is_object(db) => true,
1063+
(_, Type::NominalInstance(instance)) if instance.class.is_object(db) => true,
10671064

10681065
// In general, a TypeVar `T` is not a subtype of a type `S` unless one of the two conditions is satisfied:
10691066
// 1. `T` is a bound TypeVar and `T`'s upper bound is a subtype of `S`.
@@ -1373,7 +1370,7 @@ impl<'db> Type<'db> {
13731370

13741371
// All types are assignable to `object`.
13751372
// TODO this special case might be removable once the below cases are comprehensive
1376-
(_, Type::NominalInstance(instance)) if instance.class().is_object(db) => true,
1373+
(_, Type::NominalInstance(instance)) if instance.class.is_object(db) => true,
13771374

13781375
// In general, a TypeVar `T` is not assignable to a type `S` unless one of the two conditions is satisfied:
13791376
// 1. `T` is a bound TypeVar and `T`'s upper bound is assignable to `S`.
@@ -1547,7 +1544,7 @@ impl<'db> Type<'db> {
15471544
}
15481545

15491546
(Type::NominalInstance(instance), Type::Callable(_))
1550-
if instance.class().is_subclass_of_any_or_unknown(db) =>
1547+
if instance.class.is_subclass_of_any_or_unknown(db) =>
15511548
{
15521549
true
15531550
}
@@ -1616,7 +1613,7 @@ impl<'db> Type<'db> {
16161613
}
16171614
(Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n))
16181615
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
1619-
n.class().is_object(db) && protocol.normalized(db) == nominal
1616+
n.class.is_object(db) && protocol.normalized(db) == nominal
16201617
}
16211618
_ => self == other && self.is_fully_static(db) && other.is_fully_static(db),
16221619
}
@@ -1671,7 +1668,7 @@ impl<'db> Type<'db> {
16711668
}
16721669
(Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n))
16731670
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
1674-
n.class().is_object(db) && protocol.normalized(db) == nominal
1671+
n.class.is_object(db) && protocol.normalized(db) == nominal
16751672
}
16761673
_ => false,
16771674
}
@@ -1883,7 +1880,7 @@ impl<'db> Type<'db> {
18831880
// member on `protocol`.
18841881
(Type::ProtocolInstance(protocol), nominal @ Type::NominalInstance(n))
18851882
| (nominal @ Type::NominalInstance(n), Type::ProtocolInstance(protocol)) => {
1886-
n.class().is_final(db) && !nominal.satisfies_protocol(db, protocol)
1883+
n.class.is_final(db) && !nominal.satisfies_protocol(db, protocol)
18871884
}
18881885

18891886
(
@@ -1948,7 +1945,7 @@ impl<'db> Type<'db> {
19481945

19491946
(Type::KnownInstance(known_instance), Type::NominalInstance(instance))
19501947
| (Type::NominalInstance(instance), Type::KnownInstance(known_instance)) => {
1951-
!known_instance.is_instance_of(db, instance.class())
1948+
!known_instance.is_instance_of(db, instance.class)
19521949
}
19531950

19541951
(known_instance_ty @ Type::KnownInstance(_), Type::Tuple(tuple))
@@ -1960,7 +1957,7 @@ impl<'db> Type<'db> {
19601957
| (Type::NominalInstance(instance), Type::BooleanLiteral(..)) => {
19611958
// A `Type::BooleanLiteral()` must be an instance of exactly `bool`
19621959
// (it cannot be an instance of a `bool` subclass)
1963-
!KnownClass::Bool.is_subclass_of(db, instance.class())
1960+
!KnownClass::Bool.is_subclass_of(db, instance.class)
19641961
}
19651962

19661963
(Type::BooleanLiteral(..), _) | (_, Type::BooleanLiteral(..)) => true,
@@ -1969,7 +1966,7 @@ impl<'db> Type<'db> {
19691966
| (Type::NominalInstance(instance), Type::IntLiteral(..)) => {
19701967
// A `Type::IntLiteral()` must be an instance of exactly `int`
19711968
// (it cannot be an instance of an `int` subclass)
1972-
!KnownClass::Int.is_subclass_of(db, instance.class())
1969+
!KnownClass::Int.is_subclass_of(db, instance.class)
19731970
}
19741971

19751972
(Type::IntLiteral(..), _) | (_, Type::IntLiteral(..)) => true,
@@ -1981,7 +1978,7 @@ impl<'db> Type<'db> {
19811978
| (Type::NominalInstance(instance), Type::StringLiteral(..) | Type::LiteralString) => {
19821979
// A `Type::StringLiteral()` or a `Type::LiteralString` must be an instance of exactly `str`
19831980
// (it cannot be an instance of a `str` subclass)
1984-
!KnownClass::Str.is_subclass_of(db, instance.class())
1981+
!KnownClass::Str.is_subclass_of(db, instance.class)
19851982
}
19861983

19871984
(Type::LiteralString, Type::LiteralString) => false,
@@ -1991,7 +1988,7 @@ impl<'db> Type<'db> {
19911988
| (Type::NominalInstance(instance), Type::BytesLiteral(..)) => {
19921989
// A `Type::BytesLiteral()` must be an instance of exactly `bytes`
19931990
// (it cannot be an instance of a `bytes` subclass)
1994-
!KnownClass::Bytes.is_subclass_of(db, instance.class())
1991+
!KnownClass::Bytes.is_subclass_of(db, instance.class)
19951992
}
19961993

19971994
// A class-literal type `X` is always disjoint from an instance type `Y`,
@@ -2012,7 +2009,7 @@ impl<'db> Type<'db> {
20122009
| (Type::NominalInstance(instance), Type::FunctionLiteral(..)) => {
20132010
// A `Type::FunctionLiteral()` must be an instance of exactly `types.FunctionType`
20142011
// (it cannot be an instance of a `types.FunctionType` subclass)
2015-
!KnownClass::FunctionType.is_subclass_of(db, instance.class())
2012+
!KnownClass::FunctionType.is_subclass_of(db, instance.class)
20162013
}
20172014

20182015
(Type::BoundMethod(_), other) | (other, Type::BoundMethod(_)) => KnownClass::MethodType
@@ -2440,7 +2437,7 @@ impl<'db> Type<'db> {
24402437
// i.e. Type::NominalInstance(type). So looking up a name in the MRO of
24412438
// `Type::NominalInstance(type)` is equivalent to looking up the name in the
24422439
// MRO of the class `object`.
2443-
Type::NominalInstance(instance) if instance.class().is_known(db, KnownClass::Type) => {
2440+
Type::NominalInstance(instance) if instance.class.is_known(db, KnownClass::Type) => {
24442441
KnownClass::Object
24452442
.to_class_literal(db)
24462443
.find_name_in_mro_with_policy(db, name, policy)
@@ -2530,7 +2527,7 @@ impl<'db> Type<'db> {
25302527

25312528
Type::Dynamic(_) | Type::Never => Symbol::bound(self).into(),
25322529

2533-
Type::NominalInstance(instance) => instance.class().instance_member(db, name),
2530+
Type::NominalInstance(instance) => instance.class.instance_member(db, name),
25342531

25352532
Type::ProtocolInstance(protocol) => protocol.instance_member(db, name),
25362533

@@ -2978,7 +2975,7 @@ impl<'db> Type<'db> {
29782975

29792976
Type::NominalInstance(instance)
29802977
if matches!(name.as_str(), "major" | "minor")
2981-
&& instance.class().is_known(db, KnownClass::VersionInfo) =>
2978+
&& instance.class.is_known(db, KnownClass::VersionInfo) =>
29822979
{
29832980
let python_version = Program::get(db).python_version(db);
29842981
let segment = if name == "major" {
@@ -3050,7 +3047,7 @@ impl<'db> Type<'db> {
30503047
// resolve the attribute.
30513048
if matches!(
30523049
self.into_nominal_instance()
3053-
.and_then(|instance| instance.class().known(db)),
3050+
.and_then(|instance| instance.class.known(db)),
30543051
Some(KnownClass::ModuleType | KnownClass::GenericAlias)
30553052
) {
30563053
return Symbol::Unbound.into();
@@ -3309,7 +3306,7 @@ impl<'db> Type<'db> {
33093306
}
33103307
},
33113308

3312-
Type::NominalInstance(instance) => match instance.class().known(db) {
3309+
Type::NominalInstance(instance) => match instance.class.known(db) {
33133310
Some(known_class) => known_class.bool(),
33143311
None => try_dunder_bool()?,
33153312
},
@@ -4863,7 +4860,7 @@ impl<'db> Type<'db> {
48634860

48644861
Type::Dynamic(_) => Ok(*self),
48654862

4866-
Type::NominalInstance(instance) => match instance.class().known(db) {
4863+
Type::NominalInstance(instance) => match instance.class.known(db) {
48674864
Some(KnownClass::TypeVar) => Ok(todo_type!(
48684865
"Support for `typing.TypeVar` instances in type expressions"
48694866
)),
@@ -5291,7 +5288,7 @@ impl<'db> Type<'db> {
52915288
}
52925289
Self::GenericAlias(alias) => Some(TypeDefinition::Class(alias.definition(db))),
52935290
Self::NominalInstance(instance) => {
5294-
Some(TypeDefinition::Class(instance.class().definition(db)))
5291+
Some(TypeDefinition::Class(instance.class.definition(db)))
52955292
}
52965293
Self::KnownInstance(instance) => match instance {
52975294
KnownInstanceType::TypeVar(var) => {
@@ -8046,7 +8043,7 @@ impl<'db> SuperOwnerKind<'db> {
80468043
Either::Left(ClassBase::Dynamic(dynamic).mro(db, None))
80478044
}
80488045
SuperOwnerKind::Class(class) => Either::Right(class.iter_mro(db)),
8049-
SuperOwnerKind::Instance(instance) => Either::Right(instance.class().iter_mro(db)),
8046+
SuperOwnerKind::Instance(instance) => Either::Right(instance.class.iter_mro(db)),
80508047
}
80518048
}
80528049

@@ -8062,7 +8059,7 @@ impl<'db> SuperOwnerKind<'db> {
80628059
match self {
80638060
SuperOwnerKind::Dynamic(_) => None,
80648061
SuperOwnerKind::Class(class) => Some(class),
8065-
SuperOwnerKind::Instance(instance) => Some(instance.class()),
8062+
SuperOwnerKind::Instance(instance) => Some(instance.class),
80668063
}
80678064
}
80688065

@@ -8240,7 +8237,7 @@ impl<'db> BoundSuperType<'db> {
82408237
.expect("Calling `find_name_in_mro` on dynamic type should return `Some`")
82418238
}
82428239
SuperOwnerKind::Class(class) => class,
8243-
SuperOwnerKind::Instance(instance) => instance.class(),
8240+
SuperOwnerKind::Instance(instance) => instance.class,
82448241
};
82458242

82468243
let (class_literal, _) = class.class_literal(db);

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
614614
_ => {
615615
let known_instance = new_positive
616616
.into_nominal_instance()
617-
.and_then(|instance| instance.class().known(db));
617+
.and_then(|instance| instance.class.known(db));
618618

619619
if known_instance == Some(KnownClass::Object) {
620620
// `object & T` -> `T`; it is always redundant to add `object` to an intersection
@@ -634,7 +634,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
634634
new_positive = Type::BooleanLiteral(false);
635635
}
636636
Type::NominalInstance(instance)
637-
if instance.class().is_known(db, KnownClass::Bool) =>
637+
if instance.class.is_known(db, KnownClass::Bool) =>
638638
{
639639
match new_positive {
640640
// `bool & AlwaysTruthy` -> `Literal[True]`
@@ -728,7 +728,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
728728
self.positive
729729
.iter()
730730
.filter_map(|ty| ty.into_nominal_instance())
731-
.filter_map(|instance| instance.class().known(db))
731+
.filter_map(|instance| instance.class.known(db))
732732
.any(KnownClass::is_bool)
733733
};
734734

@@ -744,7 +744,7 @@ impl<'db> InnerIntersectionBuilder<'db> {
744744
Type::Never => {
745745
// Adding ~Never to an intersection is a no-op.
746746
}
747-
Type::NominalInstance(instance) if instance.class().is_object(db) => {
747+
Type::NominalInstance(instance) if instance.class.is_object(db) => {
748748
// Adding ~object to an intersection results in Never.
749749
*self = Self::default();
750750
self.positive.insert(Type::Never);

crates/ty_python_semantic/src/types/class.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,7 +2733,7 @@ impl<'db> Type<'db> {
27332733
/// The type must be a specialization of the `slice` builtin type, where the specialized
27342734
/// typevars are statically known integers or `None`.
27352735
pub(crate) fn slice_literal(self, db: &'db dyn Db) -> Option<SliceLiteral> {
2736-
let ClassType::Generic(alias) = self.into_nominal_instance()?.class() else {
2736+
let ClassType::Generic(alias) = self.into_nominal_instance()?.class else {
27372737
return None;
27382738
};
27392739
if !alias.origin(db).is_known(db, KnownClass::Slice) {
@@ -2747,7 +2747,7 @@ impl<'db> Type<'db> {
27472747
Type::IntLiteral(n) => i32::try_from(*n).map(Some).ok(),
27482748
Type::BooleanLiteral(b) => Some(Some(i32::from(*b))),
27492749
Type::NominalInstance(instance)
2750-
if instance.class().is_known(db, KnownClass::NoneType) =>
2750+
if instance.class.is_known(db, KnownClass::NoneType) =>
27512751
{
27522752
Some(None)
27532753
}

crates/ty_python_semantic/src/types/class_base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ impl<'db> ClassBase<'db> {
103103
}
104104
Type::GenericAlias(generic) => Some(Self::Class(ClassType::Generic(generic))),
105105
Type::NominalInstance(instance)
106-
if instance.class().is_known(db, KnownClass::GenericAlias) =>
106+
if instance.class.is_known(db, KnownClass::GenericAlias) =>
107107
{
108108
Self::try_from_type(db, todo_type!("GenericAlias instance"))
109109
}

crates/ty_python_semantic/src/types/display.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl Display for DisplayRepresentation<'_> {
6969
Type::Dynamic(dynamic) => dynamic.fmt(f),
7070
Type::Never => f.write_str("Never"),
7171
Type::NominalInstance(instance) => {
72-
match (instance.class(), instance.class().known(self.db)) {
72+
match (instance.class, instance.class.known(self.db)) {
7373
(_, Some(KnownClass::NoneType)) => f.write_str("None"),
7474
(_, Some(KnownClass::NoDefaultType)) => f.write_str("NoDefault"),
7575
(ClassType::NonGeneric(class), _) => f.write_str(class.name(self.db)),

0 commit comments

Comments
 (0)