diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 30984ad07c460..9c01b52d06819 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -413,3 +413,45 @@ main:3: error: Revealed type is 'm.F' main:2: error: Revealed type is 'm.E' main:3: error: Revealed type is 'm.F' +[case testEnumIteration] +from enum import Enum +class E(Enum): + A = 'a' +l = [e for e in E] +reveal_type(l[0]) # E: Revealed type is '__main__.E' +for e in E: + reveal_type(e) # E: Revealed type is '__main__.E' +[builtins fixtures/list.pyi] + +[case testEnumIterable] +from enum import Enum +from typing import Iterable +class E(Enum): + a = 'a' +def f(ie:Iterable[E]): + pass +f(E) + +[case testIntEnumIterable] +from enum import IntEnum +from typing import Iterable +class N(IntEnum): + x = 1 +def f(ni: Iterable[N]): + pass +def g(ii: Iterable[int]): + pass +f(N) +g(N) + +[case testDerivedEnumIterable] +from enum import Enum +from typing import Iterable +class E(str, Enum): + a = 'foo' +def f(ei: Iterable[E]): + pass +def g(si: Iterable[str]): + pass +f(E) +g(E) diff --git a/test-data/unit/pythoneval.test b/test-data/unit/pythoneval.test index 67374962afc31..f15d14412b8ac 100644 --- a/test-data/unit/pythoneval.test +++ b/test-data/unit/pythoneval.test @@ -1458,3 +1458,45 @@ _testNoCrashOnGenericUnionUnpacking.py:10: error: Revealed type is 'Union[builti _testNoCrashOnGenericUnionUnpacking.py:11: error: Revealed type is 'Union[builtins.str, builtins.int]' _testNoCrashOnGenericUnionUnpacking.py:15: error: Revealed type is 'Union[builtins.int*, builtins.str*]' _testNoCrashOnGenericUnionUnpacking.py:16: error: Revealed type is 'Union[builtins.int*, builtins.str*]' + +[case testEnumIteration] +from enum import Enum +class E(Enum): + A = 'a' +l = [e for e in E] +reveal_type(l[0]) # E: Revealed type is '__main__.E' +for e in E: + reveal_type(e) # E: Revealed type is '__main__.E' + +[case testEnumIterable] +from enum import Enum +from typing import Iterable +class E(Enum): + a = 'a' +def f(ie:Iterable[E]): + pass +f(E) + +[case testIntEnumIterable] +from enum import IntEnum +from typing import Iterable +class N(IntEnum): + x = 1 +def f(ni: Iterable[N]): + pass +def g(ii: Iterable[int]): + pass +f(N) +g(N) + +[case testDerivedEnumIterable] +from enum import Enum +from typing import Iterable +class E(str, Enum): + a = 'foo' +def f(ei: Iterable[E]): + pass +def g(si: Iterable[str]): + pass +f(E) +g(E)