Skip to content

Commit

Permalink
Fix constraint inference for non-invariant instances (#5817)
Browse files Browse the repository at this point in the history
Fixes #2035

Note that in two tests we now infer `object` instead of giving an error. This may be not what a user expects, but I think this is still OK, and after #3816 mypy will always ask for an annotation in such cases.
  • Loading branch information
ilevkivskyi authored Oct 22, 2018
1 parent a7f0263 commit 2b71c2f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 17 deletions.
32 changes: 19 additions & 13 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import mypy.subtypes
from mypy.sametypes import is_same_type
from mypy.erasetype import erase_typevars
from mypy.nodes import COVARIANT, CONTRAVARIANT

MYPY = False
if MYPY:
Expand Down Expand Up @@ -337,25 +338,30 @@ def visit_instance(self, template: Instance) -> List[Constraint]:
if (self.direction == SUBTYPE_OF and
template.type.has_base(instance.type.fullname())):
mapped = map_instance_to_supertype(template, instance.type)
tvars = mapped.type.defn.type_vars
for i in range(len(instance.args)):
# The constraints for generic type parameters are
# invariant. Include constraints from both directions
# to achieve the effect.
res.extend(infer_constraints(
mapped.args[i], instance.args[i], self.direction))
res.extend(infer_constraints(
mapped.args[i], instance.args[i], neg_op(self.direction)))
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
if tvars[i].variance != CONTRAVARIANT:
res.extend(infer_constraints(
mapped.args[i], instance.args[i], self.direction))
if tvars[i].variance != COVARIANT:
res.extend(infer_constraints(
mapped.args[i], instance.args[i], neg_op(self.direction)))
return res
elif (self.direction == SUPERTYPE_OF and
instance.type.has_base(template.type.fullname())):
mapped = map_instance_to_supertype(instance, template.type)
tvars = template.type.defn.type_vars
for j in range(len(template.args)):
# The constraints for generic type parameters are
# invariant.
res.extend(infer_constraints(
template.args[j], mapped.args[j], self.direction))
res.extend(infer_constraints(
template.args[j], mapped.args[j], neg_op(self.direction)))
# The constraints for generic type parameters depend on variance.
# Include constraints from both directions if invariant.
if tvars[j].variance != CONTRAVARIANT:
res.extend(infer_constraints(
template.args[j], mapped.args[j], self.direction))
if tvars[j].variance != COVARIANT:
res.extend(infer_constraints(
template.args[j], mapped.args[j], neg_op(self.direction)))
return res
if (template.type.is_protocol and self.direction == SUPERTYPE_OF and
# We avoid infinite recursion for structural subtypes by checking
Expand Down
8 changes: 4 additions & 4 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1666,8 +1666,8 @@ it = [('x', 1)]
d = dict(it, x=1)
d() # E: "Dict[str, int]" not callable

d2 = dict(it, x='') # E: Cannot infer type argument 2 of "dict"
d2() # E: "Dict[Any, Any]" not callable
d2 = dict(it, x='')
d2() # E: "Dict[str, object]" not callable

d3 = dict(it, x='') # type: Dict[str, int] # E: Argument "x" to "dict" has incompatible type "str"; expected "int"
[builtins fixtures/dict.pyi]
Expand All @@ -1691,8 +1691,8 @@ d = dict(it, **kw)
d() # E: "Dict[str, int]" not callable

kw2 = {'x': ''}
d2 = dict(it, **kw2) # E: Cannot infer type argument 2 of "dict"
d2() # E: "Dict[Any, Any]" not callable
d2 = dict(it, **kw2)
d2() # E: "Dict[str, object]" not callable

d3 = dict(it, **kw2) # type: Dict[str, int] # E: Argument 2 to "dict" has incompatible type "**Dict[str, str]"; expected "int"
[builtins fixtures/dict.pyi]
Expand Down
65 changes: 65 additions & 0 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,71 @@ class C:
g: Callable[[], int] = lambda: 1 or self.x
self.x = int()

[case testInferTypeVariableFromTwoGenericTypes1]
from typing import TypeVar, List, Sequence

T = TypeVar('T')

class C: ...
class D(C): ...

def f(x: Sequence[T], y: Sequence[T]) -> List[T]: ...

reveal_type(f([C()], [D()])) # E: Revealed type is 'builtins.list[__main__.C*]'
[builtins fixtures/list.pyi]

[case testInferTypeVariableFromTwoGenericTypes2]
from typing import TypeVar, List

T = TypeVar('T')

class C: ...
class D(C): ...

def f(x: List[T], y: List[T]) -> List[T]: ...

f([C()], [D()]) # E: Cannot infer type argument 1 of "f"
[builtins fixtures/list.pyi]

[case testInferTypeVariableFromTwoGenericTypes3]
from typing import Generic, TypeVar

T = TypeVar('T')
T_contra = TypeVar('T_contra', contravariant=True)

class A(Generic[T_contra]): pass
class B(A[T]): pass

class C: ...
class D(C): ...

def f(x: A[T], y: A[T]) -> B[T]: ...

c: B[C]
d: B[D]
reveal_type(f(c, d)) # E: Revealed type is '__main__.B[__main__.D*]'

[case testInferTypeVariableFromTwoGenericTypes4]
from typing import Generic, TypeVar, Callable, List

T = TypeVar('T')
T_contra = TypeVar('T_contra', contravariant=True)

class A(Generic[T_contra]): pass
class B(A[T_contra]): pass

class C: ...
class D(C): ...

def f(x: Callable[[B[T]], None],
y: Callable[[B[T]], None]) -> List[T]: ...

def gc(x: A[C]) -> None: pass # B[C]
def gd(x: A[D]) -> None: pass # B[C]

reveal_type(f(gc, gd)) # E: Revealed type is 'builtins.list[__main__.C*]'
[builtins fixtures/list.pyi]

[case testWideOuterContextSubClassBound]
from typing import TypeVar

Expand Down

0 comments on commit 2b71c2f

Please sign in to comment.