Skip to content

Commit a3af30d

Browse files
committed
[ty] infer type for enum members
astral-sh/ty#876
1 parent 1fa64a2 commit a3af30d

File tree

7 files changed

+115
-38
lines changed

7 files changed

+115
-38
lines changed

crates/ruff_memory_usage/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use std::sync::{LazyLock, Mutex};
22

33
use get_size2::{GetSize, StandardTracker};
4+
use ordermap::OrderMap;
45
use ordermap::OrderSet;
56

67
/// Returns the memory usage of the provided object, using a global tracker to avoid
@@ -18,3 +19,12 @@ pub fn heap_size<T: GetSize>(value: &T) -> usize {
1819
pub fn order_set_heap_size<T: GetSize, S>(set: &OrderSet<T, S>) -> usize {
1920
(set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::<usize>()
2021
}
22+
23+
/// An implementation of [`GetSize::get_heap_size`] for [`OrderMap`].
24+
pub fn order_map_heap_size<K: GetSize, V: GetSize, S>(map: &OrderMap<K, V, S>) -> usize {
25+
(map.capacity() * (K::get_stack_size() + V::get_stack_size()))
26+
+ map
27+
.iter()
28+
.map(|(k, v)| heap_size(k) + heap_size(v))
29+
.sum::<usize>()
30+
}

crates/ty_python_semantic/resources/mdtest/attributes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2415,7 +2415,7 @@ class Answer(enum.Enum):
24152415
YES = 1
24162416

24172417
reveal_type(Answer.NO) # revealed: Literal[Answer.NO]
2418-
reveal_type(Answer.NO.value) # revealed: Any
2418+
reveal_type(Answer.NO.value) # revealed: Literal[0]
24192419
reveal_type(Answer.__members__) # revealed: MappingProxyType[str, Unknown]
24202420
```
24212421

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,22 @@
44

55
```py
66
from enum import Enum
7+
from typing import Literal
78

89
class Color(Enum):
910
RED = 1
1011
GREEN = 2
1112
BLUE = 3
1213

1314
reveal_type(Color.RED) # revealed: Literal[Color.RED]
14-
# TODO: This could be `Literal[1]`
15-
reveal_type(Color.RED.value) # revealed: Any
15+
16+
reveal_type(Color.RED.value) # revealed: Literal[1]
17+
18+
reveal_type(Color.RED._value_) # revealed: Literal[1]
19+
20+
reveal_type(Color.RED.name) # revealed: Literal["RED"]
21+
22+
reveal_type(Color.RED._name_) # revealed: Literal["RED"]
1623

1724
# TODO: Should be `Color` or `Literal[Color.RED]`
1825
reveal_type(Color["RED"]) # revealed: Unknown
@@ -21,6 +28,13 @@ reveal_type(Color["RED"]) # revealed: Unknown
2128
reveal_type(Color(1)) # revealed: Color
2229

2330
reveal_type(Color.RED in Color) # revealed: bool
31+
32+
def func1(red_or_blue: Literal[Color.RED, Color.BLUE]):
33+
reveal_type(red_or_blue.name) # revealed: Literal["RED", "BLUE"]
34+
35+
def func2(any_color: Color):
36+
# TODO: Literal["RED", "GREEN", "BLUE"]
37+
reveal_type(any_color.name) # revealed: Any
2438
```
2539

2640
## Enum members
@@ -155,6 +169,7 @@ python-version = "3.11"
155169

156170
```py
157171
from enum import Enum, property as enum_property
172+
from typing import Any
158173
from ty_extensions import enum_members
159174

160175
class Answer(Enum):
@@ -167,6 +182,16 @@ class Answer(Enum):
167182

168183
# revealed: tuple[Literal["YES"], Literal["NO"]]
169184
reveal_type(enum_members(Answer))
185+
186+
class Choices(Enum):
187+
A = 1
188+
B = 2
189+
190+
@enum_property
191+
def value(self) -> Any: ...
192+
193+
# TODO: This should be `Any` - overridden by `@enum_property`
194+
reveal_type(Choices.A.value) # revealed: Literal[1]
170195
```
171196

172197
### `types.DynamicClassAttribute`
@@ -609,6 +634,12 @@ reveal_type(EnumWithSubclassOfEnumMetaMetaclass.NO) # revealed: Literal[EnumWit
609634
# Attributes like `.value` can *not* be accessed on members of these enums:
610635
# error: [unresolved-attribute]
611636
EnumWithSubclassOfEnumMetaMetaclass.NO.value
637+
# error: [unresolved-attribute]
638+
EnumWithSubclassOfEnumMetaMetaclass.NO._value_
639+
# error: [unresolved-attribute]
640+
EnumWithSubclassOfEnumMetaMetaclass.NO.name
641+
# error: [unresolved-attribute]
642+
EnumWithSubclassOfEnumMetaMetaclass.NO._name_
612643
```
613644

614645
### Enums with (subclasses of) `EnumType` as metaclass

crates/ty_python_semantic/src/types.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3528,6 +3528,29 @@ impl<'db> Type<'db> {
35283528
.value_type(db)
35293529
.member_lookup_with_policy(db, name, policy),
35303530

3531+
Type::EnumLiteral(enum_literal)
3532+
if matches!(name_str, "name" | "_name_")
3533+
&& Type::ClassLiteral(enum_literal.enum_class(db))
3534+
.is_subtype_of(db, KnownClass::Enum.to_subclass_of(db)) =>
3535+
{
3536+
Place::bound(Type::StringLiteral(StringLiteralType::new(
3537+
db,
3538+
enum_literal.name(db).as_str(),
3539+
)))
3540+
.into()
3541+
}
3542+
3543+
Type::EnumLiteral(enum_literal)
3544+
if matches!(name_str, "value" | "_value_")
3545+
&& Type::ClassLiteral(enum_literal.enum_class(db))
3546+
.is_subtype_of(db, KnownClass::Enum.to_subclass_of(db)) =>
3547+
{
3548+
enum_metadata(db, enum_literal.enum_class(db))
3549+
.and_then(|metadata| metadata.members(db).get(enum_literal.name(db)))
3550+
.map_or_else(|| Place::Unbound, Place::bound)
3551+
.into()
3552+
}
3553+
35313554
Type::NominalInstance(..)
35323555
| Type::ProtocolInstance(..)
35333556
| Type::BooleanLiteral(..)
@@ -3645,7 +3668,7 @@ impl<'db> Type<'db> {
36453668
_ => None,
36463669
} {
36473670
if let Some(metadata) = enum_metadata(db, enum_class) {
3648-
if let Some(resolved_name) = metadata.resolve_member(&name) {
3671+
if let Some(resolved_name) = metadata.resolve_member(db, &name) {
36493672
return Place::Type(
36503673
Type::EnumLiteral(EnumLiteralType::new(
36513674
db,

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,10 @@ impl<'db> UnionBuilder<'db> {
427427
.collect::<FxOrderSet<_>>();
428428

429429
let all_members_are_in_union = metadata
430-
.members
430+
.members(self.db)
431+
.keys()
432+
.cloned()
433+
.collect::<FxOrderSet<_>>()
431434
.difference(&enum_members_in_union)
432435
.next()
433436
.is_none();

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,8 @@ impl<'db> Bindings<'db> {
723723
Type::heterogeneous_tuple(
724724
db,
725725
metadata
726-
.members
727-
.iter()
726+
.members(db)
727+
.keys()
728728
.map(|member| Type::string_literal(db, member)),
729729
)
730730
} else {

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use std::collections::BTreeMap;
2+
13
use ruff_python_ast::name::Name;
24
use rustc_hash::FxHashMap;
35

46
use crate::{
5-
Db, FxOrderSet,
7+
Db, FxOrderMap,
68
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
79
semantic_index::{place_table, use_def_map},
810
types::{
@@ -11,44 +13,52 @@ use crate::{
1113
},
1214
};
1315

14-
#[derive(Debug, PartialEq, Eq)]
15-
pub(crate) struct EnumMetadata {
16-
pub(crate) members: FxOrderSet<Name>,
17-
pub(crate) aliases: FxHashMap<Name, Name>,
16+
#[salsa::interned(debug, heap_size=EnumMetadata::heap_size)]
17+
pub(crate) struct EnumMetadata<'db> {
18+
#[returns(ref)]
19+
pub(crate) members: FxOrderMap<Name, Type<'db>>,
20+
#[returns(ref)]
21+
pub(crate) aliases: BTreeMap<Name, Name>,
1822
}
1923

20-
impl get_size2::GetSize for EnumMetadata {}
24+
impl get_size2::GetSize for EnumMetadata<'_> {}
2125

22-
impl EnumMetadata {
23-
fn empty() -> Self {
24-
EnumMetadata {
25-
members: FxOrderSet::default(),
26-
aliases: FxHashMap::default(),
27-
}
26+
impl<'db> EnumMetadata<'db> {
27+
fn empty(db: &'db dyn Db) -> Self {
28+
EnumMetadata::new(db, FxOrderMap::default(), BTreeMap::default())
2829
}
2930

30-
pub(crate) fn resolve_member<'a>(&'a self, name: &'a Name) -> Option<&'a Name> {
31-
if self.members.contains(name) {
31+
pub(crate) fn resolve_member<'a>(&'a self, db: &'a dyn Db, name: &'a Name) -> Option<&'a Name> {
32+
if self.members(db).contains_key(name) {
3233
Some(name)
3334
} else {
34-
self.aliases.get(name)
35+
self.aliases(db).get(name)
3536
}
3637
}
38+
39+
fn heap_size(
40+
(members, aliases): &(FxOrderMap<Name, Type<'db>>, BTreeMap<Name, Name>),
41+
) -> usize {
42+
ruff_memory_usage::order_map_heap_size(members) + ruff_memory_usage::heap_size(aliases)
43+
}
3744
}
3845

39-
#[allow(clippy::ref_option)]
40-
fn enum_metadata_cycle_recover(
41-
_db: &dyn Db,
42-
_value: &Option<EnumMetadata>,
46+
#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
47+
fn enum_metadata_cycle_recover<'db>(
48+
_db: &'db dyn Db,
49+
_value: &Option<EnumMetadata<'db>>,
4350
_count: u32,
44-
_class: ClassLiteral<'_>,
45-
) -> salsa::CycleRecoveryAction<Option<EnumMetadata>> {
51+
_class: ClassLiteral<'db>,
52+
) -> salsa::CycleRecoveryAction<Option<EnumMetadata<'db>>> {
4653
salsa::CycleRecoveryAction::Iterate
4754
}
4855

4956
#[allow(clippy::unnecessary_wraps)]
50-
fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option<EnumMetadata> {
51-
Some(EnumMetadata::empty())
57+
fn enum_metadata_cycle_initial<'db>(
58+
db: &'db dyn Db,
59+
_class: ClassLiteral<'db>,
60+
) -> Option<EnumMetadata<'db>> {
61+
Some(EnumMetadata::empty(db))
5262
}
5363

5464
/// List all members of an enum.
@@ -57,7 +67,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
5767
pub(crate) fn enum_metadata<'db>(
5868
db: &'db dyn Db,
5969
class: ClassLiteral<'db>,
60-
) -> Option<EnumMetadata> {
70+
) -> Option<EnumMetadata<'db>> {
6171
// This is a fast path to avoid traversing the MRO of known classes
6272
if class
6373
.known(db)
@@ -97,7 +107,7 @@ pub(crate) fn enum_metadata<'db>(
97107
None
98108
};
99109

100-
let mut aliases = FxHashMap::default();
110+
let mut aliases = BTreeMap::default();
101111

102112
let members = use_def_map
103113
.all_end_of_scope_symbol_bindings()
@@ -217,16 +227,16 @@ pub(crate) fn enum_metadata<'db>(
217227
}
218228
}
219229

220-
Some(name.clone())
230+
Some((name.clone(), value_ty))
221231
})
222-
.collect::<FxOrderSet<_>>();
232+
.collect::<FxOrderMap<_, _>>();
223233

224234
if members.is_empty() {
225235
// Enum subclasses without members are not considered enums.
226236
return None;
227237
}
228238

229-
Some(EnumMetadata { members, aliases })
239+
Some(EnumMetadata::new(db, members, aliases))
230240
}
231241

232242
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
@@ -236,15 +246,15 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>(
236246
) -> Option<impl Iterator<Item = Type<'a>> + 'a> {
237247
enum_metadata(db, class).map(|metadata| {
238248
metadata
239-
.members
240-
.iter()
249+
.members(db)
250+
.keys()
241251
.filter(move |name| Some(*name) != exclude_member)
242252
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
243253
})
244254
}
245255

246256
pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
247-
enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1)
257+
enum_metadata(db, class).is_some_and(|metadata| metadata.members(db).len() == 1)
248258
}
249259

250260
pub(crate) fn is_enum_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {

0 commit comments

Comments
 (0)