Skip to content

Commit 64830e6

Browse files
committed
[ty] infer type for enum members
astral-sh/ty#876
1 parent 61f906d commit 64830e6

File tree

6 files changed

+79
-35
lines changed

6 files changed

+79
-35
lines changed

crates/ruff_memory_usage/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ pub fn heap_size<T: GetSize>(value: &T) -> usize {
1818
pub fn order_set_heap_size<T: GetSize, S>(set: &OrderSet<T, S>) -> usize {
1919
(set.capacity() * T::get_stack_size()) + set.iter().map(heap_size).sum::<usize>()
2020
}
21+
22+
/// An implementation of [`GetSize::get_heap_size`] for [`OrderMap`].
23+
pub fn order_map_heap_size<K: GetSize, V: GetSize, S>(map: &ordermap::OrderMap<K, V, S>) -> usize {
24+
(map.capacity() * (K::get_stack_size() + V::get_stack_size()))
25+
+ map
26+
.iter()
27+
.map(|(k, v)| heap_size(k) + heap_size(v))
28+
.sum::<usize>()
29+
}

crates/ty_python_semantic/resources/mdtest/enums.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
```py
66
from enum import Enum
7+
from typing import Literal
78

89
class Color(Enum):
910
RED = 1
@@ -13,6 +14,12 @@ class Color(Enum):
1314
reveal_type(Color.RED) # revealed: Literal[Color.RED]
1415
# TODO: This could be `Literal[1]`
1516
reveal_type(Color.RED.value) # revealed: Any
17+
# TODO: This could be `Literal[1]`
18+
reveal_type(Color.RED._value_) # revealed: Any
19+
20+
reveal_type(Color.RED.name) # revealed: Literal["RED"]
21+
22+
reveal_type(Color.RED._name_) # revealed: Literal["RED"]
1623

1724
# TODO: Should be `Color` or `Literal[Color.RED]`
1825
reveal_type(Color["RED"]) # revealed: Unknown
@@ -21,6 +28,13 @@ reveal_type(Color["RED"]) # revealed: Unknown
2128
reveal_type(Color(1)) # revealed: Color
2229

2330
reveal_type(Color.RED in Color) # revealed: bool
31+
32+
def func1(red_or_blue: Literal[Color.RED, Color.BLUE]):
33+
reveal_type(red_or_blue.name) # revealed: Literal["RED", "BLUE"]
34+
35+
def func2(any_color: Color):
36+
# TODO: Literal["RED", "GREEN", "BLUE"]
37+
reveal_type(any_color.name) # revealed: Any
2438
```
2539

2640
## Enum members

crates/ty_python_semantic/src/types.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3465,6 +3465,14 @@ impl<'db> Type<'db> {
34653465
.value_type(db)
34663466
.member_lookup_with_policy(db, name, policy),
34673467

3468+
Type::EnumLiteral(enum_literal) if matches!(name_str, "name" | "_name_") => {
3469+
Place::bound(Type::StringLiteral(StringLiteralType::new(
3470+
db,
3471+
enum_literal.name(db).as_str(),
3472+
)))
3473+
.into()
3474+
}
3475+
34683476
Type::NominalInstance(..)
34693477
| Type::ProtocolInstance(..)
34703478
| Type::BooleanLiteral(..)
@@ -3582,7 +3590,7 @@ impl<'db> Type<'db> {
35823590
_ => None,
35833591
} {
35843592
if let Some(metadata) = enum_metadata(db, enum_class) {
3585-
if let Some(resolved_name) = metadata.resolve_member(&name) {
3593+
if let Some(resolved_name) = metadata.resolve_member(db, &name) {
35863594
return Place::Type(
35873595
Type::EnumLiteral(EnumLiteralType::new(
35883596
db,

crates/ty_python_semantic/src/types/builder.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,10 @@ impl<'db> UnionBuilder<'db> {
428428
.collect::<FxOrderSet<_>>();
429429

430430
let all_members_are_in_union = metadata
431-
.members
431+
.members(self.db)
432+
.keys()
433+
.cloned()
434+
.collect::<FxOrderSet<_>>()
432435
.difference(&enum_members_in_union)
433436
.next()
434437
.is_none();

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ impl<'db> Bindings<'db> {
710710
Type::heterogeneous_tuple(
711711
db,
712712
metadata
713-
.members
714-
.iter()
713+
.members(db)
714+
.keys()
715715
.map(|member| Type::string_literal(db, member)),
716716
)
717717
} else {

crates/ty_python_semantic/src/types/enums.rs

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
use std::collections::BTreeMap;
2+
13
use ruff_python_ast::name::Name;
24
use rustc_hash::FxHashMap;
35

46
use crate::{
5-
Db, FxOrderSet,
7+
Db, FxOrderMap,
68
place::{Place, PlaceAndQualifiers, place_from_bindings, place_from_declarations},
79
semantic_index::{place_table, use_def_map},
810
types::{
@@ -11,44 +13,52 @@ use crate::{
1113
},
1214
};
1315

14-
#[derive(Debug, PartialEq, Eq)]
15-
pub(crate) struct EnumMetadata {
16-
pub(crate) members: FxOrderSet<Name>,
17-
pub(crate) aliases: FxHashMap<Name, Name>,
16+
#[salsa::interned(debug, heap_size=EnumMetadata::heap_size)]
17+
pub(crate) struct EnumMetadata<'db> {
18+
#[returns(ref)]
19+
pub(crate) members: FxOrderMap<Name, Type<'db>>,
20+
#[returns(ref)]
21+
pub(crate) aliases: BTreeMap<Name, Name>,
1822
}
1923

20-
impl get_size2::GetSize for EnumMetadata {}
24+
impl get_size2::GetSize for EnumMetadata<'_> {}
2125

22-
impl EnumMetadata {
23-
fn empty() -> Self {
24-
EnumMetadata {
25-
members: FxOrderSet::default(),
26-
aliases: FxHashMap::default(),
27-
}
26+
impl<'db> EnumMetadata<'db> {
27+
fn empty(db: &'db dyn Db) -> Self {
28+
EnumMetadata::new(db, FxOrderMap::default(), BTreeMap::default())
2829
}
2930

30-
pub(crate) fn resolve_member<'a>(&'a self, name: &'a Name) -> Option<&'a Name> {
31-
if self.members.contains(name) {
31+
pub(crate) fn resolve_member<'a>(&'a self, db: &'a dyn Db, name: &'a Name) -> Option<&'a Name> {
32+
if self.members(db).contains_key(name) {
3233
Some(name)
3334
} else {
34-
self.aliases.get(name)
35+
self.aliases(db).get(name)
3536
}
3637
}
38+
39+
fn heap_size(
40+
(members, aliases): &(FxOrderMap<Name, Type<'db>>, BTreeMap<Name, Name>),
41+
) -> usize {
42+
ruff_memory_usage::order_map_heap_size(members) + ruff_memory_usage::heap_size(aliases)
43+
}
3744
}
3845

39-
#[allow(clippy::ref_option)]
40-
fn enum_metadata_cycle_recover(
41-
_db: &dyn Db,
42-
_value: &Option<EnumMetadata>,
46+
#[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
47+
fn enum_metadata_cycle_recover<'db>(
48+
_db: &'db dyn Db,
49+
_value: &Option<EnumMetadata<'db>>,
4350
_count: u32,
44-
_class: ClassLiteral<'_>,
45-
) -> salsa::CycleRecoveryAction<Option<EnumMetadata>> {
51+
_class: ClassLiteral<'db>,
52+
) -> salsa::CycleRecoveryAction<Option<EnumMetadata<'db>>> {
4653
salsa::CycleRecoveryAction::Iterate
4754
}
4855

4956
#[allow(clippy::unnecessary_wraps)]
50-
fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option<EnumMetadata> {
51-
Some(EnumMetadata::empty())
57+
fn enum_metadata_cycle_initial<'db>(
58+
db: &'db dyn Db,
59+
_class: ClassLiteral<'db>,
60+
) -> Option<EnumMetadata<'db>> {
61+
Some(EnumMetadata::empty(db))
5262
}
5363

5464
/// List all members of an enum.
@@ -57,7 +67,7 @@ fn enum_metadata_cycle_initial(_db: &dyn Db, _class: ClassLiteral<'_>) -> Option
5767
pub(crate) fn enum_metadata<'db>(
5868
db: &'db dyn Db,
5969
class: ClassLiteral<'db>,
60-
) -> Option<EnumMetadata> {
70+
) -> Option<EnumMetadata<'db>> {
6171
// This is a fast path to avoid traversing the MRO of known classes
6272
if class
6373
.known(db)
@@ -97,7 +107,7 @@ pub(crate) fn enum_metadata<'db>(
97107
None
98108
};
99109

100-
let mut aliases = FxHashMap::default();
110+
let mut aliases = BTreeMap::default();
101111

102112
let members = use_def_map
103113
.all_end_of_scope_symbol_bindings()
@@ -217,16 +227,16 @@ pub(crate) fn enum_metadata<'db>(
217227
}
218228
}
219229

220-
Some(name.clone())
230+
Some((name.clone(), value_ty))
221231
})
222-
.collect::<FxOrderSet<_>>();
232+
.collect::<FxOrderMap<_, _>>();
223233

224234
if members.is_empty() {
225235
// Enum subclasses without members are not considered enums.
226236
return None;
227237
}
228238

229-
Some(EnumMetadata { members, aliases })
239+
Some(EnumMetadata::new(db, members, aliases))
230240
}
231241

232242
pub(crate) fn enum_member_literals<'a, 'db: 'a>(
@@ -236,15 +246,15 @@ pub(crate) fn enum_member_literals<'a, 'db: 'a>(
236246
) -> Option<impl Iterator<Item = Type<'a>> + 'a> {
237247
enum_metadata(db, class).map(|metadata| {
238248
metadata
239-
.members
240-
.iter()
249+
.members(db)
250+
.keys()
241251
.filter(move |name| Some(*name) != exclude_member)
242252
.map(move |name| Type::EnumLiteral(EnumLiteralType::new(db, class, name.clone())))
243253
})
244254
}
245255

246256
pub(crate) fn is_single_member_enum<'db>(db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
247-
enum_metadata(db, class).is_some_and(|metadata| metadata.members.len() == 1)
257+
enum_metadata(db, class).is_some_and(|metadata| metadata.members(db).len() == 1)
248258
}
249259

250260
pub(crate) fn is_enum_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool {

0 commit comments

Comments
 (0)