Skip to content

Commit 457167e

Browse files
committed
[ty] Add support for Literals in implicit type aliases
1 parent 132d10f commit 457167e

File tree

6 files changed

+156
-68
lines changed

6 files changed

+156
-68
lines changed

crates/ty_python_semantic/resources/mdtest/annotations/literal.md

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -181,30 +181,20 @@ def _(
181181
bool2: Literal[Bool2],
182182
multiple: Literal[SingleInt, SingleStr, SingleEnum],
183183
):
184-
# TODO should be `Literal[1]`
185-
reveal_type(single_int) # revealed: @Todo(Inference of subscript on special form)
186-
# TODO should be `Literal["foo"]`
187-
reveal_type(single_str) # revealed: @Todo(Inference of subscript on special form)
188-
# TODO should be `Literal[b"bar"]`
189-
reveal_type(single_bytes) # revealed: @Todo(Inference of subscript on special form)
190-
# TODO should be `Literal[True]`
191-
reveal_type(single_bool) # revealed: @Todo(Inference of subscript on special form)
192-
# TODO should be `None`
193-
reveal_type(single_none) # revealed: @Todo(Inference of subscript on special form)
194-
# TODO should be `Literal[E.A]`
195-
reveal_type(single_enum) # revealed: @Todo(Inference of subscript on special form)
196-
# TODO should be `Literal[1, "foo", b"bar", True, E.A] | None`
197-
reveal_type(union_literals) # revealed: @Todo(Inference of subscript on special form)
184+
reveal_type(single_int) # revealed: Literal[1]
185+
reveal_type(single_str) # revealed: Literal["foo"]
186+
reveal_type(single_bytes) # revealed: Literal[b"bar"]
187+
reveal_type(single_bool) # revealed: Literal[True]
188+
reveal_type(single_none) # revealed: None
189+
reveal_type(single_enum) # revealed: Literal[E.A]
190+
reveal_type(union_literals) # revealed: Literal[1, "foo", b"bar", True, E.A] | None
198191
# Could also be `E`
199192
reveal_type(an_enum1) # revealed: Unknown
200-
# TODO should be `E`
201-
reveal_type(an_enum2) # revealed: @Todo(Inference of subscript on special form)
193+
reveal_type(an_enum2) # revealed: E
202194
# Could also be `bool`
203195
reveal_type(bool1) # revealed: Unknown
204-
# TODO should be `bool`
205-
reveal_type(bool2) # revealed: @Todo(Inference of subscript on special form)
206-
# TODO should be `Literal[1, "foo", E.A]`
207-
reveal_type(multiple) # revealed: @Todo(Inference of subscript on special form)
196+
reveal_type(bool2) # revealed: bool
197+
reveal_type(multiple) # revealed: Literal[1, "foo", E.A]
208198
```
209199

210200
### Implicit type alias
@@ -246,28 +236,18 @@ def _(
246236
bool2: Literal[Bool2],
247237
multiple: Literal[SingleInt, SingleStr, SingleEnum],
248238
):
249-
# TODO should be `Literal[1]`
250-
reveal_type(single_int) # revealed: @Todo(Inference of subscript on special form)
251-
# TODO should be `Literal["foo"]`
252-
reveal_type(single_str) # revealed: @Todo(Inference of subscript on special form)
253-
# TODO should be `Literal[b"bar"]`
254-
reveal_type(single_bytes) # revealed: @Todo(Inference of subscript on special form)
255-
# TODO should be `Literal[True]`
256-
reveal_type(single_bool) # revealed: @Todo(Inference of subscript on special form)
257-
# TODO should be `None`
258-
reveal_type(single_none) # revealed: @Todo(Inference of subscript on special form)
259-
# TODO should be `Literal[E.A]`
260-
reveal_type(single_enum) # revealed: @Todo(Inference of subscript on special form)
261-
# TODO should be `Literal[1, "foo", b"bar", True, E.A] | None`
262-
reveal_type(union_literals) # revealed: @Todo(Inference of subscript on special form)
239+
reveal_type(single_int) # revealed: Literal[1]
240+
reveal_type(single_str) # revealed: Literal["foo"]
241+
reveal_type(single_bytes) # revealed: Literal[b"bar"]
242+
reveal_type(single_bool) # revealed: Literal[True]
243+
reveal_type(single_none) # revealed: None
244+
reveal_type(single_enum) # revealed: Literal[E.A]
245+
reveal_type(union_literals) # revealed: Literal[1, "foo", b"bar", True, E.A] | None
263246
reveal_type(an_enum1) # revealed: Unknown
264-
# TODO should be `E`
265-
reveal_type(an_enum2) # revealed: @Todo(Inference of subscript on special form)
247+
reveal_type(an_enum2) # revealed: E
266248
reveal_type(bool1) # revealed: Unknown
267-
# TODO should be `bool`
268-
reveal_type(bool2) # revealed: @Todo(Inference of subscript on special form)
269-
# TODO should be `Literal[1, "foo", E.A]`
270-
reveal_type(multiple) # revealed: @Todo(Inference of subscript on special form)
249+
reveal_type(bool2) # revealed: bool
250+
reveal_type(multiple) # revealed: Literal[1, "foo", E.A]
271251
```
272252

273253
## Shortening unions of literals

crates/ty_python_semantic/resources/mdtest/implicit_type_aliases.md

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ g(None)
3333
We also support unions in type aliases:
3434

3535
```py
36-
from typing_extensions import Any, Never
36+
from typing_extensions import Any, Never, Literal
3737
from ty_extensions import Unknown
3838

3939
IntOrStr = int | str
@@ -54,6 +54,8 @@ NeverOrAny = Never | Any
5454
AnyOrNever = Any | Never
5555
UnknownOrInt = Unknown | int
5656
IntOrUnknown = int | Unknown
57+
StrOrZero = str | Literal[0]
58+
ZeroOrStr = Literal[0] | str
5759

5860
reveal_type(IntOrStr) # revealed: types.UnionType
5961
reveal_type(IntOrStrOrBytes1) # revealed: types.UnionType
@@ -73,6 +75,8 @@ reveal_type(NeverOrAny) # revealed: types.UnionType
7375
reveal_type(AnyOrNever) # revealed: types.UnionType
7476
reveal_type(UnknownOrInt) # revealed: types.UnionType
7577
reveal_type(IntOrUnknown) # revealed: types.UnionType
78+
reveal_type(StrOrZero) # revealed: types.UnionType
79+
reveal_type(ZeroOrStr) # revealed: types.UnionType
7680

7781
def _(
7882
int_or_str: IntOrStr,
@@ -93,6 +97,8 @@ def _(
9397
any_or_never: AnyOrNever,
9498
unknown_or_int: UnknownOrInt,
9599
int_or_unknown: IntOrUnknown,
100+
str_or_zero: StrOrZero,
101+
zero_or_str: ZeroOrStr,
96102
):
97103
reveal_type(int_or_str) # revealed: int | str
98104
reveal_type(int_or_str_or_bytes1) # revealed: int | str | bytes
@@ -112,6 +118,8 @@ def _(
112118
reveal_type(any_or_never) # revealed: Any
113119
reveal_type(unknown_or_int) # revealed: Unknown | int
114120
reveal_type(int_or_unknown) # revealed: int | Unknown
121+
reveal_type(str_or_zero) # revealed: str | Literal[0]
122+
reveal_type(zero_or_str) # revealed: Literal[0] | str
115123
```
116124

