From 49316f9fb8ccddc3941a1fbe378e4c96e929152f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 16 Nov 2022 20:20:30 +0000 Subject: [PATCH] Allow super() for mixin protocols (#14082) Fixes #12344 FWIW this is unsafe (since we don't know where the mixin will appear in the MRO of the actual implementation), but the alternative is having annoying false positives like this issue and e.g. https://github.com/python/mypy/issues/4335 --- mypy/checkexpr.py | 14 ++++++++++++-- test-data/unit/check-selftype.test | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 376e1f811692..3d2c69073bc0 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4321,8 +4321,18 @@ def visit_super_expr(self, e: SuperExpr) -> Type: mro = e.info.mro index = mro.index(type_info) if index is None: - self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e) - return AnyType(TypeOfAny.from_error) + if ( + instance_info.is_protocol + and instance_info != type_info + and not type_info.is_protocol + ): + # A special case for mixins, in this case super() should point + # directly to the host protocol, this is not safe, since the real MRO + # is not known yet for mixin, but this feature is more like an escape hatch. + index = -1 + else: + self.chk.fail(message_registry.SUPER_ARG_2_NOT_INSTANCE_OF_ARG_1, e) + return AnyType(TypeOfAny.from_error) if len(mro) == index + 1: self.chk.fail(message_registry.TARGET_CLASS_HAS_NO_BASE_CLASS, e) diff --git a/test-data/unit/check-selftype.test b/test-data/unit/check-selftype.test index a7dc41a2ff86..072978254049 100644 --- a/test-data/unit/check-selftype.test +++ b/test-data/unit/check-selftype.test @@ -792,6 +792,26 @@ reveal_type(f.copy()) # N: Revealed type is "__main__.File" b.copy() # E: Invalid self argument "Bad" to attribute function "copy" with type "Callable[[T], T]" [builtins fixtures/tuple.pyi] +[case testMixinProtocolSuper] +from typing import Protocol + +class Base(Protocol): + def func(self) -> int: + ... + +class TweakFunc: + def func(self: Base) -> int: + return reveal_type(super().func()) # N: Revealed type is "builtins.int" + +class Good: + def func(self) -> int: ... +class C(TweakFunc, Good): pass +C().func() # OK + +class Bad: + def func(self) -> str: ... +class CC(TweakFunc, Bad): pass # E: Definition of "func" in base class "TweakFunc" is incompatible with definition in base class "Bad" + [case testBadClassLevelDecoratorHack] from typing_extensions import Protocol from typing import TypeVar, Any