Skip to content

Commit 2bb499d

Browse files
committed
Infer user-defined enum classes by checking if the class is a subtype of enum.Enum.
Closes pylint-dev/pylint#8897
1 parent 3752f93 commit 2bb499d

File tree

3 files changed

+41
-17
lines changed

3 files changed

+41
-17
lines changed

ChangeLog

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ Release date: TBA
221221

222222
Closes pylint-dev/pylint#8802
223223

224+
* Infer user-defined enum classes by checking if the class is a subtype of ``enum.Enum``.
225+
226+
Closes pylint-dev/pylint#8897
227+
224228
* Fix inference of functions with ``@functools.lru_cache`` decorators without
225229
parentheses.
226230

astroid/brain/brain_namedtuple_enum.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,10 @@
2020
AstroidTypeError,
2121
AstroidValueError,
2222
InferenceError,
23-
MroError,
2423
UseInferenceDefault,
2524
)
2625
from astroid.manager import AstroidManager
2726

28-
ENUM_BASE_NAMES = {
29-
"Enum",
30-
"IntEnum",
31-
"enum.Enum",
32-
"enum.IntEnum",
33-
"IntFlag",
34-
"enum.IntFlag",
35-
}
3627
ENUM_QNAME: Final[str] = "enum.Enum"
3728
TYPING_NAMEDTUPLE_QUALIFIED: Final = {
3829
"typing.NamedTuple",
@@ -644,14 +635,7 @@ def _get_namedtuple_fields(node: nodes.Call) -> str:
644635

645636
def _is_enum_subclass(cls: astroid.ClassDef) -> bool:
646637
"""Return whether cls is a subclass of an Enum."""
647-
try:
648-
return any(
649-
klass.name in ENUM_BASE_NAMES
650-
and getattr(klass.root(), "name", None) == "enum"
651-
for klass in cls.mro()
652-
)
653-
except MroError:
654-
return False
638+
return cls.is_subtype_of("enum.Enum")
655639

656640

657641
AstroidManager().register_transform(

tests/brain/test_enum.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,3 +521,39 @@ def __init__(self, mass, radius):
521521
mars, radius = enum_members.items
522522
assert mars[1].name == "MARS"
523523
assert radius[1].name == "radius"
524+
525+
def test_local_enum_child_class_inference(self) -> None:
526+
"""Originally reported in https://github.com/pylint-dev/pylint/issues/8897
527+
528+
Test that a user-defined enum class is inferred when it subclasses
529+
another user-defined enum class.
530+
"""
531+
enum_class_node, enum_member_value_node = astroid.extract_node(
532+
"""
533+
import sys
534+
535+
from enum import Enum
536+
537+
if sys.version_info >= (3, 11):
538+
from enum import StrEnum
539+
else:
540+
class StrEnum(str, Enum):
541+
pass
542+
543+
544+
class Color(StrEnum): #@
545+
RED = "red"
546+
547+
548+
Color.RED.value #@
549+
"""
550+
)
551+
assert "RED" in enum_class_node.locals
552+
553+
enum_members = enum_class_node.locals["__members__"][0].items
554+
assert len(enum_members) == 1
555+
_, name = enum_members[0]
556+
assert name.name == "RED"
557+
558+
inferred_enum_member_value_node = next(enum_member_value_node.infer())
559+
assert inferred_enum_member_value_node.value == "red"

0 commit comments

Comments
 (0)