Skip to content

Commit

Permalink
Fix recently added enum value type prediction
Browse files Browse the repository at this point in the history
In python#9443, some code was added to predict the type of enum values where
it is not explicitly when all enum members have the same type.

However, it didn't consider that subclasses of Enum that have a custom
__new__ implementation may use any type for the enum value (typically it
would use only one of their parameters instead of a whole tuple that is
specified in the definition of the member). Fix this by avoiding to
guess the enum value type in classes that implement __new__.

In addition, the added code was buggy in that it didn't only consider
class attributes as enum members, but also instance attributes assigned
to self.* in __init__. Fix this by ignoring implicit nodes when checking
the enum members.

Fixes python#10000.
  • Loading branch information
taljeth committed Feb 10, 2021
1 parent 11d4fb2 commit 585152b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 3 deletions.
35 changes: 32 additions & 3 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import mypy.plugin # To avoid circular imports.
from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type
from mypy.nodes import TypeInfo

# Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
# enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
Expand Down Expand Up @@ -103,6 +104,17 @@ def _infer_value_type_with_auto_fallback(
return ctx.default_attr_type


def _implements_new(info: TypeInfo) -> bool:
"""Check whether __new__ comes from enum.Enum or was implemented in a
subclass. In the latter case, we must infer Any as long as mypy can't infer
the type of _value_ from assignments in __new__.
"""
type_with_new = _first(ti for ti in info.mro if ti.names.get('__new__'))
if type_with_new is None:
return False
return type_with_new.fullname != 'enum.Enum'


def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
Expand Down Expand Up @@ -135,12 +147,22 @@ class SomeEnum:
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type

# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type

stnodes = (info.get(name) for name in info.names)
# Enums _can_ have methods.
# Omit methods for our value inference.

# Enums _can_ have methods and instance attributes.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes)
for n in stnodes
if n is None or not n.implicit)
proper_types = (
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
Expand All @@ -158,6 +180,13 @@ class SomeEnum:

assert isinstance(ctx.type, Instance)
info = ctx.type.type

# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type

stnode = info.get(enum_field_name)
if stnode is None:
return ctx.default_attr_type
Expand Down
56 changes: 56 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -1243,3 +1243,59 @@ class Comparator(enum.Enum):

reveal_type(Comparator.__foo__) # N: Revealed type is 'builtins.dict[builtins.str, builtins.int]'
[builtins fixtures/dict.pyi]

[case testEnumWithInstanceAttributes]
from enum import Enum
class Foo(Enum):
def __init__(self, value: int) -> None:
self.foo = "bar"
A = 1
B = 2

a = Foo.A
reveal_type(a.value) # N: Revealed type is 'builtins.int'
reveal_type(a._value_) # N: Revealed type is 'builtins.int'

[case testNewSetsUnexpectedValueType]
from enum import Enum

class bytes:
def __new__(cls): pass

class Foo(bytes, Enum):
def __new__(cls, value: int) -> 'Foo':
obj = bytes.__new__(cls)
obj._value_ = "Number %d" % value
return obj
A = 1
B = 2

a = Foo.A
reveal_type(a.value) # N: Revealed type is 'Any'
reveal_type(a._value_) # N: Revealed type is 'Any'
[builtins fixtures/__new__.pyi]
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testValueTypeWithNewInParentClass]
from enum import Enum

class bytes:
def __new__(cls): pass

class Foo(bytes, Enum):
def __new__(cls, value: int) -> 'Foo':
obj = bytes.__new__(cls)
obj._value_ = "Number %d" % value
return obj

class Bar(Foo):
A = 1
B = 2

a = Bar.A
reveal_type(a.value) # N: Revealed type is 'Any'
reveal_type(a._value_) # N: Revealed type is 'Any'
[builtins fixtures/__new__.pyi]
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

0 comments on commit 585152b

Please sign in to comment.