Skip to content

Commit

Permalink
Merge pull request #898 from jhance/cov-contra
Browse files Browse the repository at this point in the history
Fixes for covariance/contravariance + #734
  • Loading branch information
JukkaL committed Oct 12, 2015
2 parents 6d9351a + 5fc466b commit 6eb1e92
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 4 deletions.
18 changes: 16 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr, StarExpr,
YieldFromExpr, YieldFromStmt, NamedTupleExpr, SetComprehension,
DictionaryComprehension, ComplexExpr, EllipsisExpr, TypeAliasExpr,
RefExpr, YieldExpr
RefExpr, YieldExpr, CONTRAVARIANT, COVARIANT
)
from mypy.nodes import function_type, method_type, method_type_with_fallback
from mypy import nodes
from mypy.types import (
Type, AnyType, CallableType, Void, FunctionLike, Overloaded, TupleType,
Instance, NoneTyp, UnboundType, ErrorType, TypeTranslator, strip_type, UnionType
Instance, NoneTyp, UnboundType, ErrorType, TypeTranslator, strip_type,
UnionType, TypeVarType,
)
from mypy.sametypes import is_same_type
from mypy.messages import MessageBuilder
Expand Down Expand Up @@ -505,12 +506,25 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str) -> None:
elif name == '__getattr__':
self.check_getattr_method(typ, defn)

# Refuse contravariant return type variable
if isinstance(typ.ret_type, TypeVarType):
if typ.ret_type.variance == CONTRAVARIANT:
self.fail(messages.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT,
typ.ret_type)

# Push return type.
self.return_types.append(typ.ret_type)

# Store argument types.
for i in range(len(typ.arg_types)):
arg_type = typ.arg_types[i]

# Refuse covariant parameter type variables
if isinstance(arg_type, TypeVarType):
if arg_type.variance == COVARIANT:
self.fail(messages.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT,
arg_type)

if typ.arg_kinds[i] == nodes.ARG_STAR:
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type('builtins.tuple',
Expand Down
2 changes: 2 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
FORMAT_REQUIRES_MAPPING = 'Format requires a mapping'
GENERIC_TYPE_NOT_VALID_AS_EXPRESSION = \
"Generic type not valid as an expression any more (use '# type:' comment instead)"
RETURN_TYPE_CANNOT_BE_CONTRAVARIANT = "Cannot use a contravariant type variable as return type"
FUNCTION_PARAMETER_CANNOT_BE_COVARIANT = "Cannot use a covariant type variable as a parameter"


class MessageBuilder:
Expand Down
40 changes: 40 additions & 0 deletions mypy/test/data/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -902,3 +902,43 @@ def f():
"""
return 1
f() + ''
[case testRejectCovariantArgument]
from typing import TypeVar, Generic

t = TypeVar('t', covariant=True)
class A(Generic[t]):
def foo(self, x: t) -> None:
return None
[builtins fixtures/bool.py]
[out]
main: note: In member "foo" of class "A":
main:5: error: Cannot use a covariant type variable as a parameter

[case testRejectContravariantReturnType]
from typing import TypeVar, Generic

t = TypeVar('t', contravariant=True)
class A(Generic[t]):
def foo(self) -> t:
return None
[builtins fixtures/bool.py]
[out]
main: note: In member "foo" of class "A":
main:5: error: Cannot use a contravariant type variable as return type

[case testAcceptCovariantReturnType]
from typing import TypeVar, Generic

t = TypeVar('t', covariant=True)
class A(Generic[t]):
def foo(self) -> t:
return None
[builtins fixtures/bool.py]
[case testAcceptContravariantArgument]
from typing import TypeVar, Generic

t = TypeVar('t', contravariant=True)
class A(Generic[t]):
def foo(self, x: t) -> None:
return None
[builtins fixtures/bool.py]
5 changes: 3 additions & 2 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def visit_unbound_type(self, t: UnboundType) -> Type:
if len(t.args) > 0:
self.fail('Type variable "{}" used with arguments'.format(
t.name), t)
values = cast(TypeVarExpr, sym.node).values
return TypeVarType(t.name, sym.tvar_id, values,
tvar_expr = cast(TypeVarExpr, sym.node)
return TypeVarType(t.name, sym.tvar_id, tvar_expr.values,
self.builtin_type('builtins.object'),
tvar_expr.variance,
t.line)
elif fullname == 'builtins.None':
return Void()
Expand Down

0 comments on commit 6eb1e92

Please sign in to comment.