117125
If a type is unioned with itself in a value expression, the result is just that type. No
@@ -255,6 +263,60 @@ def _(list_or_tuple: ListOrTuple[int]):
255263
reveal_type(list_or_tuple) # revealed: @Todo(Generic specialization of types.UnionType)
256264
```
257265

266+
## `Literal`s
267+
268+
We also support `typing.Literal` in implicit type aliases.
269+
270+
```py
271+
from typing import Literal
272+
from enum import Enum
273+
274+
IntLiteral1 = Literal[26]
275+
IntLiteral2 = Literal[0x1A]
276+
IntLiterals = Literal[-1, 0, 1]
277+
NestedLiteral = Literal[Literal[1]]
278+
StringLiteral = Literal["a"]
279+
BytesLiteral = Literal[b"b"]
280+
BoolLiteral = Literal[True]
281+
MixedLiterals = Literal[1, "a", True, None]
282+
283+
class Color(Enum):
284+
RED = 0
285+
GREEN = 1
286+
BLUE = 2
287+
288+
EnumLiteral = Literal[Color.RED]
289+
290+
def _(
291+
int_literal1: IntLiteral1,
292+
int_literal2: IntLiteral2,
293+
int_literals: IntLiterals,
294+
nested_literal: NestedLiteral,
295+
string_literal: StringLiteral,
296+
bytes_literal: BytesLiteral,
297+
bool_literal: BoolLiteral,
298+
mixed_literals: MixedLiterals,
299+
enum_literal: EnumLiteral,
300+
):
301+
reveal_type(int_literal1) # revealed: Literal[26]
302+
reveal_type(int_literal2) # revealed: Literal[26]
303+
reveal_type(int_literals) # revealed: Literal[-1, 0, 1]
304+
reveal_type(nested_literal) # revealed: Literal[1]
305+
reveal_type(string_literal) # revealed: Literal["a"]
306+
reveal_type(bytes_literal) # revealed: Literal[b"b"]
307+
reveal_type(bool_literal) # revealed: Literal[True]
308+
reveal_type(mixed_literals) # revealed: Literal[1, "a", True] | None
309+
reveal_type(enum_literal) # revealed: Literal[Color.RED]
310+
```
311+
312+
We reject invalid uses:
313+
314+
```py
315+
# error: [invalid-type-form] "`typing.Literal` instances are not allowed in type expressions"
316+
def _(weird: IntLiteral1[int]):
317+
reveal_type(weird) # revealed: Unknown
318+
```
319+
258320
## Stringified annotations?
259321

260322
From the [typing spec on type aliases](https://typing.python.org/en/latest/spec/aliases.html):

crates/ty_python_semantic/src/types.rs

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6444,9 +6444,9 @@ impl<'db> Type<'db> {
64446444
invalid_expressions: smallvec::smallvec_inline![InvalidTypeExpression::Generic],
64456445
fallback_type: Type::unknown(),
64466446
}),
6447-
KnownInstanceType::UnionType(union_type) => {
6447+
KnownInstanceType::UnionType(list) => {
64486448
let mut builder = UnionBuilder::new(db);
6449-
for element in union_type.elements(db) {
6449+
for element in list.elements(db) {
64506450
builder = builder.add(element.in_type_expression(
64516451
db,
64526452
scope_id,
@@ -6455,6 +6455,7 @@ impl<'db> Type<'db> {
64556455
}
64566456
Ok(builder.build())
64576457
}
6458+
KnownInstanceType::Literal(list) => Ok(list.to_union(db)),
64586459
},
64596460

64606461
Type::SpecialForm(special_form) => match special_form {
@@ -7668,7 +7669,10 @@ pub enum KnownInstanceType<'db> {
76687669

76697670
/// A single instance of `types.UnionType`, which stores the left- and
76707671
/// right-hand sides of a PEP 604 union.
7671-
UnionType(UnionTypeInstance<'db>),
7672+
UnionType(TypeList<'db>),
7673+
7674+
/// A single instance of `typing.Literal`
7675+
Literal(TypeList<'db>),
76727676
}
76737677

76747678
fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
@@ -7695,9 +7699,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
76957699
visitor.visit_type(db, default_ty);
76967700
}
76977701
}
7698-
KnownInstanceType::UnionType(union_type) => {
7699-
for element in union_type.elements(db) {
7700-
visitor.visit_type(db, element);
7702+
KnownInstanceType::UnionType(list) | KnownInstanceType::Literal(list) => {
7703+
for element in list.elements(db) {
7704+
visitor.visit_type(db, *element);
77017705
}
77027706
}
77037707
}
@@ -7736,7 +7740,8 @@ impl<'db> KnownInstanceType<'db> {
77367740
// Nothing to normalize
77377741
Self::ConstraintSet(set)
77387742
}
7739-
Self::UnionType(union_type) => Self::UnionType(union_type.normalized_impl(db, visitor)),
7743+
Self::UnionType(list) => Self::UnionType(list.normalized_impl(db, visitor)),
7744+
Self::Literal(list) => Self::Literal(list.normalized_impl(db, visitor)),
77407745
}
77417746
}
77427747

@@ -7752,6 +7757,7 @@ impl<'db> KnownInstanceType<'db> {
77527757
Self::Field(_) => KnownClass::Field,
77537758
Self::ConstraintSet(_) => KnownClass::ConstraintSet,
77547759
Self::UnionType(_) => KnownClass::UnionType,
7760+
Self::Literal(_) => KnownClass::GenericAlias,
77557761
}
77567762
}
77577763

@@ -7826,6 +7832,7 @@ impl<'db> KnownInstanceType<'db> {
78267832
)
78277833
}
78287834
KnownInstanceType::UnionType(_) => f.write_str("types.UnionType"),
7835+
KnownInstanceType::Literal(_) => f.write_str("typing.Literal"),
78297836
}
78307837
}
78317838
}
@@ -8949,32 +8956,46 @@ impl<'db> TypeVarBoundOrConstraints<'db> {
89498956
}
89508957
}
89518958

8952-
/// An instance of `types.UnionType`.
8959+
/// A salsa-interned list of types.
89538960
///
89548961
/// # Ordering
89558962
/// Ordering is based on the context's salsa-assigned id and not on its values.
89568963
/// The id may change between runs, or when the context was garbage collected and recreated.
89578964
#[salsa::interned(debug)]
89588965
#[derive(PartialOrd, Ord)]
8959-
pub struct UnionTypeInstance<'db> {
8960-
left: Type<'db>,
8961-
right: Type<'db>,
8966+
pub struct TypeList<'db> {
8967+
#[returns(deref)]
8968+
elements: Box<[Type<'db>]>,
89628969
}
89638970

8964-
impl get_size2::GetSize for UnionTypeInstance<'_> {}
8971+
impl get_size2::GetSize for TypeList<'_> {}
8972+
8973+
impl<'db> TypeList<'db> {
8974+
pub(crate) fn from_elements(
8975+
db: &'db dyn Db,
8976+
elements: impl IntoIterator<Item = Type<'db>>,
8977+
) -> TypeList<'db> {
8978+
TypeList::new(db, elements.into_iter().collect::<Box<[_]>>())
8979+
}
89658980

8966-
impl<'db> UnionTypeInstance<'db> {
8967-
pub(crate) fn elements(self, db: &'db dyn Db) -> [Type<'db>; 2] {
8968-
[self.left(db), self.right(db)]
8981+
pub(crate) fn singleton(db: &'db dyn Db, element: Type<'db>) -> TypeList<'db> {
8982+
TypeList::from_elements(db, [element])
89698983
}
89708984

89718985
pub(crate) fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
8972-
UnionTypeInstance::new(
8986+
TypeList::new(
89738987
db,
8974-
self.left(db).normalized_impl(db, visitor),
8975-
self.right(db).normalized_impl(db, visitor),
8988+
self.elements(db)
8989+
.iter()
8990+
.map(|ty| ty.normalized_impl(db, visitor))
8991+
.collect::<Box<[_]>>(),
89768992
)
89778993
}
8994+
8995+
/// Turn this list of types `[T1, T2, ...]` into a union type `T1 | T2 | ...`.
8996+
pub(crate) fn to_union(self, db: &'db dyn Db) -> Type<'db> {
8997+
UnionType::from_elements(db, self.elements(db))
8998+
}
89788999
}
89799000

89809001
/// Error returned if a type is not awaitable.

crates/ty_python_semantic/src/types/class_base.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ impl<'db> ClassBase<'db> {
171171
| KnownInstanceType::Deprecated(_)
172172
| KnownInstanceType::Field(_)
173173
| KnownInstanceType::ConstraintSet(_)
174-
| KnownInstanceType::UnionType(_) => None,
174+
| KnownInstanceType::UnionType(_)
175+
| KnownInstanceType::Literal(_) => None,
175176
},
176177

177178
Type::SpecialForm(special_form) => match special_form {

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ use crate::types::{
102102
DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType,
103103
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
104104
Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type,
105-
TypeAliasType, TypeAndQualifiers, TypeContext, TypeQualifiers,
105+
TypeAliasType, TypeAndQualifiers, TypeContext, TypeList, TypeQualifiers,
106106
TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarIdentity,
107107
TypeVarInstance, TypeVarKind, TypeVarVariance, TypedDictType, UnionBuilder, UnionType,
108-
UnionTypeInstance, binding_type, todo_type,
108+
binding_type, todo_type,
109109
};
110110
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
111111
use crate::unpack::{EvaluationMode, UnpackPosition};
@@ -8473,19 +8473,23 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
84738473
| Type::SubclassOf(..)
84748474
| Type::GenericAlias(..)
84758475
| Type::SpecialForm(_)
8476-
| Type::KnownInstance(KnownInstanceType::UnionType(_)),
8476+
| Type::KnownInstance(
8477+
KnownInstanceType::UnionType(_) | KnownInstanceType::Literal(_),
8478+
),
84778479
Type::ClassLiteral(..)
84788480
| Type::SubclassOf(..)
84798481
| Type::GenericAlias(..)
84808482
| Type::SpecialForm(_)
8481-
| Type::KnownInstance(KnownInstanceType::UnionType(_)),
8483+
| Type::KnownInstance(
8484+
KnownInstanceType::UnionType(_) | KnownInstanceType::Literal(_),
8485+
),
84828486
ast::Operator::BitOr,
84838487
) if Program::get(self.db()).python_version(self.db()) >= PythonVersion::PY310 => {
84848488
if left_ty.is_equivalent_to(self.db(), right_ty) {
84858489
Some(left_ty)
84868490
} else {
84878491
Some(Type::KnownInstance(KnownInstanceType::UnionType(
8488-
UnionTypeInstance::new(self.db(), left_ty, right_ty),
8492+
TypeList::from_elements(self.db(), [left_ty, right_ty]),
84898493
)))
84908494
}
84918495
}
@@ -8510,7 +8514,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
85108514
&& instance.has_known_class(self.db(), KnownClass::NoneType) =>
85118515
{
85128516
Some(Type::KnownInstance(KnownInstanceType::UnionType(
8513-
UnionTypeInstance::new(self.db(), left_ty, right_ty),
8517+
TypeList::from_elements(self.db(), [left_ty, right_ty]),
85148518
)))
85158519
}
85168520

@@ -9643,6 +9647,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
96439647
);
96449648
}
96459649
}
9650+
if value_ty == Type::SpecialForm(SpecialFormType::Literal) {
9651+
let result = self
9652+
.infer_literal_parameter_type(slice)
9653+
.unwrap_or_else(|_| Type::unknown());
9654+
return Type::KnownInstance(KnownInstanceType::Literal(TypeList::singleton(
9655+
self.db(),
9656+
result,
9657+
)));
9658+
}
96469659

96479660
let slice_ty = self.infer_expression(slice, TypeContext::default());
96489661
let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);

0 commit comments

Comments
 (0)