Skip to content

Commit 63fd2f4

Browse files
committed
[ty] infer type for enum members
astral-sh/ty#876
1 parent 1fa64a2 commit 63fd2f4

File tree

7 files changed

+110
-28
lines changed

7 files changed

+110
-28
lines changed

crates/ty_python_semantic/resources/mdtest/attributes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2415,7 +2415,7 @@ class Answer(enum.Enum):
24152415
YES = 1
24162416

24172417
reveal_type(Answer.NO) # revealed: Literal[Answer.NO]
2418-
reveal_type(Answer.NO.value) # revealed: Any
2418+
reveal_type(Answer.NO.value) # revealed: Literal[0]
24192419
reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
24202420
```
24212421

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@
44

55
```py
66
from enum import Enum
7+
from typing import Literal
78

89
class Color(Enum):
910
RED = 1
1011
GREEN = 2
1112
BLUE = 3
1213

1314
reveal_type(Color.RED) # revealed: Literal[Color.RED]
14-
# TODO: This could be `Literal[1]`
15-
reveal_type(Color.RED.value) # revealed: Any
15+
reveal_type(Color.RED.name) # revealed: Literal["RED"]
16+
reveal_type(Color.RED.value) # revealed: Literal[1]
1617

1718
# TODO: Should be `Color` or `Literal[Color.RED]`
1819
reveal_type(Color["RED"]) # revealed: Unknown
@@ -50,6 +51,40 @@ class ColorStr(Enum):
5051
reveal_type(enum_members(ColorStr))
5152
```
5253

54+
### Generated `_name_` and `_value_` attributes
55+
56+
```py
57+
from enum import Enum
58+
from typing import Literal
59+
60+
class Color(Enum):
61+
RED = 1
62+
GREEN = 2
63+
BLUE = 3
64+
65+
reveal_type(Color.RED._name_) # revealed: Literal["RED"]
66+
reveal_type(Color.RED._value_) # revealed: Literal[1]
67+
```
68+
69+
### `name` attribute Literal unions
70+
71+
```py
72+
from enum import Enum
73+
from typing import Literal
74+
75+
class Color(Enum):
76+
RED = 1
77+
GREEN = 2
78+
BLUE = 3
79+
80+
def func1(red_or_blue: Literal[Color.RED, Color.BLUE]):
81+
reveal_type(red_or_blue.name) # revealed: Literal["RED", "BLUE"]
82+
83+
def func2(any_color: Color):
84+
# TODO: Literal["RED", "GREEN", "BLUE"]
85+
reveal_type(any_color.name) # revealed: Any
86+
```
87+
5388
### When deriving from `IntEnum`
5489

5590
```py
@@ -155,6 +190,7 @@ python-version = "3.11"
155190

156191
```py
157192
from enum import Enum, property as enum_property
193+
from typing import Any
158194
from ty_extensions import enum_members
159195

160196
class Answer(Enum):
@@ -169,6 +205,22 @@ class Answer(Enum):
169205
reveal_type(enum_members(Answer))
170206
```
171207

208+
Enum attributes defined using `enum.property` take precedence over generated attributes.
209+
210+
```py
211+
from enum import Enum, property as enum_property
212+
213+
class Choices(Enum):
214+
A = 1
215+
B = 2
216+
217+
@enum_property
218+
def value(self) -> Any: ...
219+
220+
# TODO: This should be `Any` - overridden by `@enum_property`
221+
reveal_type(Choices.A.value) # revealed: Literal[1]
222+
```
223+
172224
### `types.DynamicClassAttribute`
173225

