Skip to content

is_subtype function does not work correctly for two overloads #9147

Closed
@sobolevn

Description

@sobolevn
  • Are you reporting a bug, or opening a feature request?

I am reporting a bug.

In mypy curry plugin I generate types like this for curried functions:

from returns.curry import curry

@curry
def curried(a: int, b: int, c: str) -> float:
    ...

reveal_type(curried)  
# Revealed type is 'Overload(
#   def (a: builtins.int) -> Overload(
#     def (b: builtins.int, c: builtins.str) -> builtins.float, 
#     def (b: builtins.int) -> def (c: builtins.str) -> builtins.float
#   ), 
#   def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float, 
#   def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float,
# )'

When I try to manipulate this type in other places like:

from returns.maybe import Maybe

Maybe.from_value(1).apply(Maybe.from_value(curried))

It raises an error:

Argument 1 to "Maybe" has incompatible type overloaded function; expected overloaded function

Original issue: dry-python/returns#459

So, I have started from here:

mypy/mypy/checkexpr.py

Lines 1470 to 1481 in 358522e

elif not is_subtype(caller_type, callee_type):
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
return
code = messages.incompatible_argument(n,
m,
callee,
original_caller_type,
caller_kind,
context=context,
outer_context=outer_context)
messages.incompatible_argument_note(original_caller_type, callee_type, context,
code=code)

And checked what are the types I am dealing with:

print('----')
print('--->', caller_type)
print('---<', callee_type)
print(is_subtype(caller_type, callee_type))
print(is_subtype(
                caller_type, callee_type,
                ignore_type_params=True,
                ignore_pos_arg_names=True,
                ignore_declared_variance=True,
                ignore_promotions=True))  # I have tried with all possible ignores set
print(str(caller_type) == str(callee_type))
print('----')

Outputs:

----
---> Overload(def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float), def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float, def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float)
---< Overload(def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float), def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float, def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float)
False  # is_subtype
False  # is_subtype with ignores
True   # str(callee) == str(caller)
----

So, it looks like the types are correct, but is_subtypes treats equal types incorrectly.

So, I have tracked it down to these lines:

mypy/mypy/subtypes.py

Lines 401 to 443 in 358522e

elif isinstance(right, Overloaded):
# Ensure each overload in the right side (the supertype) is accounted for.
previous_match_left_index = -1
matched_overloads = set()
possible_invalid_overloads = set()
for right_index, right_item in enumerate(right.items()):
found_match = False
for left_index, left_item in enumerate(left.items()):
subtype_match = self._is_subtype(left_item, right_item)\
# Order matters: we need to make sure that the index of
# this item is at least the index of the previous one.
if subtype_match and previous_match_left_index <= left_index:
if not found_match:
# Update the index of the previous match.
previous_match_left_index = left_index
found_match = True
matched_overloads.add(left_item)
possible_invalid_overloads.discard(left_item)
else:
# If this one overlaps with the supertype in any way, but it wasn't
# an exact match, then it's a potential error.
if (is_callable_compatible(left_item, right_item,
is_compat=self._is_subtype, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names) or
is_callable_compatible(right_item, left_item,
is_compat=self._is_subtype, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names)):
# If this is an overload that's already been matched, there's no
# problem.
if left_item not in matched_overloads:
possible_invalid_overloads.add(left_item)
if not found_match:
return False
if possible_invalid_overloads:
# There were potentially invalid overloads that were never matched to the
# supertype.
return False
return True

                    print('$ ', right_index, left_index, subtype_match)
                    print(right_item)
                    print(left_item)
                    print('error?', possible_invalid_overloads)

And that's what it produces:

$  0 0 True
def (b: builtins.int, c: builtins.str) -> builtins.float
def (b: builtins.int, c: builtins.str) -> builtins.float
$  0 1 False
def (b: builtins.int, c: builtins.str) -> builtins.float
def (b: builtins.int) -> def (c: builtins.str) -> builtins.float
$  1 0 False
def (b: builtins.int) -> def (c: builtins.str) -> builtins.float
def (b: builtins.int, c: builtins.str) -> builtins.float
$  1 1 True
def (b: builtins.int) -> def (c: builtins.str) -> builtins.float
def (b: builtins.int) -> def (c: builtins.str) -> builtins.float

error? set()

$  0 0 True
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
$  0 1 False
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
$  0 2 False
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float
visit Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float) def (c: builtins.str) -> builtins.float
items True def (b: builtins.int, c: builtins.str) -> builtins.float
$  1 0 True
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
$  1 1 True
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
$  1 2 False
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float
visit Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float) builtins.float
$  2 0 False
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float
def (a: builtins.int) -> Overload(def (b: builtins.int, c: builtins.str) -> builtins.float, def (b: builtins.int) -> def (c: builtins.str) -> builtins.float)
$  2 1 False
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float
def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float
$  2 2 True
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float
def (a: builtins.int, b: builtins.int, c: builtins.str) -> builtins.float

error? {def (a: builtins.int, b: builtins.int) -> def (c: builtins.str) -> builtins.float}

So, the logic seems incorrect here. But, I am not sure what's wrong. I don't understand this part:

                    # Order matters: we need to make sure that the index of
                    # this item is at least the index of the previous one.
                    if subtype_match and previous_match_left_index <= left_index:
                        if not found_match:
                            # Update the index of the previous match.
                            previous_match_left_index = left_index
                            found_match = True
                            matched_overloads.add(left_item)
                            possible_invalid_overloads.discard(left_item)

I would love to help with this one, I will try to send a PR shortly.

Related dry-python/returns#462

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions