Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Polynomial time-complexity of typing.overload #10004

Open
pelson opened this issue Feb 1, 2021 · 3 comments · May be fixed by #15865
Open

Polynomial time-complexity of typing.overload #10004

pelson opened this issue Feb 1, 2021 · 3 comments · May be fixed by #15865

Comments

@pelson
Copy link

pelson commented Feb 1, 2021

Feature

The time complexity of dealing with function overloads in mypy appears to be polynomial. Naively, it seems this is not a fundamental limit (I have no experience in static analysis to base this on), below is a log-log graph of number of overloads vs execution time:

Figure_1

The data for the graph was generated with the following code:
import json
import pathlib
import tempfile
import textwrap
import time

from mypy import api


def generate_n_overloads(n: int, types=('str', 'int', 'float')) -> str:
    result = ['import typing']

    for name_count in range(n):
        type_name = types[name_count % len(types)]
        name = f'overload_{name_count}_{type_name}'
        annotation = (
            textwrap.dedent(f'''
        @typing.overload
        def get_param(param: typing.Literal['{name}']) -> {type_name}: ...
        ''')
        )
        result.append(annotation)
    return '\n'.join(result)


def build_test_case(dest_dir: pathlib.Path, n_overloads: int) -> pathlib.Path:
    pkg_dir = dest_dir / 'example_overloads_pkg'
    pkg_dir.mkdir(exist_ok=True)
    stubfile = pkg_dir / '__init__.pyi'
    stubfile.write_text(generate_n_overloads(n_overloads))

    pkg_file = pkg_dir / '__init__.py'
    pkg_file.write_text('def get_param(param): ...')

    test_script = dest_dir / 'script.py'
    test_script.write_text(textwrap.dedent('''
        from example_overloads_pkg import get_param
        
        get_param('overload_3_str').upper()  # Good
        get_param('overload_2_float').upper()  # Bad
    '''))
    return test_script


if __name__ == '__main__':
    tmpdir = tempfile.TemporaryDirectory()

    scale_times = {}
    numbers_to_check = [16, 32, 64, 128, 256, 512, 1024]
    # numbers_to_check = [16, 32, 64, 128, 256, 300, 400, 500, 600, 700, 800, 900, 1000, 1500, 2000, 2500, 3000, 4000, 5000, 6000, 8000, 10000]
    for number_of_overloads in numbers_to_check:
        print(f'Overload size: {number_of_overloads}')
        test_script = build_test_case(pathlib.Path(tmpdir.name), number_of_overloads)
        start = time.perf_counter()
        result = api.run([str(test_script)])
        end = time.perf_counter()
        elapsed = end - start

        print(' Time:', end - start)
        if scale_times:
            last_n = next(reversed(scale_times))
            print(f" Slowdown: x{elapsed / scale_times[last_n]:.2f}")
        scale_times[number_of_overloads] = elapsed

    # Save the times to json for plotting/analysis.
    with open('times.json', 'wt') as fh:
        json.dump(scale_times, fh)

And the plot:

import json

import matplotlib.pyplot as plt
import numpy as np


with open('times.json', 'rt') as fh:
    times = json.load(fh)

n_points, times = zip(*times.items())

n_points = np.array(n_points, dtype=int)
times = np.array(times)
fit = np.polyfit(n_points, times, 2)
p = np.poly1d(fit)

plt.title('Number of overloads vs execution time of mypy')

plt.xscale('log')
plt.yscale('log')


plt.scatter(n_points, times)
plt.xlabel('Number of overloads')
plt.ylabel('Time to run mypy / seconds')
plt.show()

The full dataset (times.json):

{"16": 0.11653269396629184, "32": 0.13803347101202235, "64": 0.1889677940052934, "128": 0.42529276601271704, "256": 1.4768535409821197, "300": 2.165687720000278, "400": 3.581374307977967, "500": 5.615040614036843, "600": 8.182245634961873, "700": 10.219899011019152, "800": 14.204097100009676, "900": 19.01026200200431, "1000": 21.914704270020593, "1500": 50.26287814299576, "2000": 88.03798637498403, "2500": 140.229756515997, "3000": 206.8722462329897, "4000": 373.4201490940177, "5000": 596.6820141010103, "6000": 831.5039637690061, "8000": 1579.1185438489774, "10000": 2624.7711193099967}

I have inherited a library which has a very simple DSL to query data from a database and which returns a dictionary of different types depending on the input. Indeed, the language is so simple that I could map these to a collection of literal overloads, with a generic fallback for new/unmapped data. In practice this looks something like:

import typing

@typing.overload
def get_param(param: typing.Literal['some_float']) -> float: ...


@typing.overload
def get_param(param: typing.Literal['some_other_int']) -> int: ...


@typing.overload
def get_param(param: typing.Literal['further_str']) -> str: ...

Being able to do static analysis (and completion) on such an interface would be hugely beneficial to its users for correctness and ease of use. Unfortunately the total number of overloads would need to be in the order of 100,000... clearly from the graph, this number is prohibitive with the current implementation of mypy's overload functionality.

Pitch

I'm therefore reaching out to find out if there is a known good reason for this time-complexity, and potentially seeking help/pointers to address the issue in the implementation.

Other places this could be useful

Truth be told, this is a highly specialised requirement which would apply to a fairly limited set of use cases (some of which are likely anti-patterns in the first place).

Given the high number of overloads involved for this time-complexity to become an issue (at around 100 overloads), it is clear that you'd almost always want to generate the stubs rather than hand-craft them. Examples of where it could come in handy are other string based lookups - I can imagine for example wanting to generate stubs for all of the JClasses that exist in JPype (which can be any Java classes available in the JVM), or perhaps for some interface that allows you to get hold of a CSS property by name and which returns a slightly different type for each of the ~520 valid tags. I could also imagine this being something that could be used for generating the permutations of a multiple-dispatch system.

@henribru
Copy link
Contributor

henribru commented Feb 6, 2021

This might be of some relevance: microsoft/pyright@d9f621e Seems Pyright only avoids quadratic complexity by skipping some form of consistency check when the number of overloads are large enough.

I'm also interested in this feature, as I maintain one stubs package which currently uses about 300 overloads and another which would benefit from overloads but would require around 10 000 of them.

@henribru
Copy link
Contributor

henribru commented Mar 4, 2021

Looks like a consistency check is probably the cause for Mypy as well: https://github.com/python/mypy/blob/master/mypy/checker.py#L486

@hauntsaninja
Copy link
Collaborator

#10922 can help here. This skips the quadratic check in files which mypy wouldn't have reported errors in (i.e. third party code), for instance, installed stub packages.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants