Skip to content

Commit

Permalink
Predict enum value type for unknown member names (#9443)
Browse files Browse the repository at this point in the history
It is very common for enums to have homogenous member-value types.
In the case where we do not know what enum member we are dealing
with, we should sniff for that case and still collapse to a known
type if that assumption holds.

Handles auto() too, even if you override _get_next_value_.
  • Loading branch information
mgilson authored Sep 17, 2020
1 parent b707d29 commit 37777b3
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 20 deletions.
90 changes: 80 additions & 10 deletions mypy/plugins/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
we actually bake some of it directly in to the semantic analysis layer (see
semanal_enum.py).
"""
from typing import Optional
from typing import Iterable, Optional, TypeVar
from typing_extensions import Final

import mypy.plugin # To avoid circular imports.
from mypy.types import Type, Instance, LiteralType, get_proper_type
from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type

# Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use
# enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes.
Expand Down Expand Up @@ -53,6 +53,56 @@ def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
return str_type.copy_modified(last_known_value=literal_type)


_T = TypeVar('_T')


def _first(it: Iterable[_T]) -> Optional[_T]:
"""Return the first value from any iterable.
Returns ``None`` if the iterable is empty.
"""
for val in it:
return val
return None


def _infer_value_type_with_auto_fallback(
ctx: 'mypy.plugin.AttributeContext',
proper_type: Optional[ProperType]) -> Optional[Type]:
"""Figure out the type of an enum value accounting for `auto()`.
This method is a no-op for a `None` proper_type and also in the case where
the type is not "enum.auto"
"""
if proper_type is None:
return None
if not ((isinstance(proper_type, Instance) and
proper_type.type.fullname == 'enum.auto')):
return proper_type
assert isinstance(ctx.type, Instance), 'An incorrect ctx.type was passed.'
info = ctx.type.type
# Find the first _generate_next_value_ on the mro. We need to know
# if it is `Enum` because `Enum` types say that the return-value of
# `_generate_next_value_` is `Any`. In reality the default `auto()`
# returns an `int` (presumably the `Any` in typeshed is to make it
# easier to subclass and change the returned type).
type_with_gnv = _first(
ti for ti in info.mro if ti.names.get('_generate_next_value_'))
if type_with_gnv is None:
return ctx.default_attr_type

stnode = type_with_gnv.names['_generate_next_value_']

# This should be a `CallableType`
node_type = get_proper_type(stnode.type)
if isinstance(node_type, CallableType):
if type_with_gnv.fullname == 'enum.Enum':
int_type = ctx.api.named_generic_type('builtins.int', [])
return int_type
return get_proper_type(node_type.ret_type)
return ctx.default_attr_type


def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
Expand All @@ -78,6 +128,32 @@ class SomeEnum:
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
# We do not know the enum field name (perhaps it was passed to a
# function and we only know that it _is_ a member). All is not lost
# however, if we can prove that the all of the enum members have the
# same value-type, then it doesn't matter which member was passed in.
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type
stnodes = (info.get(name) for name in info.names)
# Enums _can_ have methods.
# Omit methods for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes)
proper_types = (
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None or not isinstance(t, CallableType))
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type
for proper_type in proper_types)
if all_same_value_type:
if underlying_type is not None:
return underlying_type
return ctx.default_attr_type

assert isinstance(ctx.type, Instance)
Expand All @@ -86,15 +162,9 @@ class SomeEnum:
if stnode is None:
return ctx.default_attr_type

underlying_type = get_proper_type(stnode.type)
underlying_type = _infer_value_type_with_auto_fallback(
ctx, get_proper_type(stnode.type))
if underlying_type is None:
# TODO: Deduce the inferred type if the user omits adding their own default types.
# TODO: Consider using the return type of `Enum._generate_next_value_` here?
return ctx.default_attr_type

if isinstance(underlying_type, Instance) and underlying_type.type.fullname == 'enum.auto':
# TODO: Deduce the correct inferred type when the user uses 'enum.auto'.
# We should use the same strategy we end up picking up above.
return ctx.default_attr_type

return underlying_type
Expand Down
88 changes: 79 additions & 9 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,76 @@ reveal_type(Truth.true.name) # N: Revealed type is 'Literal['true']?'
reveal_type(Truth.false.value) # N: Revealed type is 'builtins.bool'
[builtins fixtures/bool.pyi]

[case testEnumValueExtended]
from enum import Enum
class Truth(Enum):
true = True
false = False

def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'builtins.bool'
[builtins fixtures/bool.pyi]

[case testEnumValueAllAuto]
from enum import Enum, auto
class Truth(Enum):
true = auto()
false = auto()

def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'builtins.int'
[builtins fixtures/primitives.pyi]

[case testEnumValueSomeAuto]
from enum import Enum, auto
class Truth(Enum):
true = 8675309
false = auto()

def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'builtins.int'
[builtins fixtures/primitives.pyi]

[case testEnumValueExtraMethods]
from enum import Enum, auto
class Truth(Enum):
true = True
false = False

def foo(self) -> str:
return 'bar'

def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'builtins.bool'
[builtins fixtures/bool.pyi]

[case testEnumValueCustomAuto]
from enum import Enum, auto
class AutoName(Enum):

# In `typeshed`, this is a staticmethod and has more arguments,
# but I have lied a bit to keep the test stubs lean.
def _generate_next_value_(self) -> str:
return "name"

class Truth(AutoName):
true = auto()
false = auto()

def infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'builtins.str'
[builtins fixtures/primitives.pyi]

[case testEnumValueInhomogenous]
from enum import Enum
class Truth(Enum):
true = 'True'
false = 0

def cannot_infer_truth(truth: Truth) -> None:
reveal_type(truth.value) # N: Revealed type is 'Any'
[builtins fixtures/bool.pyi]

[case testEnumUnique]
import enum
@enum.unique
Expand Down Expand Up @@ -497,8 +567,8 @@ reveal_type(A1.x.value) # N: Revealed type is 'Any'
reveal_type(A1.x._value_) # N: Revealed type is 'Any'
is_x(reveal_type(A2.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(A2.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(A2.x.value) # N: Revealed type is 'Any'
reveal_type(A2.x._value_) # N: Revealed type is 'Any'
reveal_type(A2.x.value) # N: Revealed type is 'builtins.int'
reveal_type(A2.x._value_) # N: Revealed type is 'builtins.int'
is_x(reveal_type(A3.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(A3.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(A3.x.value) # N: Revealed type is 'builtins.int'
Expand All @@ -519,7 +589,7 @@ reveal_type(B1.x._value_) # N: Revealed type is 'Any'
is_x(reveal_type(B2.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(B2.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(B2.x.value) # N: Revealed type is 'builtins.int'
reveal_type(B2.x._value_) # N: Revealed type is 'Any'
reveal_type(B2.x._value_) # N: Revealed type is 'builtins.int'
is_x(reveal_type(B3.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(B3.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(B3.x.value) # N: Revealed type is 'builtins.int'
Expand All @@ -540,8 +610,8 @@ reveal_type(C1.x.value) # N: Revealed type is 'Any'
reveal_type(C1.x._value_) # N: Revealed type is 'Any'
is_x(reveal_type(C2.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(C2.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(C2.x.value) # N: Revealed type is 'Any'
reveal_type(C2.x._value_) # N: Revealed type is 'Any'
reveal_type(C2.x.value) # N: Revealed type is 'builtins.int'
reveal_type(C2.x._value_) # N: Revealed type is 'builtins.int'
is_x(reveal_type(C3.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(C3.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(C3.x.value) # N: Revealed type is 'builtins.int'
Expand All @@ -559,8 +629,8 @@ reveal_type(D1.x.value) # N: Revealed type is 'Any'
reveal_type(D1.x._value_) # N: Revealed type is 'Any'
is_x(reveal_type(D2.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(D2.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(D2.x.value) # N: Revealed type is 'Any'
reveal_type(D2.x._value_) # N: Revealed type is 'Any'
reveal_type(D2.x.value) # N: Revealed type is 'builtins.int'
reveal_type(D2.x._value_) # N: Revealed type is 'builtins.int'
is_x(reveal_type(D3.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(D3.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(D3.x.value) # N: Revealed type is 'builtins.int'
Expand All @@ -578,8 +648,8 @@ class E3(Parent):

is_x(reveal_type(E2.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(E2.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(E2.x.value) # N: Revealed type is 'Any'
reveal_type(E2.x._value_) # N: Revealed type is 'Any'
reveal_type(E2.x.value) # N: Revealed type is 'builtins.int'
reveal_type(E2.x._value_) # N: Revealed type is 'builtins.int'
is_x(reveal_type(E3.x.name)) # N: Revealed type is 'Literal['x']'
is_x(reveal_type(E3.x._name_)) # N: Revealed type is 'Literal['x']'
reveal_type(E3.x.value) # N: Revealed type is 'builtins.int'
Expand Down
6 changes: 5 additions & 1 deletion test-data/unit/lib-stub/enum.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class Enum(metaclass=EnumMeta):
_name_: str
_value_: Any

# In reality, _generate_next_value_ is python3.6 only and has a different signature.
# However, this should be quick and doesn't require additional stubs (e.g. `staticmethod`)
def _generate_next_value_(self) -> Any: pass

class IntEnum(int, Enum):
value: int

Expand All @@ -37,4 +41,4 @@ class IntFlag(int, Flag):


class auto(IntFlag):
value: Any
value: Any

0 comments on commit 37777b3

Please sign in to comment.