Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with in operator used with a union of Container and Iterable #14384

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4499,6 +4499,26 @@ def analyze_iterable_item_type(self, expr: Expression) -> tuple[Type, Type]:
# Non-tuple iterable.
return iterator, echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0]

def analyze_iterable_item_type_without_expression(
self, type: Type, context: Context
) -> tuple[Type, Type]:
"""Analyse iterable type and return iterator and iterator item types."""
echk = self.expr_checker
iterable = get_proper_type(type)
iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]

if isinstance(iterable, TupleType):
joined: Type = UninhabitedType()
for item in iterable.items:
joined = join_types(joined, item)
return iterator, joined
else:
# Non-tuple iterable.
return (
iterator,
echk.check_method_call_by_name("__next__", iterator, [], [], context)[0],
)

def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).

Expand Down
141 changes: 90 additions & 51 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,75 +2914,116 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
"""
result: Type | None = None
sub_result: Type | None = None
sub_result: Type

# Check each consecutive operand pair and their operator
for left, right, operator in zip(e.operands, e.operands[1:], e.operators):
left_type = self.accept(left)

method_type: mypy.types.Type | None = None

if operator == "in" or operator == "not in":
# This case covers both iterables and containers, which have different meanings.
# For a container, the in operator calls the __contains__ method.
# For an iterable, the in operator iterates over the iterable, and compares each item one-by-one.
# We allow `in` for a union of containers and iterables as long as at least one of them matches the
# type of the left operand, as the operation will simply return False if the union's container/iterator
# type doesn't match the left operand.

# If the right operand has partial type, look it up without triggering
# a "Need type annotation ..." message, as it would be noise.
right_type = self.find_partial_type_ref_fast_path(right)
if right_type is None:
right_type = self.accept(right) # Validate the right operand

# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as local_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=right_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
)
right_type = get_proper_type(right_type)
item_types: Sequence[Type] = [right_type]
if isinstance(right_type, UnionType):
item_types = list(right_type.items)

sub_result = self.bool_type()
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(right_type)
if isinstance(right_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
pass
elif (
local_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(right_type)
):
_, itertype = self.chk.analyze_iterable_item_type(right)
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types("in", left_type, right_type, e)
# Only show dangerous overlap if there are no other errors.
elif (
not local_errors.has_new_errors()
and cont_type
and self.dangerous_comparison(
left_type, cont_type, original_container=right_type, prefer_literal=False
)
):
self.msg.dangerous_comparison(left_type, cont_type, "container", e)
else:
self.msg.add_errors(local_errors.filtered_errors())

container_types: list[Type] = []
iterable_types: list[Type] = []
failed_out = False
encountered_partial_type = False

for item_type in item_types:
# Keep track of whether we get type check errors (these won't be reported, they
# are just to verify whether something is valid typing wise).
with self.msg.filter_errors(save_filtered_errors=True) as container_errors:
_, method_type = self.check_method_call_by_name(
method="__contains__",
base_type=item_type,
args=[left],
arg_kinds=[ARG_POS],
context=e,
original_type=right_type,
)
# Container item type for strict type overlap checks. Note: we need to only
# check for nominal type, because a usual "Unsupported operands for in"
# will be reported for types incompatible with __contains__().
# See testCustomContainsCheckStrictEquality for an example.
cont_type = self.chk.analyze_container_item_type(item_type)

if isinstance(item_type, PartialType):
# We don't really know if this is an error or not, so just shut up.
encountered_partial_type = True
pass
elif (
container_errors.has_new_errors()
and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(item_type)
):
# it's not a container, but it is an iterable
with self.msg.filter_errors(save_filtered_errors=True) as iterable_errors:
_, itertype = self.chk.analyze_iterable_item_type_without_expression(
item_type, e
)
if iterable_errors.has_new_errors():
self.msg.add_errors(iterable_errors.filtered_errors())
failed_out = True
else:
method_type = CallableType(
[left_type],
[nodes.ARG_POS],
[None],
self.bool_type(),
self.named_type("builtins.function"),
)
e.method_types.append(method_type)
iterable_types.append(itertype)
elif not container_errors.has_new_errors() and cont_type:
container_types.append(cont_type)
e.method_types.append(method_type)
else:
self.msg.add_errors(container_errors.filtered_errors())
failed_out = True

if not encountered_partial_type and not failed_out:
iterable_type = UnionType.make_union(iterable_types)
if not is_subtype(left_type, iterable_type):
if len(container_types) == 0:
self.msg.unsupported_operand_types("in", left_type, right_type, e)
else:
container_type = UnionType.make_union(container_types)
if self.dangerous_comparison(
left_type,
container_type,
original_container=right_type,
prefer_literal=False,
):
self.msg.dangerous_comparison(
left_type, container_type, "container", e
)

elif operator in operators.op_methods:
method = operators.op_methods[operator]

with ErrorWatcher(self.msg.errors) as w:
sub_result, method_type = self.check_op(
method, left_type, right, e, allow_reverse=True
)
e.method_types.append(method_type)

# Only show dangerous overlap if there are no other errors. See
# testCustomEqCheckStrictEquality for an example.
Expand All @@ -3002,12 +3043,10 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
left_type = try_getting_literal(left_type)
right_type = try_getting_literal(right_type)
self.msg.dangerous_comparison(left_type, right_type, "identity", e)
method_type = None
e.method_types.append(None)
else:
raise RuntimeError(f"Unknown comparison operator {operator}")

e.method_types.append(method_type)

# Determine type of boolean-and of result and sub_result
if result is None:
result = sub_result
Expand Down
17 changes: 17 additions & 0 deletions test-data/unit/check-unions.test
Original file line number Diff line number Diff line change
Expand Up @@ -1202,3 +1202,20 @@ def foo(
yield i
foo([1])
[builtins fixtures/list.pyi]

[case testUnionIterableContainer]
from typing import Iterable, Container, Union

i: Iterable[str]
c: Container[str]
u: Union[Iterable[str], Container[str]]
ni: Union[Iterable[str], int]
nc: Union[Container[str], int]

'x' in i
'x' in c
'x' in u
'x' in ni # E: Unsupported right operand type for in ("Union[Iterable[str], int]")
'x' in nc # E: Unsupported right operand type for in ("Union[Container[str], int]")
[builtins fixtures/tuple.pyi]
[typing fixtures/typing-full.pyi]