Skip to content

Commit da7ed1f

Browse files
committed
[ty] infer type for enum members
astral-sh/ty#876
1 parent 61f906d commit da7ed1f

File tree

5 files changed

+33
-8
lines changed

5 files changed

+33
-8
lines changed

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
```py
66
from enum import Enum
7+
from typing import Literal
78

89
class Color(Enum):
910
RED = 1
@@ -13,6 +14,12 @@ class Color(Enum):
1314
reveal_type(Color.RED) # revealed: Literal[Color.RED]
1415
# TODO: This could be `Literal[1]`
1516
reveal_type(Color.RED.value) # revealed: Any
17+
# TODO: This could be `Literal[1]`
18+
reveal_type(Color.RED._value_) # revealed: Any
19+
20+
reveal_type(Color.RED.name) # revealed: Literal["RED"]
21+
22+
reveal_type(Color.RED._name_) # revealed: Literal["RED"]
1623

1724
# TODO: Should be `Color` or `Literal[Color.RED]`
1825
reveal_type(Color["RED"]) # revealed: Unknown
@@ -21,6 +28,13 @@ reveal_type(Color["RED"]) # revealed: Unknown
2128
reveal_type(Color(1)) # revealed: Color
2229

2330
reveal_type(Color.RED in Color) # revealed: bool
31+
32+
def func1(red_or_blue: Literal[Color.RED, Color.BLUE]):
33+
reveal_type(red_or_blue.name) # revealed: Literal["RED", "BLUE"]
34+
35+
def func2(any_color: Color):
36+
# TODO: Literal["RED", "GREEN", "BLUE"]
37+
reveal_type(any_color.name) # revealed: Any
2438
```
2539

2640
## Enum members

crates/ty_python_semantic/src/types.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3465,6 +3465,14 @@ impl<'db> Type<'db> {
34653465
.value_type(db)
34663466
.member_lookup_with_policy(db, name, policy),
34673467

3468+
Type::EnumLiteral(enum_literal) if matches!(name_str, "name" | "_name_") => {
3469+
Place::bound(Type::StringLiteral(StringLiteralType::new(
3470+
db,
3471+
enum_literal.name(db).as_str(),
3472+
)))
3473+
.into()
3474+
}
3475+
34683476
Type::NominalInstance(..)
34693477
| Type::ProtocolInstance(..)
34703478
| Type::BooleanLiteral(..)

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,9 @@ impl<'db> UnionBuilder<'db> {
429429

430430
let all_members_are_in_union = metadata
431431
.members
432+
.keys()
433+
.cloned()
434+
.collect::<FxOrderSet<_>>()
432435
.difference(&enum_members_in_union)
433436
.next()
434437
.is_none();

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ impl<'db> Bindings<'db> {
711711
db,
712712
metadata
713713
.members
714-
.iter()
714+
.keys()
715715
.map(|member| Type::string_literal(db, member)),
716716
)
717717
} else {

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use ruff_python_ast::name::Name;
22
use rustc_hash::FxHashMap;
33

44
use crate::{
5-
Db, FxOrderSet,
5+
Db, FxOrderMap,
66
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
77
semantic_index::{place_table, use_def_map},
88
types::{
@@ -13,7 +13,7 @@ use crate::{
1313

1414
#[derive(Debug, PartialEq, Eq)]
1515
pub(crate) struct EnumMetadata {
16-
pub(crate) members: FxOrderSet<Name>,
16+
pub(crate) members: FxOrderMap<Name, Name>,
1717
pub(crate) aliases: FxHashMap<Name, Name>,
1818
}
1919

@@ -22,13 +22,13 @@ impl get_size2::GetSize for EnumMetadata {}
2222
impl EnumMetadata {
2323
fn empty() -> Self {
2424
EnumMetadata {
25-
members: FxOrderSet::default(),
25+
members: FxOrderMap::default(),
2626
aliases: FxHashMap::default(),
2727
}
2828
}
2929

3030
pub(crate) fn resolve_member<'a>(&'a self, name: &'a Name) -> Option<&'a Name> {
31-
if self.members.contains(name) {
31+
if self.members.contains_key(name) {
3232
Some(name)
3333
} else {
3434
self.aliases.get(name)
@@ -217,9 +217,9 @@ pub(crate) fn enum_metadata<'db>(
217217
}
218218
}
219219

220-
Some(name.clone())
220+
Some((name.clone(), name.clone()))
221221
})
222-
.collect::<FxOrderSet<_>>();
222+
.collect::<FxOrderMap<_, _>>();
223223

224224
if members.is_empty() {
225225
// Enum subclasses without members are not considered enums.
@@ -237,7 +237,7 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>(
237237
enum_metadata(db, class).map(|metadata| {
238238
metadata
239239
.members
240-
.iter()
240+
.keys()
241241
.filter(move |name| Some(*name) != exclude_member)
242242
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
243243
})

0 commit comments

Comments
 (0)