174226
Attributes defined using `types.DynamicClassAttribute` are not considered members:
@@ -609,6 +661,12 @@ reveal_type(EnumWithSubclassOfEnumMetaMetaclass.NO) # revealed: Literal[EnumWit
609661
# Attributes like `.value` can *not* be accessed on members of these enums:
610662
# error: [unresolved-attribute]
611663
EnumWithSubclassOfEnumMetaMetaclass.NO.value
664+
# error: [unresolved-attribute]
665+
EnumWithSubclassOfEnumMetaMetaclass.NO._value_
666+
# error: [unresolved-attribute]
667+
EnumWithSubclassOfEnumMetaMetaclass.NO.name
668+
# error: [unresolved-attribute]
669+
EnumWithSubclassOfEnumMetaMetaclass.NO._name_
612670
```
613671

614672
### Enums with (subclasses of) `EnumType` as metaclass

crates/ty_python_semantic/src/lib.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
)]
55
use std::hash::BuildHasherDefault;
66

7-
use rustc_hash::FxHasher;
8-
97
use crate::lint::{LintRegistry, LintRegistryBuilder};
108
use crate::suppression::{INVALID_IGNORE_COMMENT, UNKNOWN_RULE, UNUSED_IGNORE_COMMENT};
119
pub use db::Db;
@@ -19,6 +17,7 @@ pub use program::{
1917
PythonVersionWithSource, SearchPathSettings,
2018
};
2119
pub use python_platform::PythonPlatform;
20+
use rustc_hash::FxHasher;
2221
pub use semantic_model::{
2322
Completion, CompletionKind, HasDefinition, HasType, NameKind, SemanticModel,
2423
};

crates/ty_python_semantic/src/types.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3528,6 +3528,29 @@ impl<'db> Type<'db> {
35283528
.value_type(db)
35293529
.member_lookup_with_policy(db, name, policy),
35303530

3531+
Type::EnumLiteral(enum_literal)
3532+
if matches!(name_str, "name" | "_name_")
3533+
&& Type::ClassLiteral(enum_literal.enum_class(db))
3534+
.is_subtype_of(db, KnownClass::Enum.to_subclass_of(db)) =>
3535+
{
3536+
Place::bound(Type::StringLiteral(StringLiteralType::new(
3537+
db,
3538+
enum_literal.name(db).as_str(),
3539+
)))
3540+
.into()
3541+
}
3542+
3543+
Type::EnumLiteral(enum_literal)
3544+
if matches!(name_str, "value" | "_value_")
3545+
&& Type::ClassLiteral(enum_literal.enum_class(db))
3546+
.is_subtype_of(db, KnownClass::Enum.to_subclass_of(db)) =>
3547+
{
3548+
enum_metadata(db, enum_literal.enum_class(db))
3549+
.and_then(|metadata| metadata.members.get(enum_literal.name(db)))
3550+
.map_or_else(|| Place::Unbound, Place::bound)
3551+
.into()
3552+
}
3553+
35313554
Type::NominalInstance(..)
35323555
| Type::ProtocolInstance(..)
35333556
| Type::BooleanLiteral(..)

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,8 @@ impl<'db> UnionBuilder<'db> {
428428

429429
let all_members_are_in_union = metadata
430430
.members
431-
.difference(&enum_members_in_union)
432-
.next()
433-
.is_none();
431+
.keys()
432+
.all(|name| enum_members_in_union.contains(name));
434433

435434
if all_members_are_in_union {
436435
self.add_in_place_impl(

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ impl<'db> Bindings<'db> {
724724
db,
725725
metadata
726726
.members
727-
.iter()
727+
.keys()
728728
.map(|member| Type::string_literal(db, member)),
729729
)
730730
} else {

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use ruff_python_ast::name::Name;
22
use rustc_hash::FxHashMap;
33

44
use crate::{
5-
Db, FxOrderSet,
5+
Db, FxIndexMap,
66
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
77
semantic_index::{place_table, use_def_map},
88
types::{
@@ -11,43 +11,46 @@ use crate::{
1111
},
1212
};
1313

14-
#[derive(Debug, PartialEq, Eq)]
15-
pub(crate) struct EnumMetadata {
16-
pub(crate) members: FxOrderSet<Name>,
14+
#[derive(Debug, PartialEq, Eq, salsa::Update)]
15+
pub(crate) struct EnumMetadata<'db> {
16+
pub(crate) members: FxIndexMap<Name, Type<'db>>,
1717
pub(crate) aliases: FxHashMap<Name, Name>,
1818
}
1919

20-
impl get_size2::GetSize for EnumMetadata {}
20+
impl get_size2::GetSize for EnumMetadata<'_> {}
2121

22-
impl EnumMetadata {
22+
impl EnumMetadata<'_> {
2323
fn empty() -> Self {
2424
EnumMetadata {
25-
members: FxOrderSet::default(),
25+
members: FxIndexMap::default(),
2626
aliases: FxHashMap::default(),
2727
}
2828
}
2929

3030
pub(crate) fn resolve_member<'a>(&'a self, name: &'a Name) -> Option<&'a Name> {
31-
if self.members.contains(name) {
31+
if self.members.contains_key(name) {
3232
Some(name)
3333
} else {
3434
self.aliases.get(name)
3535
}
3636
}
3737
}
3838

39-
#[allow(clippy::ref_option)]
40-
fn enum_metadata_cycle_recover(
41-
_db: &dyn Db,
42-
_value: &Option<EnumMetadata>,
39+
#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
40+
fn enum_metadata_cycle_recover<'db>(
41+
_db: &'db dyn Db,
42+
_value: &Option<EnumMetadata<'db>>,
4343
_count: u32,
44-
_class: ClassLiteral<'_>,
45-
) -> salsa::CycleRecoveryAction<Option<EnumMetadata>> {
44+
_class: ClassLiteral<'db>,
45+
) -> salsa::CycleRecoveryAction<Option<EnumMetadata<'db>>> {
4646
salsa::CycleRecoveryAction::Iterate
4747
}
4848

4949
#[allow(clippy::unnecessary_wraps)]
50-
fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option<EnumMetadata> {
50+
fn enum_metadata_cycle_initial<'db>(
51+
_db: &'db dyn Db,
52+
_class: ClassLiteral<'db>,
53+
) -> Option<EnumMetadata<'db>> {
5154
Some(EnumMetadata::empty())
5255
}
5356

@@ -57,7 +60,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
5760
pub(crate) fn enum_metadata<'db>(
5861
db: &'db dyn Db,
5962
class: ClassLiteral<'db>,
60-
) -> Option<EnumMetadata> {
63+
) -> Option<EnumMetadata<'db>> {
6164
// This is a fast path to avoid traversing the MRO of known classes
6265
if class
6366
.known(db)
@@ -217,9 +220,9 @@ pub(crate) fn enum_metadata<'db>(
217220
}
218221
}
219222

220-
Some(name.clone())
223+
Some((name.clone(), value_ty))
221224
})
222-
.collect::<FxOrderSet<_>>();
225+
.collect::<FxIndexMap<_, _>>();
223226

224227
if members.is_empty() {
225228
// Enum subclasses without members are not considered enums.
@@ -237,7 +240,7 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>(
237240
enum_metadata(db, class).map(|metadata| {
238241
metadata
239242
.members
240-
.iter()
243+
.keys()
241244
.filter(move |name| Some(*name) != exclude_member)
242245
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
243246
})

0 commit comments

Comments
 (0)