From 17b1c3c3c9dc2edd7bc2b99a77e1453e0f574684 Mon Sep 17 00:00:00 2001 From: KotlinIsland Date: Thu, 21 Jul 2022 00:04:54 +1000 Subject: [PATCH] format files with black/isort --- .gitignore | 2 +- conftest.py | 9 +- docs/source/conf.py | 168 +- misc/actions_stubs.py | 113 +- misc/analyze_cache.py | 62 +- misc/apply-cache-diff.py | 15 +- misc/async_matrix.py | 57 +- misc/build_wheel.py | 90 +- misc/cherry-pick-typeshed.py | 32 +- misc/convert-cache.py | 27 +- misc/diff-cache.py | 9 +- misc/dump-ast.py | 24 +- misc/fix_annotate.py | 67 +- misc/incremental_checker.py | 285 ++- misc/perf_checker.py | 48 +- misc/proper_plugin.py | 99 +- misc/sync-typeshed.py | 69 +- misc/test_case_to_actual.py | 24 +- misc/touch_checker.py | 62 +- misc/variadics.py | 53 +- mypy/__main__.py | 2 +- mypy/api.py | 9 +- mypy/applytype.py | 35 +- mypy/argmap.py | 98 +- mypy/backports.py | 2 + mypy/binder.py | 87 +- mypy/bogus_type.py | 5 +- mypy/build.py | 1259 ++++++----- mypy/checker.py | 2747 +++++++++++++++---------- mypy/checkexpr.py | 2658 ++++++++++++++---------- mypy/checkmember.py | 568 +++-- mypy/checkpattern.py | 255 +-- mypy/checkstrformat.py | 688 ++++--- mypy/config_parser.py | 380 ++-- mypy/constraints.py | 401 ++-- mypy/copytype.py | 40 +- mypy/defaults.py | 13 +- mypy/dmypy/__main__.py | 2 +- mypy/dmypy/client.py | 359 ++-- mypy/dmypy_os.py | 13 +- mypy/dmypy_server.py | 394 ++-- mypy/dmypy_util.py | 4 +- mypy/erasetype.py | 51 +- mypy/errorcodes.py | 19 +- mypy/errors.py | 540 +++-- mypy/expandtype.py | 79 +- mypy/exprtotype.py | 116 +- mypy/fastparse.py | 1000 +++++---- mypy/fastparse2.py | 608 +++--- mypy/find_sources.py | 31 +- mypy/fixup.py | 117 +- mypy/freetree.py | 2 +- mypy/fscache.py | 22 +- mypy/fswatcher.py | 8 +- mypy/gclogger.py | 21 +- mypy/indirection.py | 2 +- mypy/infer.py | 35 +- mypy/ipc.py | 88 +- mypy/join.py | 127 +- mypy/literals.py | 93 +- mypy/lookup.py | 14 +- mypy/main.py | 1188 ++++++----- mypy/maptype.py | 14 +- mypy/meet.py | 278 ++- mypy/memprofile.py | 38 +- mypy/message_registry.py | 23 +- mypy/messages.py | 2312 ++++++++++++--------- mypy/metastore.py | 38 +- mypy/mixedtraverser.py | 20 +- mypy/modulefinder.py | 265 +-- mypy/moduleinspect.py | 63 +- mypy/mro.py | 10 +- mypy/nodes.py | 1591 +++++++------- mypy/operators.py | 121 +- mypy/options.py | 44 +- mypy/parse.py | 36 +- mypy/patterns.py | 26 +- mypy/plugin.py | 199 +- mypy/plugins/attrs.py | 452 ++-- mypy/plugins/common.py | 113 +- mypy/plugins/ctypes.py | 131 +- mypy/plugins/dataclasses.py | 345 ++-- mypy/plugins/default.py | 275 +-- mypy/plugins/enums.py | 52 +- mypy/plugins/functools.py | 35 +- mypy/plugins/singledispatch.py | 75 +- mypy/pyinfo.py | 17 +- mypy/reachability.py | 120 +- mypy/renaming.py | 32 +- mypy/report.py | 520 ++--- mypy/sametypes.py | 97 +- mypy/scope.py | 8 +- mypy/semanal.py | 2176 ++++++++++++-------- mypy/semanal_classprop.py | 73 +- mypy/semanal_enum.py | 146 +- mypy/semanal_infer.py | 28 +- mypy/semanal_main.py | 205 +- mypy/semanal_namedtuple.py | 319 +-- mypy/semanal_newtype.py | 128 +- mypy/semanal_pass1.py | 35 +- mypy/semanal_shared.py | 133 +- mypy/semanal_typeargs.py | 73 +- mypy/semanal_typeddict.py | 224 +- mypy/server/astdiff.py | 281 +-- mypy/server/astmerge.py | 114 +- mypy/server/aststrip.py | 43 +- mypy/server/deps.py | 353 ++-- mypy/server/mergecheck.py | 32 +- mypy/server/objgraph.py | 54 +- mypy/server/subexpr.py | 34 +- mypy/server/target.py | 6 +- mypy/server/trigger.py | 4 +- mypy/server/update.py | 417 ++-- mypy/sharedparse.py | 1 + mypy/solve.py | 11 +- mypy/split_namespace.py | 13 +- mypy/state.py | 2 +- mypy/stats.py | 171 +- mypy/strconv.py | 410 ++-- mypy/stubdoc.py | 161 +- mypy/stubgen.py | 1085 +++++----- mypy/stubgenc.py | 494 +++-- mypy/stubinfo.py | 124 +- mypy/stubtest.py | 91 +- mypy/stubutil.py | 117 +- mypy/subtypes.py | 828 +++++--- mypy/suggestions.py | 348 ++-- mypy/test/config.py | 10 +- mypy/test/data.py | 340 +-- mypy/test/helpers.py | 265 ++- mypy/test/test_find_sources.py | 51 +- mypy/test/testapi.py | 22 +- mypy/test/testargs.py | 21 +- mypy/test/testcheck.py | 160 +- mypy/test/testcmdline.py | 87 +- mypy/test/testconstraints.py | 8 +- mypy/test/testdaemon.py | 97 +- mypy/test/testdeps.py | 56 +- mypy/test/testdiff.py | 44 +- mypy/test/testerrorstream.py | 26 +- mypy/test/testfinegrained.py | 179 +- mypy/test/testfinegrainedcache.py | 7 +- mypy/test/testformatter.py | 98 +- mypy/test/testfscache.py | 98 +- mypy/test/testgraph.py | 71 +- mypy/test/testinfer.py | 382 ++-- mypy/test/testipc.py | 30 +- mypy/test/testmerge.py | 139 +- mypy/test/testmodulefinder.py | 49 +- mypy/test/testmypyc.py | 10 +- mypy/test/testparse.py | 57 +- mypy/test/testpep561.py | 120 +- mypy/test/testpythoneval.py | 51 +- mypy/test/testreports.py | 26 +- mypy/test/testsemanal.py | 146 +- mypy/test/testsolve.py | 169 +- mypy/test/teststubgen.py | 954 +++++---- mypy/test/teststubinfo.py | 16 +- mypy/test/teststubtest.py | 132 +- mypy/test/testsubtypes.py | 176 +- mypy/test/testtransform.py | 61 +- mypy/test/testtypegen.py | 46 +- mypy/test/testtypes.py | 653 +++--- mypy/test/testutil.py | 6 +- mypy/test/typefixture.py | 304 +-- mypy/test/visitors.py | 18 +- mypy/traverser.py | 89 +- mypy/treetransform.py | 323 +-- mypy/tvar_scope.py | 48 +- mypy/type_visitor.py | 97 +- mypy/typeanal.py | 938 +++++---- mypy/typeops.py | 266 ++- mypy/types.py | 1629 ++++++++------- mypy/typestate.py | 41 +- mypy/typetraverser.py | 34 +- mypy/typevars.py | 38 +- mypy/typevartuples.py | 14 +- mypy/util.py | 352 ++-- mypy/version.py | 9 +- mypy/visitor.py | 345 ++-- mypyc/__main__.py | 20 +- mypyc/analysis/attrdefined.py | 159 +- mypyc/analysis/blockfreq.py | 2 +- mypyc/analysis/dataflow.py | 205 +- mypyc/analysis/ircheck.py | 133 +- mypyc/analysis/selfleaks.py | 64 +- mypyc/build.py | 269 +-- mypyc/codegen/cstring.py | 11 +- mypyc/codegen/emit.py | 802 ++++---- mypyc/codegen/emitclass.py | 1048 +++++----- mypyc/codegen/emitfunc.py | 502 +++-- mypyc/codegen/emitmodule.py | 608 +++--- mypyc/codegen/emitwrapper.py | 711 ++++--- mypyc/codegen/literals.py | 51 +- mypyc/common.py | 37 +- mypyc/crash.py | 18 +- mypyc/doc/conf.py | 23 +- mypyc/errors.py | 6 +- mypyc/ir/class_ir.py | 230 ++- mypyc/ir/func_ir.py | 149 +- mypyc/ir/module_ir.py | 50 +- mypyc/ir/ops.py | 303 +-- mypyc/ir/pprint.py | 244 ++- mypyc/ir/rtypes.py | 341 +-- mypyc/irbuild/ast_helpers.py | 40 +- mypyc/irbuild/builder.py | 504 +++-- mypyc/irbuild/callable_class.py | 58 +- mypyc/irbuild/classdef.py | 413 ++-- mypyc/irbuild/constant_fold.py | 36 +- mypyc/irbuild/context.py | 35 +- mypyc/irbuild/env_class.py | 49 +- mypyc/irbuild/expression.py | 404 ++-- mypyc/irbuild/for_helpers.py | 352 ++-- mypyc/irbuild/format_str_tokenizer.py | 64 +- mypyc/irbuild/function.py | 437 ++-- mypyc/irbuild/generator.py | 235 ++- mypyc/irbuild/ll_builder.py | 817 ++++---- mypyc/irbuild/main.py | 67 +- mypyc/irbuild/mapper.py | 99 +- mypyc/irbuild/nonlocalcontrol.py | 70 +- mypyc/irbuild/prebuildvisitor.py | 16 +- mypyc/irbuild/prepare.py | 209 +- mypyc/irbuild/specialize.py | 376 ++-- mypyc/irbuild/statement.py | 330 +-- mypyc/irbuild/targets.py | 8 +- mypyc/irbuild/util.py | 95 +- mypyc/irbuild/visitor.py | 173 +- mypyc/irbuild/vtable.py | 15 +- mypyc/lib-rt/setup.py | 36 +- mypyc/namegen.py | 18 +- mypyc/options.py | 22 +- mypyc/primitives/bytes_ops.py | 71 +- mypyc/primitives/dict_ops.py | 221 +- mypyc/primitives/exc_ops.py | 67 +- mypyc/primitives/float_ops.py | 27 +- mypyc/primitives/generic_ops.py | 337 +-- mypyc/primitives/int_ops.py | 220 +- mypyc/primitives/list_ops.py | 204 +- mypyc/primitives/misc_ops.py | 164 +- mypyc/primitives/registry.py | 266 ++- mypyc/primitives/set_ops.py | 101 +- mypyc/primitives/str_ops.py | 160 +- mypyc/primitives/tuple_ops.py | 58 +- mypyc/rt_subtype.py | 18 +- mypyc/sametype.py | 41 +- mypyc/subtype.py | 32 +- mypyc/test/config.py | 4 +- mypyc/test/test_alwaysdefined.py | 24 +- mypyc/test/test_analysis.py | 59 +- mypyc/test/test_cheader.py | 27 +- mypyc/test/test_commandline.py | 41 +- mypyc/test/test_emit.py | 23 +- mypyc/test/test_emitclass.py | 14 +- mypyc/test/test_emitfunc.py | 776 ++++--- mypyc/test/test_emitwrapper.py | 71 +- mypyc/test/test_exceptions.py | 34 +- mypyc/test/test_external.py | 37 +- mypyc/test/test_irbuild.py | 64 +- mypyc/test/test_ircheck.py | 99 +- mypyc/test/test_literals.py | 97 +- mypyc/test/test_namegen.py | 58 +- mypyc/test/test_pprint.py | 16 +- mypyc/test/test_rarray.py | 10 +- mypyc/test/test_refcount.py | 28 +- mypyc/test/test_run.py | 252 +-- mypyc/test/test_serialization.py | 28 +- mypyc/test/test_struct.py | 55 +- mypyc/test/test_subtype.py | 10 +- mypyc/test/test_tuplename.py | 26 +- mypyc/test/testutil.py | 125 +- mypyc/transform/exceptions.py | 43 +- mypyc/transform/refcount.py | 150 +- mypyc/transform/uninit.py | 58 +- runtests.py | 101 +- scripts/find_type.py | 39 +- setup.py | 233 ++- tox.ini | 4 +- 277 files changed, 32334 insertions(+), 24210 deletions(-) diff --git a/.gitignore b/.gitignore index b2306b96036f2..c6761f0ed7365 100644 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,7 @@ docs/source/_build mypyc/doc/_build *.iml /out/ -.venv +.venv* venv/ .mypy_cache/ .incremental_checker_cache.json diff --git a/conftest.py b/conftest.py index 83a6689f6373f..b40d4675c8545 100644 --- a/conftest.py +++ b/conftest.py @@ -1,8 +1,6 @@ import os.path -pytest_plugins = [ - 'mypy.test.data', -] +pytest_plugins = ["mypy.test.data"] def pytest_configure(config): @@ -14,5 +12,6 @@ def pytest_configure(config): # This function name is special to pytest. See # http://doc.pytest.org/en/latest/writing_plugins.html#initialization-command-line-and-configuration-hooks def pytest_addoption(parser) -> None: - parser.addoption('--bench', action='store_true', default=False, - help='Enable the benchmark test runs') + parser.addoption( + "--bench", action="store_true", default=False, help="Enable the benchmark test runs" + ) diff --git a/docs/source/conf.py b/docs/source/conf.py index 6f6b8b276d60a..18602dacbbcd0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,8 +12,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys import os +import sys from sphinx.application import Sphinx from sphinx.util.docfields import Field @@ -21,54 +21,54 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from mypy.version import __version__ as mypy_version # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = ['sphinx.ext.intersphinx'] +extensions = ["sphinx.ext.intersphinx"] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'mypy' -copyright = u'2012-2022 Jukka Lehtosalo and mypy contributors' +project = "mypy" +copyright = "2012-2022 Jukka Lehtosalo and mypy contributors" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = mypy_version.split('-')[0] +version = mypy_version.split("-")[0] # The full version, including alpha/beta/rc tags. release = mypy_version # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -76,27 +76,27 @@ # The reST default role (used for this markup: `text`) to use for all # documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- @@ -108,17 +108,17 @@ # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. @@ -127,116 +127,108 @@ # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'mypydoc' +htmlhelp_basename = "mypydoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - ('index', 'Mypy.tex', u'Mypy Documentation', - u'Jukka', 'manual'), -] +latex_documents = [("index", "Mypy.tex", "Mypy Documentation", "Jukka", "manual")] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output --------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - ('index', 'mypy', u'Mypy Documentation', - [u'Jukka Lehtosalo'], 1) -] +man_pages = [("index", "mypy", "Mypy Documentation", ["Jukka Lehtosalo"], 1)] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------- @@ -245,43 +237,49 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'Mypy', u'Mypy Documentation', - u'Jukka', 'Mypy', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "Mypy", + "Mypy Documentation", + "Jukka", + "Mypy", + "One line description of project.", + "Miscellaneous", + ) ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False -rst_prolog = '.. |...| unicode:: U+2026 .. ellipsis\n' +rst_prolog = ".. |...| unicode:: U+2026 .. ellipsis\n" intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'six': ('https://six.readthedocs.io', None), - 'attrs': ('http://www.attrs.org/en/stable', None), - 'cython': ('http://docs.cython.org/en/latest', None), - 'monkeytype': ('https://monkeytype.readthedocs.io/en/latest', None), - 'setuptools': ('https://setuptools.readthedocs.io/en/latest', None), + "python": ("https://docs.python.org/3", None), + "six": ("https://six.readthedocs.io", None), + "attrs": ("http://www.attrs.org/en/stable", None), + "cython": ("http://docs.cython.org/en/latest", None), + "monkeytype": ("https://monkeytype.readthedocs.io/en/latest", None), + "setuptools": ("https://setuptools.readthedocs.io/en/latest", None), } def setup(app: Sphinx) -> None: app.add_object_type( - 'confval', - 'confval', - objname='configuration value', - indextemplate='pair: %s; configuration value', + "confval", + "confval", + objname="configuration value", + indextemplate="pair: %s; configuration value", doc_field_types=[ - Field('type', label='Type', has_arg=False, names=('type',)), - Field('default', label='Default', has_arg=False, names=('default',)), - ] + Field("type", label="Type", has_arg=False, names=("type",)), + Field("default", label="Default", has_arg=False, names=("default",)), + ], ) diff --git a/misc/actions_stubs.py b/misc/actions_stubs.py index 0d52a882463d8..d7613cb06a5f7 100644 --- a/misc/actions_stubs.py +++ b/misc/actions_stubs.py @@ -1,20 +1,27 @@ #!/usr/bin/env python3 import os import shutil -from typing import Tuple, Any +from typing import Any, Tuple + try: import click except ImportError: - print("You need the module \'click\'") + print("You need the module 'click'") exit(1) base_path = os.getcwd() # I don't know how to set callables with different args -def apply_all(func: Any, directory: str, extension: str, - to_extension: str='', exclude: Tuple[str]=('',), - recursive: bool=True, debug: bool=False) -> None: - excluded = [x+extension for x in exclude] if exclude else [] +def apply_all( + func: Any, + directory: str, + extension: str, + to_extension: str = "", + exclude: Tuple[str] = ("",), + recursive: bool = True, + debug: bool = False, +) -> None: + excluded = [x + extension for x in exclude] if exclude else [] for p, d, files in os.walk(os.path.join(base_path, directory)): for f in files: if f in excluded: @@ -24,39 +31,75 @@ def apply_all(func: Any, directory: str, extension: str, continue if to_extension: new_path = f"{inner_path[:-len(extension)]}{to_extension}" - func(inner_path,new_path) + func(inner_path, new_path) else: func(inner_path) if not recursive: break -def confirm(resp: bool=False, **kargs) -> bool: - kargs['rest'] = "to this {f2}/*{e2}".format(**kargs) if kargs.get('f2') else '' + +def confirm(resp: bool = False, **kargs) -> bool: + kargs["rest"] = "to this {f2}/*{e2}".format(**kargs) if kargs.get("f2") else "" prompt = "{act} all files {rec}matching this expression {f1}/*{e1} {rest}".format(**kargs) prompt.format(**kargs) - prompt = "{} [{}]|{}: ".format(prompt, 'Y' if resp else 'N', 'n' if resp else 'y') + prompt = "{} [{}]|{}: ".format(prompt, "Y" if resp else "N", "n" if resp else "y") while True: ans = input(prompt).lower() if not ans: return resp - if ans not in ['y','n']: - print( 'Please, enter (y) or (n).') + if ans not in ["y", "n"]: + print("Please, enter (y) or (n).") continue - if ans == 'y': + if ans == "y": return True else: return False -actions = ['cp', 'mv', 'rm'] -@click.command(context_settings=dict(help_option_names=['-h', '--help'])) -@click.option('--action', '-a', type=click.Choice(actions), required=True, help="What do I have to do :-)") -@click.option('--dir', '-d', 'directory', default='stubs', help="Directory to start search!") -@click.option('--ext', '-e', 'extension', default='.py', help="Extension \"from\" will be applied the action. Default .py") -@click.option('--to', '-t', 'to_extension', default='.pyi', help="Extension \"to\" will be applied the action if can. Default .pyi") -@click.option('--exclude', '-x', multiple=True, default=('__init__',), help="For every appear, will ignore this files. (can set multiples times)") -@click.option('--not-recursive', '-n', default=True, is_flag=True, help="Set if don't want to walk recursively.") -def main(action: str, directory: str, extension: str, to_extension: str, - exclude: Tuple[str], not_recursive: bool) -> None: + +actions = ["cp", "mv", "rm"] + + +@click.command(context_settings=dict(help_option_names=["-h", "--help"])) +@click.option( + "--action", "-a", type=click.Choice(actions), required=True, help="What do I have to do :-)" +) +@click.option("--dir", "-d", "directory", default="stubs", help="Directory to start search!") +@click.option( + "--ext", + "-e", + "extension", + default=".py", + help='Extension "from" will be applied the action. Default .py', +) +@click.option( + "--to", + "-t", + "to_extension", + default=".pyi", + help='Extension "to" will be applied the action if can. Default .pyi', +) +@click.option( + "--exclude", + "-x", + multiple=True, + default=("__init__",), + help="For every appear, will ignore this files. (can set multiples times)", +) +@click.option( + "--not-recursive", + "-n", + default=True, + is_flag=True, + help="Set if don't want to walk recursively.", +) +def main( + action: str, + directory: str, + extension: str, + to_extension: str, + exclude: Tuple[str], + not_recursive: bool, +) -> None: """ This script helps to copy/move/remove files based on their extension. @@ -86,26 +129,26 @@ def main(action: str, directory: str, extension: str, to_extension: str, """ if action not in actions: - print("Your action have to be one of this: {}".format(', '.join(actions))) + print("Your action have to be one of this: {}".format(", ".join(actions))) return - rec = "[Recursively] " if not_recursive else '' - if not extension.startswith('.'): + rec = "[Recursively] " if not_recursive else "" + if not extension.startswith("."): extension = f".{extension}" - if not to_extension.startswith('.'): + if not to_extension.startswith("."): to_extension = f".{to_extension}" - if directory.endswith('/'): + if directory.endswith("/"): directory = directory[:-1] - if action == 'cp': - if confirm(act='Copy',rec=rec, f1=directory, e1=extension, f2=directory, e2=to_extension): + if action == "cp": + if confirm(act="Copy", rec=rec, f1=directory, e1=extension, f2=directory, e2=to_extension): apply_all(shutil.copy, directory, extension, to_extension, exclude, not_recursive) - elif action == 'rm': - if confirm(act='Remove',rec=rec, f1=directory, e1=extension): + elif action == "rm": + if confirm(act="Remove", rec=rec, f1=directory, e1=extension): apply_all(os.remove, directory, extension, exclude=exclude, recursive=not_recursive) - elif action == 'mv': - if confirm(act='Move',rec=rec, f1=directory, e1=extension, f2=directory, e2=to_extension): + elif action == "mv": + if confirm(act="Move", rec=rec, f1=directory, e1=extension, f2=directory, e2=to_extension): apply_all(shutil.move, directory, extension, to_extension, exclude, not_recursive) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/analyze_cache.py b/misc/analyze_cache.py index 5f2048b5c11cb..333a188971b66 100644 --- a/misc/analyze_cache.py +++ b/misc/analyze_cache.py @@ -1,19 +1,25 @@ #!/usr/bin/env python -from typing import Any, Dict, Iterable, List, Optional -from collections import Counter - +import json import os import os.path -import json +from collections import Counter +from typing import Any, Dict, Iterable, List, Optional ROOT = ".mypy_cache/3.5" JsonDict = Dict[str, Any] + class CacheData: - def __init__(self, filename: str, data_json: JsonDict, meta_json: JsonDict, - data_size: int, meta_size: int) -> None: + def __init__( + self, + filename: str, + data_json: JsonDict, + meta_json: JsonDict, + data_size: int, + meta_size: int, + ) -> None: self.filename = filename self.data = data_json self.meta = meta_json @@ -33,6 +39,7 @@ def extract(chunks: Iterable[JsonDict]) -> Iterable[JsonDict]: yield from extract(chunk.values()) elif isinstance(chunk, list): yield from extract(chunk) + yield from extract([chunk.data for chunk in chunks]) @@ -46,8 +53,9 @@ def load_json(data_path: str, meta_path: str) -> CacheData: data_size = os.path.getsize(data_path) meta_size = os.path.getsize(meta_path) - return CacheData(data_path.replace(".data.json", ".*.json"), - data_json, meta_json, data_size, meta_size) + return CacheData( + data_path.replace(".data.json", ".*.json"), data_json, meta_json, data_size, meta_size + ) def get_files(root: str) -> Iterable[CacheData]: @@ -56,17 +64,17 @@ def get_files(root: str) -> Iterable[CacheData]: if filename.endswith(".data.json"): meta_filename = filename.replace(".data.json", ".meta.json") yield load_json( - os.path.join(dirpath, filename), - os.path.join(dirpath, meta_filename)) + os.path.join(dirpath, filename), os.path.join(dirpath, meta_filename) + ) def pluck(name: str, chunks: Iterable[JsonDict]) -> Iterable[JsonDict]: - return (chunk for chunk in chunks if chunk['.class'] == name) + return (chunk for chunk in chunks if chunk[".class"] == name) def report_counter(counter: Counter, amount: Optional[int] = None) -> None: for name, count in counter.most_common(amount): - print(f' {count: <8} {name}') + print(f" {count: <8} {name}") print() @@ -77,6 +85,7 @@ def report_most_common(chunks: List[JsonDict], amount: Optional[int] = None) -> def compress(chunk: JsonDict) -> JsonDict: cache = {} # type: Dict[int, JsonDict] counter = 0 + def helper(chunk: Any) -> Any: nonlocal counter if not isinstance(chunk, dict): @@ -89,8 +98,8 @@ def helper(chunk: Any) -> Any: if id in cache: return cache[id] else: - cache[id] = {'.id': counter} - chunk['.cache_id'] = counter + cache[id] = {".id": counter} + chunk[".cache_id"] = counter counter += 1 for name in sorted(chunk.keys()): @@ -101,21 +110,24 @@ def helper(chunk: Any) -> Any: chunk[name] = helper(value) return chunk + out = helper(chunk) return out + def decompress(chunk: JsonDict) -> JsonDict: cache = {} # type: Dict[int, JsonDict] + def helper(chunk: Any) -> Any: if not isinstance(chunk, dict): return chunk - if '.id' in chunk: - return cache[chunk['.id']] + if ".id" in chunk: + return cache[chunk[".id"]] counter = None - if '.cache_id' in chunk: - counter = chunk['.cache_id'] - del chunk['.cache_id'] + if ".cache_id" in chunk: + counter = chunk[".cache_id"] + del chunk[".cache_id"] for name in sorted(chunk.keys()): value = chunk[name] @@ -128,9 +140,8 @@ def helper(chunk: Any) -> Any: cache[counter] = chunk return chunk - return helper(chunk) - + return helper(chunk) def main() -> None: @@ -150,7 +161,7 @@ def main() -> None: build = None for chunk in json_chunks: - if 'build.*.json' in chunk.filename: + if "build.*.json" in chunk.filename: build = chunk break original = json.dumps(build.data, sort_keys=True) @@ -166,8 +177,7 @@ def main() -> None: print("Lossless conversion back", original == decompressed) - - '''var_chunks = list(pluck("Var", class_chunks)) + """var_chunks = list(pluck("Var", class_chunks)) report_most_common(var_chunks, 20) print() @@ -182,8 +192,8 @@ def main() -> None: print() print("Most common") report_most_common(class_chunks, 20) - print()''' + print()""" -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/apply-cache-diff.py b/misc/apply-cache-diff.py index a9e13a1af9a51..53fd7e52b0663 100644 --- a/misc/apply-cache-diff.py +++ b/misc/apply-cache-diff.py @@ -12,7 +12,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from mypy.metastore import MetadataStore, FilesystemMetadataStore, SqliteMetadataStore +from mypy.metastore import FilesystemMetadataStore, MetadataStore, SqliteMetadataStore def make_cache(input_dir: str, sqlite: bool) -> MetadataStore: @@ -34,7 +34,7 @@ def apply_diff(cache_dir: str, diff_file: str, sqlite: bool = False) -> None: cache.remove(file) else: cache.write(file, data) - if file.endswith('.meta.json') and "@deps" not in file: + if file.endswith(".meta.json") and "@deps" not in file: meta = json.loads(data) old_deps["snapshot"][meta["id"]] = meta["hash"] @@ -45,16 +45,13 @@ def apply_diff(cache_dir: str, diff_file: str, sqlite: bool = False) -> None: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--sqlite', action='store_true', default=False, - help='Use a sqlite cache') - parser.add_argument('cache_dir', - help="Directory for the cache") - parser.add_argument('diff', - help="Cache diff file") + parser.add_argument("--sqlite", action="store_true", default=False, help="Use a sqlite cache") + parser.add_argument("cache_dir", help="Directory for the cache") + parser.add_argument("diff", help="Cache diff file") args = parser.parse_args() apply_diff(args.cache_dir, args.diff, args.sqlite) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/async_matrix.py b/misc/async_matrix.py index c266d0400aba4..33d194c291168 100644 --- a/misc/async_matrix.py +++ b/misc/async_matrix.py @@ -11,48 +11,60 @@ # The various things you might try to use in `await` or `yield from`. + def plain_generator() -> Generator[str, None, int]: - yield 'a' + yield "a" return 1 + async def plain_coroutine() -> int: return 1 + @coroutine def decorated_generator() -> Generator[str, None, int]: - yield 'a' + yield "a" return 1 + @coroutine async def decorated_coroutine() -> int: return 1 + class It(Iterator[str]): stop = False - def __iter__(self) -> 'It': + + def __iter__(self) -> "It": return self + def __next__(self) -> str: if self.stop: - raise StopIteration('end') + raise StopIteration("end") else: self.stop = True - return 'a' + return "a" + def other_iterator() -> It: return It() + class Aw(Awaitable[int]): def __await__(self) -> Generator[str, Any, int]: - yield 'a' + yield "a" return 1 + def other_coroutine() -> Aw: return Aw() + # The various contexts in which `await` or `yield from` might occur. + def plain_host_generator(func) -> Generator[str, None, None]: - yield 'a' + yield "a" x = 0 f = func() try: @@ -63,13 +75,15 @@ def plain_host_generator(func) -> Generator[str, None, None]: except AttributeError: pass + async def plain_host_coroutine(func) -> None: x = 0 x = await func() + @coroutine def decorated_host_generator(func) -> Generator[str, None, None]: - yield 'a' + yield "a" x = 0 f = func() try: @@ -80,22 +94,34 @@ def decorated_host_generator(func) -> Generator[str, None, None]: except AttributeError: pass + @coroutine async def decorated_host_coroutine(func) -> None: x = 0 x = await func() + # Main driver. + def main(): - verbose = ('-v' in sys.argv) - for host in [plain_host_generator, plain_host_coroutine, - decorated_host_generator, decorated_host_coroutine]: + verbose = "-v" in sys.argv + for host in [ + plain_host_generator, + plain_host_coroutine, + decorated_host_generator, + decorated_host_coroutine, + ]: print() print("==== Host:", host.__name__) - for func in [plain_generator, plain_coroutine, - decorated_generator, decorated_coroutine, - other_iterator, other_coroutine]: + for func in [ + plain_generator, + plain_coroutine, + decorated_generator, + decorated_coroutine, + other_iterator, + other_coroutine, + ]: print(" ---- Func:", func.__name__) try: f = host(func) @@ -114,7 +140,8 @@ def main(): except Exception as e: print(" error:", repr(e)) + # Run main(). -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/build_wheel.py b/misc/build_wheel.py index 6d1b6c669c1a5..0f62ce1ea64cf 100644 --- a/misc/build_wheel.py +++ b/misc/build_wheel.py @@ -25,7 +25,7 @@ from typing import Dict # Clang package we use on Linux -LLVM_URL = 'https://github.com/mypyc/mypy_mypyc-wheels/releases/download/llvm/llvm-centos-5.tar.gz' +LLVM_URL = "https://github.com/mypyc/mypy_mypyc-wheels/releases/download/llvm/llvm-centos-5.tar.gz" # Mypy repository root ROOT_DIR = os.path.dirname(os.path.dirname(__file__)) @@ -35,63 +35,77 @@ def create_environ(python_version: str) -> Dict[str, str]: """Set up environment variables for cibuildwheel.""" env = os.environ.copy() - env['CIBW_BUILD'] = f"cp{python_version}-*" + env["CIBW_BUILD"] = f"cp{python_version}-*" # Don't build 32-bit wheels - env['CIBW_SKIP'] = "*-manylinux_i686 *-win32 *-musllinux_*" + env["CIBW_SKIP"] = "*-manylinux_i686 *-win32 *-musllinux_*" # Apple Silicon support # When cross-compiling on Intel, it is not possible to test arm64 and # the arm64 part of a universal2 wheel. Warnings will be silenced with # following CIBW_TEST_SKIP - env['CIBW_ARCHS_MACOS'] = "x86_64 arm64 universal2" - env['CIBW_TEST_SKIP'] = "*-macosx_arm64 *_universal2:arm64" + env["CIBW_ARCHS_MACOS"] = "x86_64 arm64 universal2" + env["CIBW_TEST_SKIP"] = "*-macosx_arm64 *_universal2:arm64" - env['CIBW_BUILD_VERBOSITY'] = '1' + env["CIBW_BUILD_VERBOSITY"] = "1" # mypy's isolated builds don't specify the requirements mypyc needs, so install # requirements and don't use isolated builds. we need to use build-requirements.txt # with recent mypy commits to get stub packages needed for compilation. - env['CIBW_BEFORE_BUILD'] = """ + env[ + "CIBW_BEFORE_BUILD" + ] = """ pip install -r {package}/build-requirements.txt - """.replace('\n', ' ') + """.replace( + "\n", " " + ) # download a copy of clang to use to compile on linux. this was probably built in 2018, # speeds up compilation 2x - env['CIBW_BEFORE_BUILD_LINUX'] = """ + env["CIBW_BEFORE_BUILD_LINUX"] = ( + """ (cd / && curl -L %s | tar xzf -) && pip install -r {package}/build-requirements.txt - """.replace('\n', ' ') % LLVM_URL + """.replace( + "\n", " " + ) + % LLVM_URL + ) # the double negative is counterintuitive, https://github.com/pypa/pip/issues/5735 - env['CIBW_ENVIRONMENT'] = 'MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=3 PIP_NO_BUILD_ISOLATION=no' - env['CIBW_ENVIRONMENT_LINUX'] = ( - 'MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=3 PIP_NO_BUILD_ISOLATION=no ' + - 'CC=/opt/llvm/bin/clang' - ) - env['CIBW_ENVIRONMENT_WINDOWS'] = ( - 'MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=2 PIP_NO_BUILD_ISOLATION=no' + env["CIBW_ENVIRONMENT"] = "MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=3 PIP_NO_BUILD_ISOLATION=no" + env["CIBW_ENVIRONMENT_LINUX"] = ( + "MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=3 PIP_NO_BUILD_ISOLATION=no " + "CC=/opt/llvm/bin/clang" ) + env[ + "CIBW_ENVIRONMENT_WINDOWS" + ] = "MYPY_USE_MYPYC=1 MYPYC_OPT_LEVEL=2 PIP_NO_BUILD_ISOLATION=no" # lxml doesn't have a wheel for Python 3.10 on the manylinux image we use. # lxml has historically been slow to support new Pythons as well. - env['CIBW_BEFORE_TEST'] = """ + env[ + "CIBW_BEFORE_TEST" + ] = """ ( grep -v lxml {project}/mypy/test-requirements.txt > /tmp/test-requirements.txt && cp {project}/mypy/mypy-requirements.txt /tmp/mypy-requirements.txt && cp {project}/mypy/build-requirements.txt /tmp/build-requirements.txt && pip install -r /tmp/test-requirements.txt ) - """.replace('\n', ' ') + """.replace( + "\n", " " + ) # lxml currently has wheels on Windows and doesn't have grep, so special case - env['CIBW_BEFORE_TEST_WINDOWS'] = "pip install -r {project}/mypy/test-requirements.txt" + env["CIBW_BEFORE_TEST_WINDOWS"] = "pip install -r {project}/mypy/test-requirements.txt" # pytest looks for configuration files in the parent directories of where the tests live. # since we are trying to run the tests from their installed location, we copy those into # the venv. Ew ew ew. # We don't run external mypyc tests since there's some issue with compilation on the # manylinux image we use. - env['CIBW_TEST_COMMAND'] = """ + env[ + "CIBW_TEST_COMMAND" + ] = """ ( DIR=$(python -c 'import mypy, os; dn = os.path.dirname; print(dn(dn(mypy.__path__[0])))') && cp '{project}/mypy/pytest.ini' '{project}/mypy/conftest.py' $DIR @@ -102,11 +116,15 @@ def create_environ(python_version: str) -> Dict[str, str]: && MYPYC_TEST_DIR=$(python -c 'import mypyc.test; print(mypyc.test.__path__[0])') && MYPY_TEST_PREFIX='{project}/mypy' pytest $MYPYC_TEST_DIR -k 'not test_external' ) - """.replace('\n', ' ') + """.replace( + "\n", " " + ) # i ran into some flaky tests on windows, so only run testcheck. it looks like we # previously didn't run any tests on windows wheels, so this is a net win. - env['CIBW_TEST_COMMAND_WINDOWS'] = """ + env[ + "CIBW_TEST_COMMAND_WINDOWS" + ] = """ bash -c " ( DIR=$(python -c 'import mypy, os; dn = os.path.dirname; print(dn(dn(mypy.__path__[0])))') @@ -116,26 +134,34 @@ def create_environ(python_version: str) -> Dict[str, str]: && MYPY_TEST_PREFIX='{project}/mypy' pytest $MYPY_TEST_DIR/testcheck.py ) " - """.replace('\n', ' ') + """.replace( + "\n", " " + ) return env def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--python-version', required=True, metavar='XY', - help='Python version (e.g. 38 or 39)') - parser.add_argument('--output-dir', required=True, metavar='DIR', - help='Output directory for created wheels') - parser.add_argument('--extra-opts', default='', metavar='OPTIONS', - help='Extra options passed to cibuildwheel verbatim') + parser.add_argument( + "--python-version", required=True, metavar="XY", help="Python version (e.g. 38 or 39)" + ) + parser.add_argument( + "--output-dir", required=True, metavar="DIR", help="Output directory for created wheels" + ) + parser.add_argument( + "--extra-opts", + default="", + metavar="OPTIONS", + help="Extra options passed to cibuildwheel verbatim", + ) args = parser.parse_args() python_version = args.python_version output_dir = args.output_dir extra_opts = args.extra_opts environ = create_environ(python_version) - script = f'python -m cibuildwheel {extra_opts} --output-dir {output_dir} {ROOT_DIR}' + script = f"python -m cibuildwheel {extra_opts} --output-dir {output_dir} {ROOT_DIR}" subprocess.check_call(script, shell=True, env=environ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/cherry-pick-typeshed.py b/misc/cherry-pick-typeshed.py index 627c8990a155f..ae8ca3ac517a0 100644 --- a/misc/cherry-pick-typeshed.py +++ b/misc/cherry-pick-typeshed.py @@ -24,9 +24,7 @@ def main() -> None: parser.add_argument( "--typeshed-dir", help="location of typeshed", metavar="dir", required=True ) - parser.add_argument( - "commit", help="typeshed commit hash to cherry-pick" - ) + parser.add_argument("commit", help="typeshed commit hash to cherry-pick") args = parser.parse_args() typeshed_dir = args.typeshed_dir commit = args.commit @@ -41,20 +39,22 @@ def main() -> None: with tempfile.TemporaryDirectory() as d: diff_file = os.path.join(d, "diff") - out = subprocess.run(["git", "show", commit], - capture_output=True, - text=True, - check=True, - cwd=typeshed_dir) + out = subprocess.run( + ["git", "show", commit], capture_output=True, text=True, check=True, cwd=typeshed_dir + ) with open(diff_file, "w") as f: f.write(out.stdout) - subprocess.run(["git", - "apply", - "--index", - "--directory=mypy/typeshed", - "--exclude=**/tests/**", - diff_file], - check=True) + subprocess.run( + [ + "git", + "apply", + "--index", + "--directory=mypy/typeshed", + "--exclude=**/tests/**", + diff_file, + ], + check=True, + ) title = parse_commit_title(out.stdout) subprocess.run(["git", "commit", "-m", f"Typeshed cherry-pick: {title}"], check=True) @@ -63,5 +63,5 @@ def main() -> None: print(f"Cherry-picked commit {commit} from {typeshed_dir}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/convert-cache.py b/misc/convert-cache.py index 412238cfbc02f..a83eddf1bcd7e 100755 --- a/misc/convert-cache.py +++ b/misc/convert-cache.py @@ -5,22 +5,31 @@ See mypy/metastore.py for details. """ -import sys import os +import sys + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse + from mypy.metastore import FilesystemMetadataStore, SqliteMetadataStore def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument('--to-sqlite', action='store_true', default=False, - help='Convert to a sqlite cache (default: convert from)') - parser.add_argument('--output_dir', action='store', default=None, - help="Output cache location (default: same as input)") - parser.add_argument('input_dir', - help="Input directory for the cache") + parser.add_argument( + "--to-sqlite", + action="store_true", + default=False, + help="Convert to a sqlite cache (default: convert from)", + ) + parser.add_argument( + "--output_dir", + action="store", + default=None, + help="Output cache location (default: same as input)", + ) + parser.add_argument("input_dir", help="Input directory for the cache") args = parser.parse_args() input_dir = args.input_dir @@ -31,10 +40,10 @@ def main() -> None: input, output = SqliteMetadataStore(input_dir), FilesystemMetadataStore(output_dir) for s in input.list_all(): - if s.endswith('.json'): + if s.endswith(".json"): assert output.write(s, input.read(s), input.getmtime(s)), "Failed to write cache file!" output.commit() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/diff-cache.py b/misc/diff-cache.py index 11811cc3ae55d..50dd54e12b3d8 100644 --- a/misc/diff-cache.py +++ b/misc/diff-cache.py @@ -9,7 +9,6 @@ import json import os import sys - from collections import defaultdict from typing import Any, Dict, Optional, Set @@ -59,12 +58,8 @@ def unzip(x: Any) -> Any: def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument( - "--verbose", action="store_true", default=False, help="Increase verbosity" - ) - parser.add_argument( - "--sqlite", action="store_true", default=False, help="Use a sqlite cache" - ) + parser.add_argument("--verbose", action="store_true", default=False, help="Increase verbosity") + parser.add_argument("--sqlite", action="store_true", default=False, help="Use a sqlite cache") parser.add_argument("input_dir1", help="Input directory for the cache") parser.add_argument("input_dir2", help="Input directory for the cache") parser.add_argument("output", help="Output file") diff --git a/misc/dump-ast.py b/misc/dump-ast.py index 8ded2389e77db..bf59a9a012363 100755 --- a/misc/dump-ast.py +++ b/misc/dump-ast.py @@ -3,22 +3,20 @@ Parse source files and print the abstract syntax trees. """ -from typing import Tuple -import sys import argparse +import sys +from typing import Tuple +from mypy import defaults from mypy.errors import CompileError from mypy.options import Options -from mypy import defaults from mypy.parse import parse -def dump(fname: str, - python_version: Tuple[int, int], - quiet: bool = False) -> None: +def dump(fname: str, python_version: Tuple[int, int], quiet: bool = False) -> None: options = Options() options.python_version = python_version - with open(fname, 'rb') as f: + with open(fname, "rb") as f: s = f.read() tree = parse(s, fname, None, errors=None, options=options) if not quiet: @@ -28,11 +26,11 @@ def dump(fname: str, def main() -> None: # Parse a file and dump the AST (or display errors). parser = argparse.ArgumentParser( - description="Parse source files and print the abstract syntax tree (AST).", + description="Parse source files and print the abstract syntax tree (AST)." ) - parser.add_argument('--py2', action='store_true', help='parse FILEs as Python 2') - parser.add_argument('--quiet', action='store_true', help='do not print AST') - parser.add_argument('FILE', nargs='*', help='files to parse') + parser.add_argument("--py2", action="store_true", help="parse FILEs as Python 2") + parser.add_argument("--quiet", action="store_true", help="do not print AST") + parser.add_argument("FILE", nargs="*", help="files to parse") args = parser.parse_args() if args.py2: @@ -46,10 +44,10 @@ def main() -> None: dump(fname, pyversion, args.quiet) except CompileError as e: for msg in e.messages: - sys.stderr.write('%s\n' % msg) + sys.stderr.write("%s\n" % msg) status = 1 sys.exit(status) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/fix_annotate.py b/misc/fix_annotate.py index 4c34e03817036..3815dd1c26f1f 100644 --- a/misc/fix_annotate.py +++ b/misc/fix_annotate.py @@ -30,11 +30,10 @@ def foo(self, bar, baz=12): import os import re - from lib2to3.fixer_base import BaseFix +from lib2to3.fixer_util import syms, token, touch_import from lib2to3.patcomp import compile_pattern from lib2to3.pytree import Leaf, Node -from lib2to3.fixer_util import token, syms, touch_import class FixAnnotate(BaseFix): @@ -50,13 +49,13 @@ class FixAnnotate(BaseFix): funcdef< 'def' name=any parameters< '(' [args=any] ')' > ':' suite=any+ > """ - counter = None if not os.getenv('MAXFIXES') else int(os.getenv('MAXFIXES')) + counter = None if not os.getenv("MAXFIXES") else int(os.getenv("MAXFIXES")) def transform(self, node, results): if FixAnnotate.counter is not None: if FixAnnotate.counter <= 0: return - suite = results['suite'] + suite = results["suite"] children = suite[0].children # NOTE: I've reverse-engineered the structure of the parse tree. @@ -80,7 +79,7 @@ def transform(self, node, results): # Check if there's already an annotation. for ch in children: - if ch.prefix.lstrip().startswith('# type:'): + if ch.prefix.lstrip().startswith("# type:"): return # There's already a # type: comment here; don't change anything. # Compute the annotation @@ -89,26 +88,28 @@ def transform(self, node, results): # Insert '# type: {annot}' comment. # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. if len(children) >= 2 and children[1].type == token.INDENT: - children[1].prefix = '{}# type: {}\n{}'.format(children[1].value, annot, children[1].prefix) + children[1].prefix = "{}# type: {}\n{}".format( + children[1].value, annot, children[1].prefix + ) children[1].changed() if FixAnnotate.counter is not None: FixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top. - if 'Any' in annot: - touch_import('typing', 'Any', node) + if "Any" in annot: + touch_import("typing", "Any", node) def make_annotation(self, node, results): - name = results['name'] + name = results["name"] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) decorators = self.get_decorators(node) is_method = self.is_method(node) - if name.value == '__init__' or not self.has_return_exprs(node): - restype = 'None' + if name.value == "__init__" or not self.has_return_exprs(node): + restype = "None" else: - restype = 'Any' - args = results.get('args') + restype = "Any" + args = results.get("args") argtypes = [] if isinstance(args, Node): children = args.children @@ -118,48 +119,48 @@ def make_annotation(self, node, results): children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* - stars = inferred_type = '' + stars = inferred_type = "" in_default = False at_start = True for child in children: if isinstance(child, Leaf): - if child.value in ('*', '**'): + if child.value in ("*", "**"): stars += child.value elif child.type == token.NAME and not in_default: - if not is_method or not at_start or 'staticmethod' in decorators: - inferred_type = 'Any' + if not is_method or not at_start or "staticmethod" in decorators: + inferred_type = "Any" else: # Always skip the first argument if it's named 'self'. # Always skip the first argument of a class method. - if child.value == 'self' or 'classmethod' in decorators: + if child.value == "self" or "classmethod" in decorators: pass else: - inferred_type = 'Any' - elif child.value == '=': + inferred_type = "Any" + elif child.value == "=": in_default = True - elif in_default and child.value != ',': + elif in_default and child.value != ",": if child.type == token.NUMBER: - if re.match(r'\d+[lL]?$', child.value): - inferred_type = 'int' + if re.match(r"\d+[lL]?$", child.value): + inferred_type = "int" else: - inferred_type = 'float' # TODO: complex? + inferred_type = "float" # TODO: complex? elif child.type == token.STRING: - if child.value.startswith(('u', 'U')): - inferred_type = 'unicode' + if child.value.startswith(("u", "U")): + inferred_type = "unicode" else: - inferred_type = 'str' - elif child.type == token.NAME and child.value in ('True', 'False'): - inferred_type = 'bool' - elif child.value == ',': + inferred_type = "str" + elif child.type == token.NAME and child.value in ("True", "False"): + inferred_type = "bool" + elif child.value == ",": if inferred_type: argtypes.append(stars + inferred_type) # Reset - stars = inferred_type = '' + stars = inferred_type = "" in_default = False at_start = False if inferred_type: argtypes.append(stars + inferred_type) - return '(' + ', '.join(argtypes) + ') -> ' + restype + return "(" + ", ".join(argtypes) + ") -> " + restype # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. @@ -180,7 +181,7 @@ def get_decorators(self, node): results = {} if not self.decorated.match(node.parent, results): return [] - decorators = results.get('dd') or [results['d']] + decorators = results.get("dd") or [results["d"]] decs = [] for d in decorators: for child in d.children: diff --git a/misc/incremental_checker.py b/misc/incremental_checker.py index 8eea983ff5994..8a441d6dc4017 100755 --- a/misc/incremental_checker.py +++ b/misc/incremental_checker.py @@ -31,9 +31,6 @@ python3 misc/incremental_checker.py commit 2a432b """ -from typing import Any, Dict, List, Optional, Tuple - -from argparse import ArgumentParser, RawDescriptionHelpFormatter, Namespace import base64 import json import os @@ -44,7 +41,8 @@ import sys import textwrap import time - +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from typing import Any, Dict, List, Optional, Tuple CACHE_PATH = ".incremental_checker_cache.json" MYPY_REPO_URL = "https://github.com/python/mypy.git" @@ -56,7 +54,7 @@ def print_offset(text: str, indent_length: int = 4) -> None: print() - print(textwrap.indent(text, ' ' * indent_length)) + print(textwrap.indent(text, " " * indent_length)) print() @@ -67,21 +65,19 @@ def delete_folder(folder_path: str) -> None: def execute(command: List[str], fail_on_error: bool = True) -> Tuple[str, str, int]: proc = subprocess.Popen( - ' '.join(command), - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - shell=True) + " ".join(command), stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=True + ) stdout_bytes, stderr_bytes = proc.communicate() # type: Tuple[bytes, bytes] - stdout, stderr = stdout_bytes.decode('utf-8'), stderr_bytes.decode('utf-8') + stdout, stderr = stdout_bytes.decode("utf-8"), stderr_bytes.decode("utf-8") if fail_on_error and proc.returncode != 0: - print('EXECUTED COMMAND:', repr(command)) - print('RETURN CODE:', proc.returncode) + print("EXECUTED COMMAND:", repr(command)) + print("RETURN CODE:", proc.returncode) print() - print('STDOUT:') + print("STDOUT:") print_offset(stdout) - print('STDERR:') + print("STDERR:") print_offset(stderr) - raise RuntimeError('Unexpected error from external tool.') + raise RuntimeError("Unexpected error from external tool.") return stdout, stderr, proc.returncode @@ -100,32 +96,35 @@ def initialize_repo(repo_url: str, temp_repo_path: str, branch: str) -> None: def get_commits(repo_folder_path: str, commit_range: str) -> List[Tuple[str, str]]: - raw_data, _stderr, _errcode = execute([ - "git", "-C", repo_folder_path, "log", "--reverse", "--oneline", commit_range]) + raw_data, _stderr, _errcode = execute( + ["git", "-C", repo_folder_path, "log", "--reverse", "--oneline", commit_range] + ) output = [] - for line in raw_data.strip().split('\n'): - commit_id, _, message = line.partition(' ') + for line in raw_data.strip().split("\n"): + commit_id, _, message = line.partition(" ") output.append((commit_id, message)) return output def get_commits_starting_at(repo_folder_path: str, start_commit: str) -> List[Tuple[str, str]]: print(f"Fetching commits starting at {start_commit}") - return get_commits(repo_folder_path, f'{start_commit}^..HEAD') + return get_commits(repo_folder_path, f"{start_commit}^..HEAD") def get_nth_commit(repo_folder_path: str, n: int) -> Tuple[str, str]: print(f"Fetching last {n} commits (or all, if there are fewer commits than n)") - return get_commits(repo_folder_path, f'-{n}')[0] - - -def run_mypy(target_file_path: Optional[str], - mypy_cache_path: str, - mypy_script: Optional[str], - *, - incremental: bool = False, - daemon: bool = False, - verbose: bool = False) -> Tuple[float, str, Dict[str, Any]]: + return get_commits(repo_folder_path, f"-{n}")[0] + + +def run_mypy( + target_file_path: Optional[str], + mypy_cache_path: str, + mypy_script: Optional[str], + *, + incremental: bool = False, + daemon: bool = False, + verbose: bool = False, +) -> Tuple[float, str, Dict[str, Any]]: """Runs mypy against `target_file_path` and returns what mypy prints to stdout as a string. If `incremental` is set to True, this function will use store and retrieve all caching data @@ -165,19 +164,26 @@ def filter_daemon_stats(output: str) -> Tuple[str, Dict[str, Any]]: lines = output.splitlines() output_lines = [] for line in lines: - m = re.match(r'(\w+)\s+:\s+(.*)', line) + m = re.match(r"(\w+)\s+:\s+(.*)", line) if m: key, value = m.groups() stats[key] = value else: output_lines.append(line) if output_lines: - output_lines.append('\n') - return '\n'.join(output_lines), stats + output_lines.append("\n") + return "\n".join(output_lines), stats def start_daemon(mypy_cache_path: str) -> None: - cmd = DAEMON_CMD + ["restart", "--log-file", "./@incr-chk-logs", "--", "--cache-dir", mypy_cache_path] + cmd = DAEMON_CMD + [ + "restart", + "--log-file", + "./@incr-chk-logs", + "--", + "--cache-dir", + mypy_cache_path, + ] execute(cmd) @@ -194,16 +200,18 @@ def load_cache(incremental_cache_path: str = CACHE_PATH) -> JsonDict: def save_cache(cache: JsonDict, incremental_cache_path: str = CACHE_PATH) -> None: - with open(incremental_cache_path, 'w') as stream: + with open(incremental_cache_path, "w") as stream: json.dump(cache, stream, indent=2) -def set_expected(commits: List[Tuple[str, str]], - cache: JsonDict, - temp_repo_path: str, - target_file_path: Optional[str], - mypy_cache_path: str, - mypy_script: Optional[str]) -> None: +def set_expected( + commits: List[Tuple[str, str]], + cache: JsonDict, + temp_repo_path: str, + target_file_path: Optional[str], + mypy_cache_path: str, + mypy_script: Optional[str], +) -> None: """Populates the given `cache` with the expected results for all of the given `commits`. This function runs mypy on the `target_file_path` inside the `temp_repo_path`, and stores @@ -217,9 +225,10 @@ def set_expected(commits: List[Tuple[str, str]], else: print(f'Caching expected output for commit {commit_id}: "{message}"') execute(["git", "-C", temp_repo_path, "checkout", commit_id]) - runtime, output, stats = run_mypy(target_file_path, mypy_cache_path, mypy_script, - incremental=False) - cache[commit_id] = {'runtime': runtime, 'output': output} + runtime, output, stats = run_mypy( + target_file_path, mypy_cache_path, mypy_script, incremental=False + ) + cache[commit_id] = {"runtime": runtime, "output": output} if output == "": print(f" Clean output ({runtime:.3f} sec)") else: @@ -228,15 +237,17 @@ def set_expected(commits: List[Tuple[str, str]], print() -def test_incremental(commits: List[Tuple[str, str]], - cache: JsonDict, - temp_repo_path: str, - target_file_path: Optional[str], - mypy_cache_path: str, - *, - mypy_script: Optional[str] = None, - daemon: bool = False, - exit_on_error: bool = False) -> None: +def test_incremental( + commits: List[Tuple[str, str]], + cache: JsonDict, + temp_repo_path: str, + target_file_path: Optional[str], + mypy_cache_path: str, + *, + mypy_script: Optional[str] = None, + daemon: bool = False, + exit_on_error: bool = False, +) -> None: """Runs incremental mode on all `commits` to verify the output matches the expected output. This function runs mypy on the `target_file_path` inside the `temp_repo_path`. The @@ -248,11 +259,12 @@ def test_incremental(commits: List[Tuple[str, str]], for commit_id, message in commits: print(f'Now testing commit {commit_id}: "{message}"') execute(["git", "-C", temp_repo_path, "checkout", commit_id]) - runtime, output, stats = run_mypy(target_file_path, mypy_cache_path, mypy_script, - incremental=True, daemon=daemon) + runtime, output, stats = run_mypy( + target_file_path, mypy_cache_path, mypy_script, incremental=True, daemon=daemon + ) relevant_stats = combine_stats(overall_stats, stats) - expected_runtime = cache[commit_id]['runtime'] # type: float - expected_output = cache[commit_id]['output'] # type: str + expected_runtime = cache[commit_id]["runtime"] # type: float + expected_output = cache[commit_id]["output"] # type: str if output != expected_output: print(" Output does not match expected result!") print(f" Expected output ({expected_runtime:.3f} sec):") @@ -271,9 +283,8 @@ def test_incremental(commits: List[Tuple[str, str]], print("Overall stats:", overall_stats) -def combine_stats(overall_stats: Dict[str, float], - new_stats: Dict[str, Any]) -> Dict[str, float]: - INTERESTING_KEYS = ['build_time', 'gc_time'] +def combine_stats(overall_stats: Dict[str, float], new_stats: Dict[str, Any]) -> Dict[str, float]: + INTERESTING_KEYS = ["build_time", "gc_time"] # For now, we only support float keys relevant_stats = {} # type: Dict[str, float] for key in INTERESTING_KEYS: @@ -289,11 +300,18 @@ def cleanup(temp_repo_path: str, mypy_cache_path: str) -> None: delete_folder(mypy_cache_path) -def test_repo(target_repo_url: str, temp_repo_path: str, - target_file_path: Optional[str], - mypy_path: str, incremental_cache_path: str, mypy_cache_path: str, - range_type: str, range_start: str, branch: str, - params: Namespace) -> None: +def test_repo( + target_repo_url: str, + temp_repo_path: str, + target_file_path: Optional[str], + mypy_path: str, + incremental_cache_path: str, + mypy_cache_path: str, + range_type: str, + range_start: str, + branch: str, + params: Namespace, +) -> None: """Tests incremental mode against the repo specified in `target_repo_url`. This algorithm runs in five main stages: @@ -327,67 +345,110 @@ def test_repo(target_repo_url: str, temp_repo_path: str, raise RuntimeError(f"Invalid option: {range_type}") commits = get_commits_starting_at(temp_repo_path, start_commit) if params.limit: - commits = commits[:params.limit] + commits = commits[: params.limit] if params.sample: - seed = params.seed or base64.urlsafe_b64encode(os.urandom(15)).decode('ascii') + seed = params.seed or base64.urlsafe_b64encode(os.urandom(15)).decode("ascii") random.seed(seed) commits = random.sample(commits, params.sample) print("Sampled down to %d commits using random seed %s" % (len(commits), seed)) # Stage 3: Find and cache expected results for each commit (without incremental mode) cache = load_cache(incremental_cache_path) - set_expected(commits, cache, temp_repo_path, target_file_path, mypy_cache_path, - mypy_script=params.mypy_script) + set_expected( + commits, + cache, + temp_repo_path, + target_file_path, + mypy_cache_path, + mypy_script=params.mypy_script, + ) save_cache(cache, incremental_cache_path) # Stage 4: Rewind and re-run mypy (with incremental mode enabled) if params.daemon: - print('Starting daemon') + print("Starting daemon") start_daemon(mypy_cache_path) - test_incremental(commits, cache, temp_repo_path, target_file_path, mypy_cache_path, - mypy_script=params.mypy_script, daemon=params.daemon, - exit_on_error=params.exit_on_error) + test_incremental( + commits, + cache, + temp_repo_path, + target_file_path, + mypy_cache_path, + mypy_script=params.mypy_script, + daemon=params.daemon, + exit_on_error=params.exit_on_error, + ) # Stage 5: Remove temp files, stop daemon if not params.keep_temporary_files: cleanup(temp_repo_path, mypy_cache_path) if params.daemon: - print('Stopping daemon') + print("Stopping daemon") stop_daemon() def main() -> None: - help_factory = (lambda prog: RawDescriptionHelpFormatter(prog=prog, max_help_position=32)) # type: Any + help_factory = lambda prog: RawDescriptionHelpFormatter( + prog=prog, max_help_position=32 + ) # type: Any parser = ArgumentParser( - prog='incremental_checker', - description=__doc__, - formatter_class=help_factory) - - parser.add_argument("range_type", metavar="START_TYPE", choices=["last", "commit"], - help="must be one of 'last' or 'commit'") - parser.add_argument("range_start", metavar="COMMIT_ID_OR_NUMBER", - help="the commit id to start from, or the number of " - "commits to move back (see above)") - parser.add_argument("-r", "--repo_url", default=MYPY_REPO_URL, metavar="URL", - help="the repo to clone and run tests on") - parser.add_argument("-f", "--file-path", default=MYPY_TARGET_FILE, metavar="FILE", - help="the name of the file or directory to typecheck") - parser.add_argument("-x", "--exit-on-error", action='store_true', - help="Exits as soon as an error occurs") - parser.add_argument("--keep-temporary-files", action='store_true', - help="Keep temporary files on exit") - parser.add_argument("--cache-path", default=CACHE_PATH, metavar="DIR", - help="sets a custom location to store cache data") - parser.add_argument("--branch", default=None, metavar="NAME", - help="check out and test a custom branch" - "uses the default if not specified") + prog="incremental_checker", description=__doc__, formatter_class=help_factory + ) + + parser.add_argument( + "range_type", + metavar="START_TYPE", + choices=["last", "commit"], + help="must be one of 'last' or 'commit'", + ) + parser.add_argument( + "range_start", + metavar="COMMIT_ID_OR_NUMBER", + help="the commit id to start from, or the number of " "commits to move back (see above)", + ) + parser.add_argument( + "-r", + "--repo_url", + default=MYPY_REPO_URL, + metavar="URL", + help="the repo to clone and run tests on", + ) + parser.add_argument( + "-f", + "--file-path", + default=MYPY_TARGET_FILE, + metavar="FILE", + help="the name of the file or directory to typecheck", + ) + parser.add_argument( + "-x", "--exit-on-error", action="store_true", help="Exits as soon as an error occurs" + ) + parser.add_argument( + "--keep-temporary-files", action="store_true", help="Keep temporary files on exit" + ) + parser.add_argument( + "--cache-path", + default=CACHE_PATH, + metavar="DIR", + help="sets a custom location to store cache data", + ) + parser.add_argument( + "--branch", + default=None, + metavar="NAME", + help="check out and test a custom branch" "uses the default if not specified", + ) parser.add_argument("--sample", type=int, help="use a random sample of size SAMPLE") parser.add_argument("--seed", type=str, help="random seed") - parser.add_argument("--limit", type=int, - help="maximum number of commits to use (default until end)") + parser.add_argument( + "--limit", type=int, help="maximum number of commits to use (default until end)" + ) parser.add_argument("--mypy-script", type=str, help="alternate mypy script to run") - parser.add_argument("--daemon", action='store_true', - help="use mypy daemon instead of incremental (highly experimental)") + parser.add_argument( + "--daemon", + action="store_true", + help="use mypy daemon instead of incremental (highly experimental)", + ) if len(sys.argv[1:]) == 0: parser.print_help() @@ -425,11 +486,19 @@ def main() -> None: print(f"Using cache data located at {incremental_cache_path}") print() - test_repo(params.repo_url, temp_repo_path, target_file_path, - mypy_path, incremental_cache_path, mypy_cache_path, - params.range_type, params.range_start, params.branch, - params) - - -if __name__ == '__main__': + test_repo( + params.repo_url, + temp_repo_path, + target_file_path, + mypy_path, + incremental_cache_path, + mypy_cache_path, + params.range_type, + params.range_start, + params.branch, + params, + ) + + +if __name__ == "__main__": main() diff --git a/misc/perf_checker.py b/misc/perf_checker.py index 38a80c1481879..5cf03d4b86f51 100644 --- a/misc/perf_checker.py +++ b/misc/perf_checker.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -from typing import Callable, List, Tuple - import os import shutil import statistics import subprocess import textwrap import time +from typing import Callable, List, Tuple class Command: @@ -18,7 +17,7 @@ def __init__(self, setup: Callable[[], None], command: Callable[[], None]) -> No def print_offset(text: str, indent_length: int = 4) -> None: print() - print(textwrap.indent(text, ' ' * indent_length)) + print(textwrap.indent(text, " " * indent_length)) print() @@ -29,21 +28,19 @@ def delete_folder(folder_path: str) -> None: def execute(command: List[str]) -> None: proc = subprocess.Popen( - ' '.join(command), - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - shell=True) + " ".join(command), stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=True + ) stdout_bytes, stderr_bytes = proc.communicate() # type: Tuple[bytes, bytes] - stdout, stderr = stdout_bytes.decode('utf-8'), stderr_bytes.decode('utf-8') + stdout, stderr = stdout_bytes.decode("utf-8"), stderr_bytes.decode("utf-8") if proc.returncode != 0: - print('EXECUTED COMMAND:', repr(command)) - print('RETURN CODE:', proc.returncode) + print("EXECUTED COMMAND:", repr(command)) + print("RETURN CODE:", proc.returncode) print() - print('STDOUT:') + print("STDOUT:") print_offset(stdout) - print('STDERR:') + print("STDERR:") print_offset(stderr) - raise RuntimeError('Unexpected error from external tool.') + raise RuntimeError("Unexpected error from external tool.") def trial(num_trials: int, command: Command) -> List[float]: @@ -69,25 +66,28 @@ def main() -> None: trials = 3 print("Testing baseline") - baseline = trial(trials, Command( - lambda: None, - lambda: execute(["python3", "-m", "mypy", "mypy"]))) + baseline = trial( + trials, Command(lambda: None, lambda: execute(["python3", "-m", "mypy", "mypy"])) + ) report("Baseline", baseline) print("Testing cold cache") - cold_cache = trial(trials, Command( - lambda: delete_folder(".mypy_cache"), - lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]))) + cold_cache = trial( + trials, + Command( + lambda: delete_folder(".mypy_cache"), + lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), + ), + ) report("Cold cache", cold_cache) print("Testing warm cache") execute(["python3", "-m", "mypy", "-i", "mypy"]) - warm_cache = trial(trials, Command( - lambda: None, - lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]))) + warm_cache = trial( + trials, Command(lambda: None, lambda: execute(["python3", "-m", "mypy", "-i", "mypy"])) + ) report("Warm cache", warm_cache) -if __name__ == '__main__': +if __name__ == "__main__": main() - diff --git a/misc/proper_plugin.py b/misc/proper_plugin.py index acd77500cd5db..20a697ae4bbd8 100644 --- a/misc/proper_plugin.py +++ b/misc/proper_plugin.py @@ -1,13 +1,23 @@ -from mypy.plugin import Plugin, FunctionContext -from mypy.types import ( - FunctionLike, Type, Instance, CallableType, UnionType, get_proper_type, ProperType, - get_proper_types, TupleType, NoneTyp, AnyType -) -from mypy.nodes import TypeInfo -from mypy.subtypes import is_proper_subtype +from typing import Callable, Optional from typing_extensions import Type as typing_Type -from typing import Optional, Callable + +from mypy.nodes import TypeInfo +from mypy.plugin import FunctionContext, Plugin +from mypy.subtypes import is_proper_subtype +from mypy.types import ( + AnyType, + CallableType, + FunctionLike, + Instance, + NoneTyp, + ProperType, + TupleType, + Type, + UnionType, + get_proper_type, + get_proper_types, +) class ProperTypePlugin(Plugin): @@ -22,13 +32,13 @@ class ProperTypePlugin(Plugin): But after introducing a new type TypeAliasType (and removing immediate expansion) all these became dangerous because typ may be e.g. an alias to union. """ - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: - if fullname == 'builtins.isinstance': + + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: + if fullname == "builtins.isinstance": return isinstance_proper_hook - if fullname == 'mypy.types.get_proper_type': + if fullname == "mypy.types.get_proper_type": return proper_type_hook - if fullname == 'mypy.types.get_proper_types': + if fullname == "mypy.types.get_proper_types": return proper_types_hook return None @@ -39,41 +49,50 @@ def isinstance_proper_hook(ctx: FunctionContext) -> Type: right = get_proper_type(ctx.arg_types[1][0]) for arg in ctx.arg_types[0]: - if (is_improper_type(arg) or - isinstance(get_proper_type(arg), AnyType) and is_dangerous_target(right)): + if ( + is_improper_type(arg) + or isinstance(get_proper_type(arg), AnyType) + and is_dangerous_target(right) + ): if is_special_target(right): return ctx.default_return_type - ctx.api.fail('Never apply isinstance() to unexpanded types;' - ' use mypy.types.get_proper_type() first', ctx.context) - ctx.api.note('If you pass on the original type' # type: ignore[attr-defined] - ' after the check, always use its unexpanded version', ctx.context) + ctx.api.fail( + "Never apply isinstance() to unexpanded types;" + " use mypy.types.get_proper_type() first", + ctx.context, + ) + ctx.api.note( # type: ignore[attr-defined] + "If you pass on the original type" + " after the check, always use its unexpanded version", + ctx.context, + ) return ctx.default_return_type def is_special_target(right: ProperType) -> bool: """Whitelist some special cases for use in isinstance() with improper types.""" if isinstance(right, FunctionLike) and right.is_type_obj(): - if right.type_object().fullname == 'builtins.tuple': + if right.type_object().fullname == "builtins.tuple": # Used with Union[Type, Tuple[Type, ...]]. return True if right.type_object().fullname in ( - 'mypy.types.Type', - 'mypy.types.ProperType', - 'mypy.types.TypeAliasType' + "mypy.types.Type", + "mypy.types.ProperType", + "mypy.types.TypeAliasType", ): # Special case: things like assert isinstance(typ, ProperType) are always OK. return True if right.type_object().fullname in ( - 'mypy.types.UnboundType', - 'mypy.types.TypeVarType', - 'mypy.types.ParamSpecType', - 'mypy.types.RawExpressionType', - 'mypy.types.EllipsisType', - 'mypy.types.StarType', - 'mypy.types.TypeList', - 'mypy.types.CallableArgument', - 'mypy.types.PartialType', - 'mypy.types.ErasedType' + "mypy.types.UnboundType", + "mypy.types.TypeVarType", + "mypy.types.ParamSpecType", + "mypy.types.RawExpressionType", + "mypy.types.EllipsisType", + "mypy.types.StarType", + "mypy.types.TypeList", + "mypy.types.CallableArgument", + "mypy.types.PartialType", + "mypy.types.ErasedType", ): # Special case: these are not valid targets for a type alias and thus safe. # TODO: introduce a SyntheticType base to simplify this? @@ -88,7 +107,7 @@ def is_improper_type(typ: Type) -> bool: typ = get_proper_type(typ) if isinstance(typ, Instance): info = typ.type - return info.has_base('mypy.types.Type') and not info.has_base('mypy.types.ProperType') + return info.has_base("mypy.types.Type") and not info.has_base("mypy.types.ProperType") if isinstance(typ, UnionType): return any(is_improper_type(t) for t in typ.items) return False @@ -99,7 +118,7 @@ def is_dangerous_target(typ: ProperType) -> bool: if isinstance(typ, TupleType): return any(is_dangerous_target(get_proper_type(t)) for t in typ.items) if isinstance(typ, CallableType) and typ.is_type_obj(): - return typ.type_object().has_base('mypy.types.Type') + return typ.type_object().has_base("mypy.types.Type") return False @@ -113,7 +132,7 @@ def proper_type_hook(ctx: FunctionContext) -> Type: # Minimize amount of spurious errors from overload machinery. # TODO: call the hook on the overload as a whole? if isinstance(arg_type, (UnionType, Instance)): - ctx.api.fail('Redundant call to get_proper_type()', ctx.context) + ctx.api.fail("Redundant call to get_proper_type()", ctx.context) return ctx.default_return_type @@ -124,15 +143,15 @@ def proper_types_hook(ctx: FunctionContext) -> Type: arg_type = arg_types[0] proper_type = get_proper_type_instance(ctx) item_type = UnionType.make_union([NoneTyp(), proper_type]) - ok_type = ctx.api.named_generic_type('typing.Iterable', [item_type]) + ok_type = ctx.api.named_generic_type("typing.Iterable", [item_type]) if is_proper_subtype(arg_type, ok_type): - ctx.api.fail('Redundant call to get_proper_types()', ctx.context) + ctx.api.fail("Redundant call to get_proper_types()", ctx.context) return ctx.default_return_type def get_proper_type_instance(ctx: FunctionContext) -> Instance: - types = ctx.api.modules['mypy.types'] # type: ignore - proper_type_info = types.names['ProperType'] + types = ctx.api.modules["mypy.types"] # type: ignore + proper_type_info = types.names["ProperType"] assert isinstance(proper_type_info.node, TypeInfo) return Instance(proper_type_info.node, []) diff --git a/misc/sync-typeshed.py b/misc/sync-typeshed.py index 8f4bba8487b32..e74c3723ef070 100644 --- a/misc/sync-typeshed.py +++ b/misc/sync-typeshed.py @@ -18,9 +18,9 @@ def check_state() -> None: - if not os.path.isfile('README.md'): - sys.exit('error: The current working directory must be the mypy repository root') - out = subprocess.check_output(['git', 'status', '-s', os.path.join('mypy', 'typeshed')]) + if not os.path.isfile("README.md"): + sys.exit("error: The current working directory must be the mypy repository root") + out = subprocess.check_output(["git", "status", "-s", os.path.join("mypy", "typeshed")]) if out: # If there are local changes under mypy/typeshed, they would be lost. sys.exit('error: Output of "git status -s mypy/typeshed" must be empty') @@ -31,56 +31,61 @@ def update_typeshed(typeshed_dir: str, commit: Optional[str]) -> str: Return the normalized typeshed commit hash. """ - assert os.path.isdir(os.path.join(typeshed_dir, 'stdlib')) - assert os.path.isdir(os.path.join(typeshed_dir, 'stubs')) + assert os.path.isdir(os.path.join(typeshed_dir, "stdlib")) + assert os.path.isdir(os.path.join(typeshed_dir, "stubs")) if commit: - subprocess.run(['git', 'checkout', commit], check=True, cwd=typeshed_dir) + subprocess.run(["git", "checkout", commit], check=True, cwd=typeshed_dir) commit = git_head_commit(typeshed_dir) - stdlib_dir = os.path.join('mypy', 'typeshed', 'stdlib') + stdlib_dir = os.path.join("mypy", "typeshed", "stdlib") # Remove existing stubs. shutil.rmtree(stdlib_dir) # Copy new stdlib stubs. - shutil.copytree(os.path.join(typeshed_dir, 'stdlib'), stdlib_dir) + shutil.copytree(os.path.join(typeshed_dir, "stdlib"), stdlib_dir) # Copy mypy_extensions stubs. We don't want to use a stub package, since it's # treated specially by mypy and we make assumptions about what's there. - stubs_dir = os.path.join('mypy', 'typeshed', 'stubs') + stubs_dir = os.path.join("mypy", "typeshed", "stubs") shutil.rmtree(stubs_dir) os.makedirs(stubs_dir) - shutil.copytree(os.path.join(typeshed_dir, 'stubs', 'mypy-extensions'), - os.path.join(stubs_dir, 'mypy-extensions')) - shutil.copy(os.path.join(typeshed_dir, 'LICENSE'), os.path.join('mypy', 'typeshed')) + shutil.copytree( + os.path.join(typeshed_dir, "stubs", "mypy-extensions"), + os.path.join(stubs_dir, "mypy-extensions"), + ) + shutil.copy(os.path.join(typeshed_dir, "LICENSE"), os.path.join("mypy", "typeshed")) return commit def git_head_commit(repo: str) -> str: - commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=repo).decode('ascii') + commit = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=repo).decode("ascii") return commit.strip() def main() -> None: parser = argparse.ArgumentParser() parser.add_argument( - "--commit", default=None, - help="Typeshed commit (default to latest master if using a repository clone)" + "--commit", + default=None, + help="Typeshed commit (default to latest master if using a repository clone)", ) parser.add_argument( - "--typeshed-dir", default=None, - help="Location of typeshed (default to a temporary repository clone)" + "--typeshed-dir", + default=None, + help="Location of typeshed (default to a temporary repository clone)", ) args = parser.parse_args() check_state() - print('Update contents of mypy/typeshed from typeshed? [yN] ', end='') + print("Update contents of mypy/typeshed from typeshed? [yN] ", end="") answer = input() - if answer.lower() != 'y': - sys.exit('Aborting') + if answer.lower() != "y": + sys.exit("Aborting") if not args.typeshed_dir: # Clone typeshed repo if no directory given. with tempfile.TemporaryDirectory() as tempdir: - print(f'Cloning typeshed in {tempdir}...') - subprocess.run(['git', 'clone', 'https://github.com/python/typeshed.git'], - check=True, cwd=tempdir) - repo = os.path.join(tempdir, 'typeshed') + print(f"Cloning typeshed in {tempdir}...") + subprocess.run( + ["git", "clone", "https://github.com/python/typeshed.git"], check=True, cwd=tempdir + ) + repo = os.path.join(tempdir, "typeshed") commit = update_typeshed(repo, args.commit) else: commit = update_typeshed(args.typeshed_dir, args.commit) @@ -88,16 +93,20 @@ def main() -> None: assert commit # Create a commit - message = textwrap.dedent("""\ + message = textwrap.dedent( + """\ Sync typeshed Source commit: https://github.com/python/typeshed/commit/{commit} - """.format(commit=commit)) - subprocess.run(['git', 'add', '--all', os.path.join('mypy', 'typeshed')], check=True) - subprocess.run(['git', 'commit', '-m', message], check=True) - print('Created typeshed sync commit.') + """.format( + commit=commit + ) + ) + subprocess.run(["git", "add", "--all", os.path.join("mypy", "typeshed")], check=True) + subprocess.run(["git", "commit", "-m", message], check=True) + print("Created typeshed sync commit.") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/test_case_to_actual.py b/misc/test_case_to_actual.py index ccf631286802a..dd8a8a293c3cf 100644 --- a/misc/test_case_to_actual.py +++ b/misc/test_case_to_actual.py @@ -1,7 +1,7 @@ -from typing import Iterator, List -import sys import os import os.path +import sys +from typing import Iterator, List class Chunk: @@ -12,7 +12,7 @@ def __init__(self, header_type: str, args: str) -> None: def is_header(line: str) -> bool: - return line.startswith('[') and line.endswith(']') + return line.startswith("[") and line.endswith("]") def normalize(lines: Iterator[str]) -> Iterator[str]: @@ -25,8 +25,8 @@ def produce_chunks(lines: Iterator[str]) -> Iterator[Chunk]: if is_header(line): if current_chunk is not None: yield current_chunk - parts = line[1:-1].split(' ', 1) - args = parts[1] if len(parts) > 1 else '' + parts = line[1:-1].split(" ", 1) + args = parts[1] if len(parts) > 1 else "" current_chunk = Chunk(parts[0], args) else: current_chunk.lines.append(line) @@ -36,19 +36,19 @@ def produce_chunks(lines: Iterator[str]) -> Iterator[Chunk]: def write_out(filename: str, lines: List[str]) -> None: os.makedirs(os.path.dirname(filename), exist_ok=True) - with open(filename, 'w') as stream: - stream.write('\n'.join(lines)) + with open(filename, "w") as stream: + stream.write("\n".join(lines)) def write_tree(root: str, chunks: Iterator[Chunk]) -> None: init = next(chunks) - assert init.header_type == 'case' - + assert init.header_type == "case" + root = os.path.join(root, init.args) - write_out(os.path.join(root, 'main.py'), init.lines) + write_out(os.path.join(root, "main.py"), init.lines) for chunk in chunks: - if chunk.header_type == 'file' and chunk.args.endswith('.py'): + if chunk.header_type == "file" and chunk.args.endswith(".py"): write_out(os.path.join(root, chunk.args), chunk.lines) @@ -67,5 +67,5 @@ def main() -> None: write_tree(root_path, chunks) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/misc/touch_checker.py b/misc/touch_checker.py index d12c2e8166145..0cb9d7e5cf80c 100644 --- a/misc/touch_checker.py +++ b/misc/touch_checker.py @@ -1,20 +1,19 @@ #!/usr/bin/env python3 -from typing import Callable, List, Tuple, Optional - -import sys import glob import os import shutil import statistics import subprocess +import sys import textwrap import time +from typing import Callable, List, Optional, Tuple def print_offset(text: str, indent_length: int = 4) -> None: print() - print(textwrap.indent(text, ' ' * indent_length)) + print(textwrap.indent(text, " " * indent_length)) print() @@ -25,19 +24,17 @@ def delete_folder(folder_path: str) -> None: def execute(command: List[str]) -> None: proc = subprocess.Popen( - ' '.join(command), - stderr=subprocess.PIPE, - stdout=subprocess.PIPE, - shell=True) + " ".join(command), stderr=subprocess.PIPE, stdout=subprocess.PIPE, shell=True + ) stdout_bytes, stderr_bytes = proc.communicate() # type: Tuple[bytes, bytes] - stdout, stderr = stdout_bytes.decode('utf-8'), stderr_bytes.decode('utf-8') + stdout, stderr = stdout_bytes.decode("utf-8"), stderr_bytes.decode("utf-8") if proc.returncode != 0: - print('EXECUTED COMMAND:', repr(command)) - print('RETURN CODE:', proc.returncode) + print("EXECUTED COMMAND:", repr(command)) + print("RETURN CODE:", proc.returncode) print() - print('STDOUT:') + print("STDOUT:") print_offset(stdout) - print('STDERR:') + print("STDERR:") print_offset(stderr) print() @@ -57,8 +54,10 @@ def test(setup: Command, command: Command, teardown: Command) -> float: def make_touch_wrappers(filename: str) -> Tuple[Command, Command]: def setup() -> None: execute(["touch", filename]) + def teardown() -> None: pass + return setup, teardown @@ -69,12 +68,12 @@ def setup() -> None: nonlocal copy with open(filename) as stream: copy = stream.read() - with open(filename, 'a') as stream: - stream.write('\n\nfoo = 3') + with open(filename, "a") as stream: + stream.write("\n\nfoo = 3") def teardown() -> None: assert copy is not None - with open(filename, 'w') as stream: + with open(filename, "w") as stream: stream.write(copy) # Re-run to reset cache @@ -82,15 +81,16 @@ def teardown() -> None: return setup, teardown + def main() -> None: - if len(sys.argv) != 2 or sys.argv[1] not in {'touch', 'change'}: + if len(sys.argv) != 2 or sys.argv[1] not in {"touch", "change"}: print("First argument should be 'touch' or 'change'") return - if sys.argv[1] == 'touch': + if sys.argv[1] == "touch": make_wrappers = make_touch_wrappers verb = "Touching" - elif sys.argv[1] == 'change': + elif sys.argv[1] == "change": make_wrappers = make_change_wrappers verb = "Changing" else: @@ -98,22 +98,19 @@ def main() -> None: print("Setting up...") - baseline = test( - lambda: None, - lambda: execute(["python3", "-m", "mypy", "mypy"]), - lambda: None) + baseline = test(lambda: None, lambda: execute(["python3", "-m", "mypy", "mypy"]), lambda: None) print(f"Baseline: {baseline}") cold = test( lambda: delete_folder(".mypy_cache"), lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), - lambda: None) + lambda: None, + ) print(f"Cold cache: {cold}") warm = test( - lambda: None, - lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), - lambda: None) + lambda: None, lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), lambda: None + ) print(f"Warm cache: {warm}") print() @@ -121,12 +118,9 @@ def main() -> None: deltas = [] for filename in glob.iglob("mypy/**/*.py", recursive=True): print(f"{verb} {filename}") - + setup, teardown = make_wrappers(filename) - delta = test( - setup, - lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), - teardown) + delta = test(setup, lambda: execute(["python3", "-m", "mypy", "-i", "mypy"]), teardown) print(f" Time: {delta}") deltas.append(delta) print() @@ -146,6 +140,6 @@ def main() -> None: print(f" Total: {sum(deltas)}") print() -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/misc/variadics.py b/misc/variadics.py index 3ffc2a9678292..a216543a29c88 100644 --- a/misc/variadics.py +++ b/misc/variadics.py @@ -4,51 +4,52 @@ """ LIMIT = 5 -BOUND = 'object' +BOUND = "object" + def prelude(limit: int, bound: str) -> None: - print('from typing import Callable, Iterable, Iterator, Tuple, TypeVar, overload') + print("from typing import Callable, Iterable, Iterator, Tuple, TypeVar, overload") print(f"Ts = TypeVar('Ts', bound={bound})") print("R = TypeVar('R')") for i in range(LIMIT): - print('T{i} = TypeVar(\'T{i}\', bound={bound})'.format(i=i+1, bound=bound)) + print("T{i} = TypeVar('T{i}', bound={bound})".format(i=i + 1, bound=bound)) + -def expand_template(template: str, - arg_template: str = 'arg{i}: {Ts}', - lower: int = 0, - limit: int = LIMIT) -> None: +def expand_template( + template: str, arg_template: str = "arg{i}: {Ts}", lower: int = 0, limit: int = LIMIT +) -> None: print() for i in range(lower, limit): - tvs = ', '.join(f'T{j+1}' for j in range(i)) - args = ', '.join(arg_template.format(i=j+1, Ts=f'T{j+1}') - for j in range(i)) - print('@overload') + tvs = ", ".join(f"T{j+1}" for j in range(i)) + args = ", ".join(arg_template.format(i=j + 1, Ts=f"T{j+1}") for j in range(i)) + print("@overload") s = template.format(Ts=tvs, argsTs=args) - s = s.replace('Tuple[]', 'Tuple[()]') + s = s.replace("Tuple[]", "Tuple[()]") print(s) - args_l = [arg_template.format(i=j+1, Ts='Ts') for j in range(limit)] - args_l.append('*' + (arg_template.format(i='s', Ts='Ts'))) - args = ', '.join(args_l) - s = template.format(Ts='Ts, ...', argsTs=args) - s = s.replace('Callable[[Ts, ...]', 'Callable[...') - print('@overload') + args_l = [arg_template.format(i=j + 1, Ts="Ts") for j in range(limit)] + args_l.append("*" + (arg_template.format(i="s", Ts="Ts"))) + args = ", ".join(args_l) + s = template.format(Ts="Ts, ...", argsTs=args) + s = s.replace("Callable[[Ts, ...]", "Callable[...") + print("@overload") print(s) + def main(): prelude(LIMIT, BOUND) # map() - expand_template('def map(func: Callable[[{Ts}], R], {argsTs}) -> R: ...', - lower=1) + expand_template("def map(func: Callable[[{Ts}], R], {argsTs}) -> R: ...", lower=1) # zip() - expand_template('def zip({argsTs}) -> Tuple[{Ts}]: ...') + expand_template("def zip({argsTs}) -> Tuple[{Ts}]: ...") # Naomi's examples - expand_template('def my_zip({argsTs}) -> Iterator[Tuple[{Ts}]]: ...', - 'arg{i}: Iterable[{Ts}]') - expand_template('def make_check({argsTs}) -> Callable[[{Ts}], bool]: ...') - expand_template('def my_map(f: Callable[[{Ts}], R], {argsTs}) -> Iterator[R]: ...', - 'arg{i}: Iterable[{Ts}]') + expand_template("def my_zip({argsTs}) -> Iterator[Tuple[{Ts}]]: ...", "arg{i}: Iterable[{Ts}]") + expand_template("def make_check({argsTs}) -> Callable[[{Ts}], bool]: ...") + expand_template( + "def my_map(f: Callable[[{Ts}], R], {argsTs}) -> Iterator[R]: ...", + "arg{i}: Iterable[{Ts}]", + ) main() diff --git a/mypy/__main__.py b/mypy/__main__.py index aebeb4baedf82..f06a705668ac8 100644 --- a/mypy/__main__.py +++ b/mypy/__main__.py @@ -30,5 +30,5 @@ def console_entry() -> None: sys.exit(2) -if __name__ == '__main__': +if __name__ == "__main__": console_entry() diff --git a/mypy/api.py b/mypy/api.py index 28e8d835c7f8a..30a3739a52acb 100644 --- a/mypy/api.py +++ b/mypy/api.py @@ -44,9 +44,8 @@ """ import sys - from io import StringIO -from typing import List, Tuple, TextIO, Callable +from typing import Callable, List, TextIO, Tuple def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> Tuple[str, str, int]: @@ -66,8 +65,10 @@ def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> Tuple[str, str, int] def run(args: List[str]) -> Tuple[str, str, int]: # Lazy import to avoid needing to import all of mypy to call run_dmypy from mypy.main import main - return _run(lambda stdout, stderr: main(None, args=args, - stdout=stdout, stderr=stderr, clean_exit=True)) + + return _run( + lambda stdout, stderr: main(None, args=args, stdout=stdout, stderr=stderr, clean_exit=True) + ) def run_dmypy(args: List[str]) -> Tuple[str, str, int]: diff --git a/mypy/applytype.py b/mypy/applytype.py index b32b88fa32763..847c399a2e8a4 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -1,14 +1,24 @@ -from typing import Dict, Sequence, Optional, Callable +from typing import Callable, Dict, Optional, Sequence -import mypy.subtypes import mypy.sametypes +import mypy.subtypes from mypy.expandtype import expand_type +from mypy.nodes import Context from mypy.types import ( - Type, TypeVarId, TypeVarType, CallableType, AnyType, PartialType, get_proper_types, - TypeVarLikeType, ProperType, ParamSpecType, Parameters, get_proper_type, + AnyType, + CallableType, + Parameters, + ParamSpecType, + PartialType, + ProperType, + Type, + TypeVarId, + TypeVarLikeType, TypeVarTupleType, + TypeVarType, + get_proper_type, + get_proper_types, ) -from mypy.nodes import Context def get_target_type( @@ -17,7 +27,7 @@ def get_target_type( callable: CallableType, report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], context: Context, - skip_unsatisfied: bool + skip_unsatisfied: bool, ) -> Optional[Type]: if isinstance(tvar, ParamSpecType): return type @@ -31,8 +41,7 @@ def get_target_type( if isinstance(type, TypeVarType) and type.values: # Allow substituting T1 for T if every allowed value of T1 # is also a legal value of T. - if all(any(mypy.sametypes.is_same_type(v, v1) for v in values) - for v1 in type.values): + if all(any(mypy.sametypes.is_same_type(v, v1) for v in values) for v1 in type.values): return type matching = [] for value in values: @@ -58,10 +67,12 @@ def get_target_type( def apply_generic_arguments( - callable: CallableType, orig_types: Sequence[Optional[Type]], - report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], - context: Context, - skip_unsatisfied: bool = False) -> CallableType: + callable: CallableType, + orig_types: Sequence[Optional[Type]], + report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None], + context: Context, + skip_unsatisfied: bool = False, +) -> CallableType: """Apply generic type arguments to a callable type. For example, applying [int] to 'def [T] (T) -> T' results in diff --git a/mypy/argmap.py b/mypy/argmap.py index bcb8644720380..ac710f1b78d8a 100644 --- a/mypy/argmap.py +++ b/mypy/argmap.py @@ -1,23 +1,31 @@ """Utilities for mapping between actual and formal arguments (and their types).""" -from typing import TYPE_CHECKING, List, Optional, Sequence, Callable, Set +from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set +from mypy import nodes from mypy.maptype import map_instance_to_supertype from mypy.types import ( - Type, Instance, TupleType, AnyType, TypeOfAny, TypedDictType, ParamSpecType, get_proper_type + AnyType, + Instance, + ParamSpecType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + get_proper_type, ) -from mypy import nodes if TYPE_CHECKING: from mypy.infer import ArgumentInferContext -def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], - actual_names: Optional[Sequence[Optional[str]]], - formal_kinds: List[nodes.ArgKind], - formal_names: Sequence[Optional[str]], - actual_arg_type: Callable[[int], - Type]) -> List[List[int]]: +def map_actuals_to_formals( + actual_kinds: List[nodes.ArgKind], + actual_names: Optional[Sequence[Optional[str]]], + formal_kinds: List[nodes.ArgKind], + formal_names: Sequence[Optional[str]], + actual_arg_type: Callable[[int], Type], +) -> List[List[int]]: """Calculate mapping between actual (caller) args and formals. The result contains a list of caller argument indexes mapping to each @@ -89,12 +97,19 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], # # TODO: If there are also tuple varargs, we might be missing some potential # matches if the tuple was short enough to not match everything. - unmatched_formals = [fi for fi in range(nformals) - if (formal_names[fi] - and (not formal_to_actual[fi] - or actual_kinds[formal_to_actual[fi][0]] == nodes.ARG_STAR) - and formal_kinds[fi] != nodes.ARG_STAR) - or formal_kinds[fi] == nodes.ARG_STAR2] + unmatched_formals = [ + fi + for fi in range(nformals) + if ( + formal_names[fi] + and ( + not formal_to_actual[fi] + or actual_kinds[formal_to_actual[fi][0]] == nodes.ARG_STAR + ) + and formal_kinds[fi] != nodes.ARG_STAR + ) + or formal_kinds[fi] == nodes.ARG_STAR2 + ] for ai in ambiguous_actual_kwargs: for fi in unmatched_formals: formal_to_actual[fi].append(ai) @@ -102,18 +117,17 @@ def map_actuals_to_formals(actual_kinds: List[nodes.ArgKind], return formal_to_actual -def map_formals_to_actuals(actual_kinds: List[nodes.ArgKind], - actual_names: Optional[Sequence[Optional[str]]], - formal_kinds: List[nodes.ArgKind], - formal_names: List[Optional[str]], - actual_arg_type: Callable[[int], - Type]) -> List[List[int]]: +def map_formals_to_actuals( + actual_kinds: List[nodes.ArgKind], + actual_names: Optional[Sequence[Optional[str]]], + formal_kinds: List[nodes.ArgKind], + formal_names: List[Optional[str]], + actual_arg_type: Callable[[int], Type], +) -> List[List[int]]: """Calculate the reverse mapping of map_actuals_to_formals.""" - formal_to_actual = map_actuals_to_formals(actual_kinds, - actual_names, - formal_kinds, - formal_names, - actual_arg_type) + formal_to_actual = map_actuals_to_formals( + actual_kinds, actual_names, formal_kinds, formal_names, actual_arg_type + ) # Now reverse the mapping. actual_to_formal: List[List[int]] = [[] for _ in actual_kinds] for formal, actuals in enumerate(formal_to_actual): @@ -144,7 +158,7 @@ def f(x: int, *args: str) -> None: ... needs a separate instance since instances have per-call state. """ - def __init__(self, context: 'ArgumentInferContext') -> None: + def __init__(self, context: "ArgumentInferContext") -> None: # Next tuple *args index to use. self.tuple_index = 0 # Keyword arguments in TypedDict **kwargs used. @@ -152,11 +166,13 @@ def __init__(self, context: 'ArgumentInferContext') -> None: # Type context for `*` and `**` arg kinds. self.context = context - def expand_actual_type(self, - actual_type: Type, - actual_kind: nodes.ArgKind, - formal_name: Optional[str], - formal_kind: nodes.ArgKind) -> Type: + def expand_actual_type( + self, + actual_type: Type, + actual_kind: nodes.ArgKind, + formal_name: Optional[str], + formal_kind: nodes.ArgKind, + ) -> Type: """Return the actual (caller) type(s) of a formal argument with the given kinds. If the actual argument is a tuple *args, return the next individual tuple item that @@ -172,10 +188,10 @@ def expand_actual_type(self, if actual_kind == nodes.ARG_STAR: if isinstance(actual_type, Instance) and actual_type.args: from mypy.subtypes import is_subtype + if is_subtype(actual_type, self.context.iterable_type): return map_instance_to_supertype( - actual_type, - self.context.iterable_type.type, + actual_type, self.context.iterable_type.type ).args[0] else: # We cannot properly unpack anything other @@ -198,6 +214,7 @@ def expand_actual_type(self, return AnyType(TypeOfAny.from_error) elif actual_kind == nodes.ARG_STAR2: from mypy.subtypes import is_subtype + if isinstance(actual_type, TypedDictType): if formal_kind != nodes.ARG_STAR2 and formal_name in actual_type.items: # Lookup type based on keyword argument name. @@ -208,16 +225,15 @@ def expand_actual_type(self, self.kwargs_used.add(formal_name) return actual_type.items[formal_name] elif ( - isinstance(actual_type, Instance) and - len(actual_type.args) > 1 and - is_subtype(actual_type, self.context.mapping_type) + isinstance(actual_type, Instance) + and len(actual_type.args) > 1 + and is_subtype(actual_type, self.context.mapping_type) ): # Only `Mapping` type can be unpacked with `**`. # Other types will produce an error somewhere else. - return map_instance_to_supertype( - actual_type, - self.context.mapping_type.type, - ).args[1] + return map_instance_to_supertype(actual_type, self.context.mapping_type.type).args[ + 1 + ] elif isinstance(actual_type, ParamSpecType): # ParamSpec is valid in **kwargs but it can't be unpacked. return actual_type diff --git a/mypy/backports.py b/mypy/backports.py index df5afcb2416f7..2a6397ff73166 100644 --- a/mypy/backports.py +++ b/mypy/backports.py @@ -11,8 +11,10 @@ if sys.version_info < (3, 7): + @contextmanager def nullcontext() -> Iterator[None]: yield + else: from contextlib import nullcontext as nullcontext # noqa: F401 diff --git a/mypy/binder.py b/mypy/binder.py index 1dffb55a54acf..df2f9d8b4c01a 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -1,20 +1,16 @@ -from contextlib import contextmanager from collections import defaultdict +from contextlib import contextmanager +from typing import Dict, Iterator, List, Optional, Set, Tuple, Union, cast -from typing import Dict, List, Set, Iterator, Union, Optional, Tuple, cast from typing_extensions import DefaultDict, TypeAlias as _TypeAlias -from mypy.types import ( - Type, AnyType, PartialType, UnionType, TypeOfAny, NoneType, get_proper_type -) -from mypy.subtypes import is_subtype -from mypy.join import join_simple -from mypy.sametypes import is_same_type from mypy.erasetype import remove_instance_last_known_values -from mypy.nodes import Expression, Var, RefExpr +from mypy.join import join_simple from mypy.literals import Key, literal, literal_hash, subkeys -from mypy.nodes import IndexExpr, MemberExpr, AssignmentExpr, NameExpr - +from mypy.nodes import AssignmentExpr, Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, Var +from mypy.sametypes import is_same_type +from mypy.subtypes import is_subtype +from mypy.types import AnyType, NoneType, PartialType, Type, TypeOfAny, UnionType, get_proper_type BindableExpression: _TypeAlias = Union[IndexExpr, MemberExpr, AssignmentExpr, NameExpr] @@ -69,6 +65,7 @@ class A: reveal_type(lst[0].a) # str ``` """ + # Stored assignments for situations with tuple/list lvalue and rvalue of union type. # This maps an expression to a list of bound types for every item in the union type. type_assignments: Optional[Assigns] = None @@ -141,7 +138,7 @@ def put(self, expr: Expression, typ: Type) -> None: if not literal(expr): return key = literal_hash(expr) - assert key is not None, 'Internal error: binder tried to put non-literal' + assert key is not None, "Internal error: binder tried to put non-literal" if key not in self.declarations: self.declarations[key] = get_declaration(expr) self._add_dependencies(key) @@ -155,7 +152,7 @@ def suppress_unreachable_warnings(self) -> None: def get(self, expr: Expression) -> Optional[Type]: key = literal_hash(expr) - assert key is not None, 'Internal error: binder tried to get non-literal' + assert key is not None, "Internal error: binder tried to get non-literal" return self._get(key) def is_unreachable(self) -> bool: @@ -170,7 +167,7 @@ def is_unreachable_warning_suppressed(self) -> bool: def cleanse(self, expr: Expression) -> None: """Remove all references to a Node from the binder.""" key = literal_hash(expr) - assert key is not None, 'Internal error: binder tried cleanse non-literal' + assert key is not None, "Internal error: binder tried cleanse non-literal" self._cleanse_key(key) def _cleanse_key(self, key: Key) -> None: @@ -239,7 +236,7 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame: return result @contextmanager - def accumulate_type_assignments(self) -> 'Iterator[Assigns]': + def accumulate_type_assignments(self) -> "Iterator[Assigns]": """Push a new map to collect assigned types in multiassign from union. If this map is not None, actual binding is deferred until all items in @@ -253,10 +250,13 @@ def accumulate_type_assignments(self) -> 'Iterator[Assigns]': yield self.type_assignments self.type_assignments = old_assignments - def assign_type(self, expr: Expression, - type: Type, - declared_type: Optional[Type], - restrict_any: bool = False) -> None: + def assign_type( + self, + expr: Expression, + type: Type, + declared_type: Optional[Type], + restrict_any: bool = False, + ) -> None: # We should erase last known value in binder, because if we are using it, # it means that the target is not final, and therefore can't hold a literal. type = remove_instance_last_known_values(type) @@ -302,19 +302,24 @@ def assign_type(self, expr: Expression, # This overrides the normal behavior of ignoring Any assignments to variables # in order to prevent false positives. # (See discussion in #3526) - elif (isinstance(type, AnyType) - and isinstance(declared_type, UnionType) - and any(isinstance(get_proper_type(item), NoneType) for item in declared_type.items) - and isinstance(get_proper_type(self.most_recent_enclosing_type(expr, NoneType())), - NoneType)): + elif ( + isinstance(type, AnyType) + and isinstance(declared_type, UnionType) + and any(isinstance(get_proper_type(item), NoneType) for item in declared_type.items) + and isinstance( + get_proper_type(self.most_recent_enclosing_type(expr, NoneType())), NoneType + ) + ): # Replace any Nones in the union type with Any - new_items = [type if isinstance(get_proper_type(item), NoneType) else item - for item in declared_type.items] + new_items = [ + type if isinstance(get_proper_type(item), NoneType) else item + for item in declared_type.items + ] self.put(expr, UnionType(new_items)) - elif (isinstance(type, AnyType) - and not (isinstance(declared_type, UnionType) - and any(isinstance(get_proper_type(item), AnyType) - for item in declared_type.items))): + elif isinstance(type, AnyType) and not ( + isinstance(declared_type, UnionType) + and any(isinstance(get_proper_type(item), AnyType) for item in declared_type.items) + ): # Assigning an Any value doesn't affect the type to avoid false negatives, unless # there is an Any item in a declared union type. self.put(expr, declared_type) @@ -345,9 +350,9 @@ def most_recent_enclosing_type(self, expr: BindableExpression, type: Type) -> Op return get_declaration(expr) key = literal_hash(expr) assert key is not None - enclosers = ([get_declaration(expr)] + - [f.types[key] for f in self.frames - if key in f.types and is_subtype(type, f.types[key])]) + enclosers = [get_declaration(expr)] + [ + f.types[key] for f in self.frames if key in f.types and is_subtype(type, f.types[key]) + ] return enclosers[-1] def allow_jump(self, index: int) -> None: @@ -356,7 +361,7 @@ def allow_jump(self, index: int) -> None: if index < 0: index += len(self.options_on_return) frame = Frame(self._get_id()) - for f in self.frames[index + 1:]: + for f in self.frames[index + 1 :]: frame.types.update(f.types) if f.unreachable: frame.unreachable = True @@ -371,10 +376,16 @@ def handle_continue(self) -> None: self.unreachable() @contextmanager - def frame_context(self, *, can_skip: bool, fall_through: int = 1, - break_frame: int = 0, continue_frame: int = 0, - conditional_frame: bool = False, - try_frame: bool = False) -> Iterator[Frame]: + def frame_context( + self, + *, + can_skip: bool, + fall_through: int = 1, + break_frame: int = 0, + continue_frame: int = 0, + conditional_frame: bool = False, + try_frame: bool = False, + ) -> Iterator[Frame]: """Return a context manager that pushes/pops frames on enter/exit. If can_skip is True, control flow is allowed to bypass the diff --git a/mypy/bogus_type.py b/mypy/bogus_type.py index eb19e9c5db48d..2193a986c57ca 100644 --- a/mypy/bogus_type.py +++ b/mypy/bogus_type.py @@ -10,10 +10,11 @@ For those cases some other technique should be used. """ +from typing import Any, TypeVar + from mypy_extensions import FlexibleAlias -from typing import TypeVar, Any -T = TypeVar('T') +T = TypeVar("T") # This won't ever be true at runtime, but we consider it true during # mypyc compilations. diff --git a/mypy/build.py b/mypy/build.py index ecb04ada91e10..ff7e7a3295477 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -21,48 +21,79 @@ import sys import time import types +from typing import ( + AbstractSet, + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Set, + TextIO, + Tuple, + TypeVar, + Union, +) -from typing import (AbstractSet, Any, Dict, Iterable, Iterator, List, Sequence, - Mapping, NamedTuple, Optional, Set, Tuple, TypeVar, Union, Callable, TextIO) -from typing_extensions import ClassVar, NoReturn, Final, TYPE_CHECKING, TypeAlias as _TypeAlias from mypy_extensions import TypedDict +from typing_extensions import TYPE_CHECKING, ClassVar, Final, NoReturn, TypeAlias as _TypeAlias -from mypy.nodes import MypyFile, ImportBase, Import, ImportFrom, ImportAll, SymbolTable -from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis -from mypy.semanal import SemanticAnalyzer import mypy.semanal_main from mypy.checker import TypeChecker +from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error from mypy.indirection import TypeIndirectionVisitor -from mypy.errors import Errors, CompileError, ErrorInfo, report_internal_error +from mypy.nodes import Import, ImportAll, ImportBase, ImportFrom, MypyFile, SymbolTable +from mypy.semanal import SemanticAnalyzer +from mypy.semanal_pass1 import SemanticAnalyzerPreAnalysis from mypy.util import ( - DecodeError, decode_python_encoding, is_sub_path, get_mypy_comments, module_prefix, - read_py_file, hash_digest, is_typeshed_file, is_stub_package_file, get_top_two_prefixes, - time_ref, time_spent_us + DecodeError, + decode_python_encoding, + get_mypy_comments, + get_top_two_prefixes, + hash_digest, + is_stub_package_file, + is_sub_path, + is_typeshed_file, + module_prefix, + read_py_file, + time_ref, + time_spent_us, ) + if TYPE_CHECKING: from mypy.report import Reports # Avoid unconditional slow import + +from mypy import errorcodes as codes +from mypy.config_parser import parse_mypy_comments from mypy.fixup import fixup_module +from mypy.freetree import free_tree +from mypy.fscache import FileSystemCache +from mypy.metastore import FilesystemMetadataStore, MetadataStore, SqliteMetadataStore from mypy.modulefinder import ( - BuildSource, BuildSourceSet, compute_search_paths, FindModuleCache, SearchPaths, - ModuleSearchResult, ModuleNotFoundReason + BuildSource, + BuildSourceSet, + FindModuleCache, + ModuleNotFoundReason, + ModuleSearchResult, + SearchPaths, + compute_search_paths, ) from mypy.nodes import Expression from mypy.options import Options from mypy.parse import parse +from mypy.plugin import ChainedPlugin, Plugin, ReportConfigContext +from mypy.plugins.default import DefaultPlugin +from mypy.renaming import LimitedVariableRenameVisitor, VariableRenameVisitor from mypy.stats import dump_type_stats +from mypy.stubinfo import is_legacy_bundled_package, legacy_bundled_packages from mypy.types import Type -from mypy.version import __version__ -from mypy.plugin import Plugin, ChainedPlugin, ReportConfigContext -from mypy.plugins.default import DefaultPlugin -from mypy.fscache import FileSystemCache -from mypy.metastore import MetadataStore, FilesystemMetadataStore, SqliteMetadataStore from mypy.typestate import TypeState, reset_global_state -from mypy.renaming import VariableRenameVisitor, LimitedVariableRenameVisitor -from mypy.config_parser import parse_mypy_comments -from mypy.freetree import free_tree -from mypy.stubinfo import legacy_bundled_packages, is_legacy_bundled_package -from mypy import errorcodes as codes - +from mypy.version import __version__ # Switch to True to produce debug output related to fine-grained incremental # mode only that is useful during development. This produces only a subset of @@ -72,18 +103,18 @@ # These modules are special and should always come from typeshed. CORE_BUILTIN_MODULES: Final = { - 'builtins', - 'typing', - 'types', - 'typing_extensions', - 'mypy_extensions', - '_importlib_modulespec', - 'sys', - 'abc', + "builtins", + "typing", + "types", + "typing_extensions", + "mypy_extensions", + "_importlib_modulespec", + "sys", + "abc", } -Graph: _TypeAlias = Dict[str, 'State'] +Graph: _TypeAlias = Dict[str, "State"] # TODO: Get rid of BuildResult. We might as well return a BuildManager. @@ -98,7 +129,7 @@ class BuildResult: errors: List of error messages. """ - def __init__(self, manager: 'BuildManager', graph: Graph) -> None: + def __init__(self, manager: "BuildManager", graph: Graph) -> None: self.manager = manager self.graph = graph self.files = manager.modules @@ -107,15 +138,16 @@ def __init__(self, manager: 'BuildManager', graph: Graph) -> None: self.errors: List[str] = [] # Filled in by build if desired -def build(sources: List[BuildSource], - options: Options, - alt_lib_path: Optional[str] = None, - flush_errors: Optional[Callable[[List[str], bool], None]] = None, - fscache: Optional[FileSystemCache] = None, - stdout: Optional[TextIO] = None, - stderr: Optional[TextIO] = None, - extra_plugins: Optional[Sequence[Plugin]] = None, - ) -> BuildResult: +def build( + sources: List[BuildSource], + options: Options, + alt_lib_path: Optional[str] = None, + flush_errors: Optional[Callable[[List[str], bool], None]] = None, + fscache: Optional[FileSystemCache] = None, + stdout: Optional[TextIO] = None, + stderr: Optional[TextIO] = None, + extra_plugins: Optional[Sequence[Plugin]] = None, +) -> BuildResult: """Analyze a program. A single call to build performs parsing, semantic analysis and optionally @@ -167,16 +199,17 @@ def default_flush_errors(new_messages: List[str], is_serious: bool) -> None: raise -def _build(sources: List[BuildSource], - options: Options, - alt_lib_path: Optional[str], - flush_errors: Callable[[List[str], bool], None], - fscache: Optional[FileSystemCache], - stdout: TextIO, - stderr: TextIO, - extra_plugins: Sequence[Plugin], - ) -> BuildResult: - if platform.python_implementation() == 'CPython': +def _build( + sources: List[BuildSource], + options: Options, + alt_lib_path: Optional[str], + flush_errors: Callable[[List[str], bool], None], + fscache: Optional[FileSystemCache], + stdout: TextIO, + stderr: TextIO, + extra_plugins: Sequence[Plugin], +) -> BuildResult: + if platform.python_implementation() == "CPython": # This seems the most reasonable place to tune garbage collection. gc.set_threshold(150 * 1000) @@ -189,20 +222,23 @@ def _build(sources: List[BuildSource], if options.report_dirs: # Import lazily to avoid slowing down startup. from mypy.report import Reports # noqa + reports = Reports(data_dir, options.report_dirs) source_set = BuildSourceSet(sources) cached_read = fscache.read - errors = Errors(options.show_error_context, - options.show_column_numbers, - options.show_error_codes, - options.pretty, - options.show_error_end, - lambda path: read_py_file(path, cached_read, options.python_version), - options.show_absolute_path, - options.enabled_error_codes, - options.disabled_error_codes, - options.many_errors_threshold) + errors = Errors( + options.show_error_context, + options.show_column_numbers, + options.show_error_codes, + options.pretty, + options.show_error_end, + lambda path: read_py_file(path, cached_read, options.python_version), + options.show_absolute_path, + options.enabled_error_codes, + options.disabled_error_codes, + options.many_errors_threshold, + ) plugin, snapshot = load_plugins(options, errors, stdout, extra_plugins) # Add catch-all .gitignore to cache dir if we created it @@ -211,19 +247,22 @@ def _build(sources: List[BuildSource], # Construct a build manager object to hold state during the build. # # Ignore current directory prefix in error messages. - manager = BuildManager(data_dir, search_paths, - ignore_prefix=os.getcwd(), - source_set=source_set, - reports=reports, - options=options, - version_id=__version__, - plugin=plugin, - plugins_snapshot=snapshot, - errors=errors, - flush_errors=flush_errors, - fscache=fscache, - stdout=stdout, - stderr=stderr) + manager = BuildManager( + data_dir, + search_paths, + ignore_prefix=os.getcwd(), + source_set=source_set, + reports=reports, + options=options, + version_id=__version__, + plugin=plugin, + plugins_snapshot=snapshot, + errors=errors, + flush_errors=flush_errors, + fscache=fscache, + stdout=stdout, + stderr=stderr, + ) manager.trace(repr(options)) reset_global_state() @@ -238,10 +277,14 @@ def _build(sources: List[BuildSource], t0 = time.time() manager.metastore.commit() manager.add_stats(cache_commit_time=time.time() - t0) - manager.log("Build finished in %.3f seconds with %d modules, and %d errors" % - (time.time() - manager.start_time, - len(manager.modules), - manager.errors.num_messages())) + manager.log( + "Build finished in %.3f seconds with %d modules, and %d errors" + % ( + time.time() - manager.start_time, + len(manager.modules), + manager.errors.num_messages(), + ) + ) manager.dump_stats() if reports is not None: # Finish the HTML or XML reports even if CompileError was raised. @@ -292,13 +335,14 @@ class CacheMeta(NamedTuple): ignore_all: bool # if errors were ignored plugin_data: Any # config data from plugins + # NOTE: dependencies + suppressed == all reachable imports; # suppressed contains those reachable imports that were prevented by # silent mode or simply not found. # Metadata for the fine-grained dependencies file associated with a module. -FgDepMeta = TypedDict('FgDepMeta', {'path': str, 'mtime': int}) +FgDepMeta = TypedDict("FgDepMeta", {"path": str, "mtime": int}) def cache_meta_from_dict(meta: Dict[str, Any], data_json: str) -> CacheMeta: @@ -310,22 +354,22 @@ def cache_meta_from_dict(meta: Dict[str, Any], data_json: str) -> CacheMeta: """ sentinel: Any = None # Values to be validated by the caller return CacheMeta( - meta.get('id', sentinel), - meta.get('path', sentinel), - int(meta['mtime']) if 'mtime' in meta else sentinel, - meta.get('size', sentinel), - meta.get('hash', sentinel), - meta.get('dependencies', []), - int(meta['data_mtime']) if 'data_mtime' in meta else sentinel, + meta.get("id", sentinel), + meta.get("path", sentinel), + int(meta["mtime"]) if "mtime" in meta else sentinel, + meta.get("size", sentinel), + meta.get("hash", sentinel), + meta.get("dependencies", []), + int(meta["data_mtime"]) if "data_mtime" in meta else sentinel, data_json, - meta.get('suppressed', []), - meta.get('options'), - meta.get('dep_prios', []), - meta.get('dep_lines', []), - meta.get('interface_hash', ''), - meta.get('version_id', sentinel), - meta.get('ignore_all', True), - meta.get('plugin_data', None), + meta.get("suppressed", []), + meta.get("options"), + meta.get("dep_prios", []), + meta.get("dep_lines", []), + meta.get("interface_hash", ""), + meta.get("version_id", sentinel), + meta.get("ignore_all", True), + meta.get("plugin_data", None), ) @@ -368,7 +412,7 @@ def load_plugins_from_config( if not options.config_file: return [], snapshot - line = find_config_file_line_number(options.config_file, 'mypy', 'plugins') + line = find_config_file_line_number(options.config_file, "mypy", "plugins") if line == -1: line = 1 # We need to pick some line number that doesn't look too confusing @@ -379,11 +423,11 @@ def plugin_error(message: str) -> NoReturn: custom_plugins: List[Plugin] = [] errors.set_file(options.config_file, None) for plugin_path in options.plugins: - func_name = 'plugin' + func_name = "plugin" plugin_dir: Optional[str] = None - if ':' in os.path.basename(plugin_path): - plugin_path, func_name = plugin_path.rsplit(':', 1) - if plugin_path.endswith('.py'): + if ":" in os.path.basename(plugin_path): + plugin_path, func_name = plugin_path.rsplit(":", 1) + if plugin_path.endswith(".py"): # Plugin paths can be relative to the config file location. plugin_path = os.path.join(os.path.dirname(options.config_file), plugin_path) if not os.path.isfile(plugin_path): @@ -395,7 +439,7 @@ def plugin_error(message: str) -> NoReturn: fnam = os.path.basename(plugin_path) module_name = fnam[:-3] sys.path.insert(0, plugin_dir) - elif re.search(r'[\\/]', plugin_path): + elif re.search(r"[\\/]", plugin_path): fnam = os.path.basename(plugin_path) plugin_error(f'Plugin "{fnam}" does not have a .py extension') else: @@ -411,40 +455,42 @@ def plugin_error(message: str) -> NoReturn: del sys.path[0] if not hasattr(module, func_name): - plugin_error('Plugin "{}" does not define entry point function "{}"'.format( - plugin_path, func_name)) + plugin_error( + 'Plugin "{}" does not define entry point function "{}"'.format( + plugin_path, func_name + ) + ) try: plugin_type = getattr(module, func_name)(__version__) except Exception: - print(f'Error calling the plugin(version) entry point of {plugin_path}\n', - file=stdout) + print(f"Error calling the plugin(version) entry point of {plugin_path}\n", file=stdout) raise # Propagate to display traceback if not isinstance(plugin_type, type): plugin_error( 'Type object expected as the return value of "plugin"; got {!r} (in {})'.format( - plugin_type, plugin_path)) + plugin_type, plugin_path + ) + ) if not issubclass(plugin_type, Plugin): plugin_error( 'Return value of "plugin" must be a subclass of "mypy.plugin.Plugin" ' - '(in {})'.format(plugin_path)) + "(in {})".format(plugin_path) + ) try: custom_plugins.append(plugin_type(options)) snapshot[module_name] = take_module_snapshot(module) except Exception: - print(f'Error constructing plugin instance of {plugin_type.__name__}\n', - file=stdout) + print(f"Error constructing plugin instance of {plugin_type.__name__}\n", file=stdout) raise # Propagate to display traceback return custom_plugins, snapshot -def load_plugins(options: Options, - errors: Errors, - stdout: TextIO, - extra_plugins: Sequence[Plugin], - ) -> Tuple[Plugin, Dict[str, str]]: +def load_plugins( + options: Options, errors: Errors, stdout: TextIO, extra_plugins: Sequence[Plugin] +) -> Tuple[Plugin, Dict[str, str]]: """Load all configured plugins. Return a plugin that encapsulates all plugins chained together. Always @@ -470,14 +516,14 @@ def take_module_snapshot(module: types.ModuleType) -> str: We record _both_ hash and the version to detect more possible changes (e.g. if there is a change in modules imported by a plugin). """ - if hasattr(module, '__file__'): + if hasattr(module, "__file__"): assert module.__file__ is not None - with open(module.__file__, 'rb') as f: + with open(module.__file__, "rb") as f: digest = hash_digest(f.read()) else: - digest = 'unknown' - ver = getattr(module, '__version__', 'none') - return f'{ver}:{digest}' + digest = "unknown" + ver = getattr(module, "__version__", "none") + return f"{ver}:{digest}" def find_config_file_line_number(path: str, section: str, setting_name: str) -> int: @@ -491,10 +537,10 @@ def find_config_file_line_number(path: str, section: str, setting_name: str) -> with open(path, encoding="UTF-8") as f: for i, line in enumerate(f): line = line.strip() - if line.startswith('[') and line.endswith(']'): + if line.startswith("[") and line.endswith("]"): current_section = line[1:-1].strip() - in_desired_section = (current_section == section) - elif in_desired_section and re.match(fr'{setting_name}\s*=', line): + in_desired_section = current_section == section + elif in_desired_section and re.match(rf"{setting_name}\s*=", line): results.append(i + 1) if len(results) == 1: return results[0] @@ -545,21 +591,23 @@ class BuildManager: ast_cache: AST cache to speed up mypy daemon """ - def __init__(self, data_dir: str, - search_paths: SearchPaths, - ignore_prefix: str, - source_set: BuildSourceSet, - reports: 'Optional[Reports]', - options: Options, - version_id: str, - plugin: Plugin, - plugins_snapshot: Dict[str, str], - errors: Errors, - flush_errors: Callable[[List[str], bool], None], - fscache: FileSystemCache, - stdout: TextIO, - stderr: TextIO, - ) -> None: + def __init__( + self, + data_dir: str, + search_paths: SearchPaths, + ignore_prefix: str, + source_set: BuildSourceSet, + reports: "Optional[Reports]", + options: Options, + version_id: str, + plugin: Plugin, + plugins_snapshot: Dict[str, str], + errors: Errors, + flush_errors: Callable[[List[str], bool], None], + fscache: FileSystemCache, + stdout: TextIO, + stderr: TextIO, + ) -> None: self.stats: Dict[str, Any] = {} # Values are ints or floats self.stdout = stdout self.stderr = stderr @@ -592,29 +640,32 @@ def __init__(self, data_dir: str, self.missing_modules, self.incomplete_namespaces, self.errors, - self.plugin) + self.plugin, + ) self.all_types: Dict[Expression, Type] = {} # Enabled by export_types self.indirection_detector = TypeIndirectionVisitor() self.stale_modules: Set[str] = set() self.rechecked_modules: Set[str] = set() self.flush_errors = flush_errors has_reporters = reports is not None and reports.reporters - self.cache_enabled = (options.incremental - and (not options.fine_grained_incremental - or options.use_fine_grained_cache) - and not has_reporters) + self.cache_enabled = ( + options.incremental + and (not options.fine_grained_incremental or options.use_fine_grained_cache) + and not has_reporters + ) self.fscache = fscache - self.find_module_cache = FindModuleCache(self.search_paths, self.fscache, self.options, - source_set=self.source_set) + self.find_module_cache = FindModuleCache( + self.search_paths, self.fscache, self.options, source_set=self.source_set + ) self.metastore = create_metastore(options) # a mapping from source files to their corresponding shadow files # for efficient lookup self.shadow_map: Dict[str, str] = {} if self.options.shadow_file is not None: - self.shadow_map = {source_file: shadow_file - for (source_file, shadow_file) - in self.options.shadow_file} + self.shadow_map = { + source_file: shadow_file for (source_file, shadow_file) in self.options.shadow_file + } # a mapping from each file being typechecked to its possible shadow file self.shadow_equivalence_map: Dict[str, Optional[str]] = {} self.plugin = plugin @@ -677,8 +728,7 @@ def getmtime(self, path: str) -> int: else: return int(self.metastore.getmtime(path)) - def all_imported_modules_in_file(self, - file: MypyFile) -> List[Tuple[int, str, int]]: + def all_imported_modules_in_file(self, file: MypyFile) -> List[Tuple[int, str, int]]: """Find all reachable import statements in a file. Return list of tuples (priority, module id, import line number) @@ -693,7 +743,7 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: rel = imp.relative if rel == 0: return imp.id - if os.path.basename(file.path).startswith('__init__.'): + if os.path.basename(file.path).startswith("__init__."): rel -= 1 if rel != 0: file_id = ".".join(file_id.split(".")[:-rel]) @@ -701,9 +751,9 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: if not new_id: self.errors.set_file(file.path, file.name) - self.errors.report(imp.line, 0, - "No parent module -- cannot perform relative import", - blocker=True) + self.errors.report( + imp.line, 0, "No parent module -- cannot perform relative import", blocker=True + ) return new_id @@ -726,7 +776,7 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: # Also add any imported names that are submodules. pri = import_priority(imp, PRI_MED) for name, __ in imp.names: - sub_id = cur_id + '.' + name + sub_id = cur_id + "." + name if self.is_module(sub_id): res.append((pri, sub_id, imp.line)) else: @@ -756,8 +806,9 @@ def is_module(self, id: str) -> bool: """Is there a file in the file system corresponding to module id?""" return find_module_simple(id, self) is not None - def parse_file(self, id: str, path: str, source: str, ignore_errors: bool, - options: Options) -> MypyFile: + def parse_file( + self, id: str, path: str, source: str, ignore_errors: bool, options: Options + ) -> MypyFile: """Parse the source of a file with the given name. Raise CompileError if there is a parse error. @@ -765,10 +816,12 @@ def parse_file(self, id: str, path: str, source: str, ignore_errors: bool, t0 = time.time() tree = parse(source, path, id, self.errors, options=options) tree._fullname = id - self.add_stats(files_parsed=1, - modules_parsed=int(not tree.is_stub), - stubs_parsed=int(tree.is_stub), - parse_time=time.time() - t0) + self.add_stats( + files_parsed=1, + modules_parsed=int(not tree.is_stub), + stubs_parsed=int(tree.is_stub), + parse_time=time.time() - t0, + ) if self.errors.is_blockers(): self.log("Bailing due to parse errors") @@ -781,17 +834,16 @@ def load_fine_grained_deps(self, id: str) -> Dict[str, Set[str]]: t0 = time.time() if id in self.fg_deps_meta: # TODO: Assert deps file wasn't changed. - deps = json.loads(self.metastore.read(self.fg_deps_meta[id]['path'])) + deps = json.loads(self.metastore.read(self.fg_deps_meta[id]["path"])) else: deps = {} val = {k: set(v) for k, v in deps.items()} self.add_stats(load_fg_deps_time=time.time() - t0) return val - def report_file(self, - file: MypyFile, - type_map: Dict[Expression, Type], - options: Options) -> None: + def report_file( + self, file: MypyFile, type_map: Dict[Expression, Type], options: Options + ) -> None: if self.reports is not None and self.source_set.is_source(file): self.reports.file(file, self.modules, type_map, options) @@ -801,15 +853,16 @@ def verbosity(self) -> int: def log(self, *message: str) -> None: if self.verbosity() >= 1: if message: - print('LOG: ', *message, file=self.stderr) + print("LOG: ", *message, file=self.stderr) else: print(file=self.stderr) self.stderr.flush() def log_fine_grained(self, *message: str) -> None: import mypy.build + if self.verbosity() >= 1: - self.log('fine-grained:', *message) + self.log("fine-grained:", *message) elif mypy.build.DEBUG_FINE_GRAINED: # Output log in a simplified format that is quick to browse. if message: @@ -820,7 +873,7 @@ def log_fine_grained(self, *message: str) -> None: def trace(self, *message: str) -> None: if self.verbosity() >= 2: - print('TRACE:', *message, file=self.stderr) + print("TRACE:", *message, file=self.stderr) self.stderr.flush() def add_stats(self, **kwds: Any) -> None: @@ -848,8 +901,9 @@ def deps_to_json(x: Dict[str, Set[str]]) -> str: FAKE_ROOT_MODULE: Final = "@root" -def write_deps_cache(rdeps: Dict[str, Dict[str, Set[str]]], - manager: BuildManager, graph: Graph) -> None: +def write_deps_cache( + rdeps: Dict[str, Dict[str, Set[str]]], manager: BuildManager, graph: Graph +) -> None: """Write cache files for fine-grained dependencies. Serialize fine-grained dependencies map for fine grained mode. @@ -886,7 +940,7 @@ def write_deps_cache(rdeps: Dict[str, Dict[str, Set[str]]], manager.log(f"Error writing fine-grained deps JSON file {deps_json}") error = True else: - fg_deps_meta[id] = {'path': deps_json, 'mtime': manager.getmtime(deps_json)} + fg_deps_meta[id] = {"path": deps_json, "mtime": manager.getmtime(deps_json)} meta_snapshot: Dict[str, str] = {} for id, st in graph.items(): @@ -900,7 +954,7 @@ def write_deps_cache(rdeps: Dict[str, Dict[str, Set[str]]], hash = st.meta.hash meta_snapshot[id] = hash - meta = {'snapshot': meta_snapshot, 'deps_meta': fg_deps_meta} + meta = {"snapshot": meta_snapshot, "deps_meta": fg_deps_meta} if not metastore.write(DEPS_META_FILE, json.dumps(meta)): manager.log(f"Error writing fine-grained deps meta JSON file {DEPS_META_FILE}") @@ -908,12 +962,10 @@ def write_deps_cache(rdeps: Dict[str, Dict[str, Set[str]]], if error: manager.errors.set_file(_cache_dir_prefix(manager.options), None) - manager.errors.report(0, 0, "Error writing fine-grained dependencies cache", - blocker=True) + manager.errors.report(0, 0, "Error writing fine-grained dependencies cache", blocker=True) -def invert_deps(deps: Dict[str, Set[str]], - graph: Graph) -> Dict[str, Dict[str, Set[str]]]: +def invert_deps(deps: Dict[str, Set[str]], graph: Graph) -> Dict[str, Dict[str, Set[str]]]: """Splits fine-grained dependencies based on the module of the trigger. Returns a dictionary from module ids to all dependencies on that @@ -939,8 +991,7 @@ def invert_deps(deps: Dict[str, Set[str]], return rdeps -def generate_deps_for_cache(manager: BuildManager, - graph: Graph) -> Dict[str, Dict[str, Set[str]]]: +def generate_deps_for_cache(manager: BuildManager, graph: Graph) -> Dict[str, Dict[str, Set[str]]]: """Generate fine-grained dependencies into a form suitable for serializing. This does a couple things: @@ -975,27 +1026,30 @@ def write_plugins_snapshot(manager: BuildManager) -> None: """Write snapshot of versions and hashes of currently active plugins.""" if not manager.metastore.write(PLUGIN_SNAPSHOT_FILE, json.dumps(manager.plugins_snapshot)): manager.errors.set_file(_cache_dir_prefix(manager.options), None) - manager.errors.report(0, 0, "Error writing plugins snapshot", - blocker=True) + manager.errors.report(0, 0, "Error writing plugins snapshot", blocker=True) def read_plugins_snapshot(manager: BuildManager) -> Optional[Dict[str, str]]: """Read cached snapshot of versions and hashes of plugins from previous run.""" - snapshot = _load_json_file(PLUGIN_SNAPSHOT_FILE, manager, - log_success='Plugins snapshot ', - log_error='Could not load plugins snapshot: ') + snapshot = _load_json_file( + PLUGIN_SNAPSHOT_FILE, + manager, + log_success="Plugins snapshot ", + log_error="Could not load plugins snapshot: ", + ) if snapshot is None: return None if not isinstance(snapshot, dict): - manager.log('Could not load plugins snapshot: cache is not a dict: {}' - .format(type(snapshot))) + manager.log( + "Could not load plugins snapshot: cache is not a dict: {}".format(type(snapshot)) + ) return None return snapshot -def read_quickstart_file(options: Options, - stdout: TextIO, - ) -> Optional[Dict[str, Tuple[float, int, str]]]: +def read_quickstart_file( + options: Options, stdout: TextIO +) -> Optional[Dict[str, Tuple[float, int, str]]]: quickstart: Optional[Dict[str, Tuple[float, int, str]]] = None if options.quickstart_file: # This is very "best effort". If the file is missing or malformed, @@ -1013,8 +1067,7 @@ def read_quickstart_file(options: Options, return quickstart -def read_deps_cache(manager: BuildManager, - graph: Graph) -> Optional[Dict[str, FgDepMeta]]: +def read_deps_cache(manager: BuildManager, graph: Graph) -> Optional[Dict[str, FgDepMeta]]: """Read and validate the fine-grained dependencies cache. See the write_deps_cache documentation for more information on @@ -1022,29 +1075,33 @@ def read_deps_cache(manager: BuildManager, Returns None if the cache was invalid in some way. """ - deps_meta = _load_json_file(DEPS_META_FILE, manager, - log_success='Deps meta ', - log_error='Could not load fine-grained dependency metadata: ') + deps_meta = _load_json_file( + DEPS_META_FILE, + manager, + log_success="Deps meta ", + log_error="Could not load fine-grained dependency metadata: ", + ) if deps_meta is None: return None - meta_snapshot = deps_meta['snapshot'] + meta_snapshot = deps_meta["snapshot"] # Take a snapshot of the source hashes from all of the metas we found. # (Including the ones we rejected because they were out of date.) # We use this to verify that they match up with the proto_deps. - current_meta_snapshot = {id: st.meta_source_hash for id, st in graph.items() - if st.meta_source_hash is not None} + current_meta_snapshot = { + id: st.meta_source_hash for id, st in graph.items() if st.meta_source_hash is not None + } common = set(meta_snapshot.keys()) & set(current_meta_snapshot.keys()) if any(meta_snapshot[id] != current_meta_snapshot[id] for id in common): # TODO: invalidate also if options changed (like --strict-optional)? - manager.log('Fine-grained dependencies cache inconsistent, ignoring') + manager.log("Fine-grained dependencies cache inconsistent, ignoring") return None - module_deps_metas = deps_meta['deps_meta'] + module_deps_metas = deps_meta["deps_meta"] if not manager.options.skip_cache_mtime_checks: for id, meta in module_deps_metas.items(): try: - matched = manager.getmtime(meta['path']) == meta['mtime'] + matched = manager.getmtime(meta["path"]) == meta["mtime"] except FileNotFoundError: matched = False if not matched: @@ -1054,8 +1111,9 @@ def read_deps_cache(manager: BuildManager, return module_deps_metas -def _load_json_file(file: str, manager: BuildManager, - log_success: str, log_error: str) -> Optional[Dict[str, Any]]: +def _load_json_file( + file: str, manager: BuildManager, log_success: str, log_error: str +) -> Optional[Dict[str, Any]]: """A simple helper to read a JSON file with logging.""" t0 = time.time() try: @@ -1073,14 +1131,15 @@ def _load_json_file(file: str, manager: BuildManager, manager.add_stats(data_json_load_time=time.time() - t1) except json.JSONDecodeError: manager.errors.set_file(file, None) - manager.errors.report(-1, -1, - "Error reading JSON file;" - " you likely have a bad cache.\n" - "Try removing the {cache_dir} directory" - " and run mypy again.".format( - cache_dir=manager.options.cache_dir - ), - blocker=True) + manager.errors.report( + -1, + -1, + "Error reading JSON file;" + " you likely have a bad cache.\n" + "Try removing the {cache_dir} directory" + " and run mypy again.".format(cache_dir=manager.options.cache_dir), + blocker=True, + ) return None else: return result @@ -1093,7 +1152,7 @@ def _cache_dir_prefix(options: Options) -> str: return os.curdir cache_dir = options.cache_dir pyversion = options.python_version - base = os.path.join(cache_dir, '%d.%d' % pyversion) + base = os.path.join(cache_dir, "%d.%d" % pyversion) return base @@ -1119,10 +1178,12 @@ def exclude_from_backups(target_dir: str) -> None: cachedir_tag = os.path.join(target_dir, "CACHEDIR.TAG") try: with open(cachedir_tag, "x") as f: - f.write("""Signature: 8a477f597d28d172789f06886806bc55 + f.write( + """Signature: 8a477f597d28d172789f06886806bc55 # This file is a cache directory tag automatically created by mypy. # For information about cache directory tags see https://bford.info/cachedir/ -""") +""" + ) except FileExistsError: pass @@ -1161,15 +1222,15 @@ def get_cache_names(id: str, path: str, options: Options) -> Tuple[str, str, Opt # This only makes sense when using the filesystem backed cache. root = _cache_dir_prefix(options) return (os.path.relpath(pair[0], root), os.path.relpath(pair[1], root), None) - prefix = os.path.join(*id.split('.')) - is_package = os.path.basename(path).startswith('__init__.py') + prefix = os.path.join(*id.split(".")) + is_package = os.path.basename(path).startswith("__init__.py") if is_package: - prefix = os.path.join(prefix, '__init__') + prefix = os.path.join(prefix, "__init__") deps_json = None if options.cache_fine_grained: - deps_json = prefix + '.deps.json' - return (prefix + '.meta.json', prefix + '.data.json', deps_json) + deps_json = prefix + ".deps.json" + return (prefix + ".meta.json", prefix + ".data.json", deps_json) def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[CacheMeta]: @@ -1186,37 +1247,44 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[Cache """ # TODO: May need to take more build options into account meta_json, data_json, _ = get_cache_names(id, path, manager.options) - manager.trace(f'Looking for {id} at {meta_json}') + manager.trace(f"Looking for {id} at {meta_json}") t0 = time.time() - meta = _load_json_file(meta_json, manager, - log_success=f'Meta {id} ', - log_error=f'Could not load cache for {id}: ') + meta = _load_json_file( + meta_json, manager, log_success=f"Meta {id} ", log_error=f"Could not load cache for {id}: " + ) t1 = time.time() if meta is None: return None if not isinstance(meta, dict): - manager.log('Could not load cache for {}: meta cache is not a dict: {}' - .format(id, repr(meta))) + manager.log( + "Could not load cache for {}: meta cache is not a dict: {}".format(id, repr(meta)) + ) return None m = cache_meta_from_dict(meta, data_json) t2 = time.time() - manager.add_stats(load_meta_time=t2 - t0, - load_meta_load_time=t1 - t0, - load_meta_from_dict_time=t2 - t1) + manager.add_stats( + load_meta_time=t2 - t0, load_meta_load_time=t1 - t0, load_meta_from_dict_time=t2 - t1 + ) # Don't check for path match, that is dealt with in validate_meta(). - if (m.id != id or - m.mtime is None or m.size is None or - m.dependencies is None or m.data_mtime is None): - manager.log(f'Metadata abandoned for {id}: attributes are missing') + if ( + m.id != id + or m.mtime is None + or m.size is None + or m.dependencies is None + or m.data_mtime is None + ): + manager.log(f"Metadata abandoned for {id}: attributes are missing") return None # Ignore cache if generated by an older mypy version. - if ((m.version_id != manager.version_id and not manager.options.skip_version_check) - or m.options is None - or len(m.dependencies) + len(m.suppressed) != len(m.dep_prios) - or len(m.dependencies) + len(m.suppressed) != len(m.dep_lines)): - manager.log(f'Metadata abandoned for {id}: new attributes are missing') + if ( + (m.version_id != manager.version_id and not manager.options.skip_version_check) + or m.options is None + or len(m.dependencies) + len(m.suppressed) != len(m.dep_prios) + or len(m.dependencies) + len(m.suppressed) != len(m.dep_lines) + ): + manager.log(f"Metadata abandoned for {id}: new attributes are missing") return None # Ignore cache if (relevant) options aren't the same. @@ -1225,57 +1293,65 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[Cache current_options = manager.options.clone_for_module(id).select_options_affecting_cache() if manager.options.skip_version_check: # When we're lax about version we're also lax about platform. - cached_options['platform'] = current_options['platform'] - if 'debug_cache' in cached_options: + cached_options["platform"] = current_options["platform"] + if "debug_cache" in cached_options: # Older versions included debug_cache, but it's silly to compare it. - del cached_options['debug_cache'] + del cached_options["debug_cache"] if cached_options != current_options: - manager.log(f'Metadata abandoned for {id}: options differ') + manager.log(f"Metadata abandoned for {id}: options differ") if manager.options.verbosity >= 2: for key in sorted(set(cached_options) | set(current_options)): if cached_options.get(key) != current_options.get(key): - manager.trace(' {}: {} != {}' - .format(key, cached_options.get(key), current_options.get(key))) + manager.trace( + " {}: {} != {}".format( + key, cached_options.get(key), current_options.get(key) + ) + ) return None if manager.old_plugins_snapshot and manager.plugins_snapshot: # Check if plugins are still the same. if manager.plugins_snapshot != manager.old_plugins_snapshot: - manager.log(f'Metadata abandoned for {id}: plugins differ') + manager.log(f"Metadata abandoned for {id}: plugins differ") return None # So that plugins can return data with tuples in it without # things silently always invalidating modules, we round-trip # the config data. This isn't beautiful. - plugin_data = json.loads(json.dumps( - manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=True)) - )) + plugin_data = json.loads( + json.dumps(manager.plugin.report_config_data(ReportConfigContext(id, path, is_check=True))) + ) if m.plugin_data != plugin_data: - manager.log(f'Metadata abandoned for {id}: plugin configuration differs') + manager.log(f"Metadata abandoned for {id}: plugin configuration differs") return None manager.add_stats(fresh_metas=1) return m -def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], - ignore_all: bool, manager: BuildManager) -> Optional[CacheMeta]: - '''Checks whether the cached AST of this module can be used. +def validate_meta( + meta: Optional[CacheMeta], + id: str, + path: Optional[str], + ignore_all: bool, + manager: BuildManager, +) -> Optional[CacheMeta]: + """Checks whether the cached AST of this module can be used. Returns: None, if the cached AST is unusable. Original meta, if mtime/size matched. Meta with mtime updated to match source file, if hash/size matched but mtime/path didn't. - ''' + """ # This requires two steps. The first one is obvious: we check that the module source file # contents is the same as it was when the cache data file was created. The second one is not # too obvious: we check that the cache data file mtime has not changed; it is needed because # we use cache data file mtime to propagate information about changes in the dependencies. if meta is None: - manager.log(f'Metadata not found for {id}') + manager.log(f"Metadata not found for {id}") return None if meta.ignore_all and not ignore_all: - manager.log(f'Metadata abandoned for {id}: errors were previously ignored') + manager.log(f"Metadata abandoned for {id}: errors were previously ignored") return None t0 = time.time() @@ -1286,10 +1362,10 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], try: data_mtime = manager.getmtime(meta.data_json) except OSError: - manager.log(f'Metadata abandoned for {id}: failed to stat data_json') + manager.log(f"Metadata abandoned for {id}: failed to stat data_json") return None if data_mtime != meta.data_mtime: - manager.log(f'Metadata abandoned for {id}: data cache is modified') + manager.log(f"Metadata abandoned for {id}: data cache is modified") return None if bazel: @@ -1300,7 +1376,7 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], except OSError: return None if not (stat.S_ISREG(st.st_mode) or stat.S_ISDIR(st.st_mode)): - manager.log(f'Metadata abandoned for {id}: file {path} does not exist') + manager.log(f"Metadata abandoned for {id}: file {path} does not exist") return None manager.add_stats(validate_stat_time=time.time() - t0) @@ -1323,7 +1399,7 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], size = st.st_size # Bazel ensures the cache is valid. if size != meta.size and not bazel and not fine_grained_cache: - manager.log(f'Metadata abandoned for {id}: file {path} has different size') + manager.log(f"Metadata abandoned for {id}: file {path} has different size") return None # Bazel ensures the cache is valid. @@ -1336,7 +1412,7 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], # the file is up to date even though the mtime is wrong, without needing to hash it. qmtime, qsize, qhash = manager.quickstart_state[path] if int(qmtime) == mtime and qsize == size and qhash == meta.hash: - manager.log(f'Metadata fresh (by quickstart) for {id}: file {path}') + manager.log(f"Metadata fresh (by quickstart) for {id}: file {path}") meta = meta._replace(mtime=mtime, path=path) return meta @@ -1344,7 +1420,7 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], try: # dir means it is a namespace package if stat.S_ISDIR(st.st_mode): - source_hash = '' + source_hash = "" else: source_hash = manager.fscache.hash_digest(path) except (OSError, UnicodeDecodeError, DecodeError): @@ -1352,11 +1428,12 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], manager.add_stats(validate_hash_time=time.time() - t0) if source_hash != meta.hash: if fine_grained_cache: - manager.log(f'Using stale metadata for {id}: file {path}') + manager.log(f"Using stale metadata for {id}: file {path}") return meta else: - manager.log('Metadata abandoned for {}: file {} has different hash'.format( - id, path)) + manager.log( + "Metadata abandoned for {}: file {} has different hash".format(id, path) + ) return None else: t0 = time.time() @@ -1364,38 +1441,39 @@ def validate_meta(meta: Optional[CacheMeta], id: str, path: Optional[str], meta = meta._replace(mtime=mtime, path=path) # Construct a dict we can pass to json.dumps() (compare to write_cache()). meta_dict = { - 'id': id, - 'path': path, - 'mtime': mtime, - 'size': size, - 'hash': source_hash, - 'data_mtime': meta.data_mtime, - 'dependencies': meta.dependencies, - 'suppressed': meta.suppressed, - 'options': (manager.options.clone_for_module(id) - .select_options_affecting_cache()), - 'dep_prios': meta.dep_prios, - 'dep_lines': meta.dep_lines, - 'interface_hash': meta.interface_hash, - 'version_id': manager.version_id, - 'ignore_all': meta.ignore_all, - 'plugin_data': meta.plugin_data, + "id": id, + "path": path, + "mtime": mtime, + "size": size, + "hash": source_hash, + "data_mtime": meta.data_mtime, + "dependencies": meta.dependencies, + "suppressed": meta.suppressed, + "options": (manager.options.clone_for_module(id).select_options_affecting_cache()), + "dep_prios": meta.dep_prios, + "dep_lines": meta.dep_lines, + "interface_hash": meta.interface_hash, + "version_id": manager.version_id, + "ignore_all": meta.ignore_all, + "plugin_data": meta.plugin_data, } if manager.options.debug_cache: meta_str = json.dumps(meta_dict, indent=2, sort_keys=True) else: meta_str = json.dumps(meta_dict) meta_json, _, _ = get_cache_names(id, path, manager.options) - manager.log('Updating mtime for {}: file {}, meta {}, mtime {}' - .format(id, path, meta_json, meta.mtime)) + manager.log( + "Updating mtime for {}: file {}, meta {}, mtime {}".format( + id, path, meta_json, meta.mtime + ) + ) t1 = time.time() manager.metastore.write(meta_json, meta_str) # Ignore errors, just an optimization. - manager.add_stats(validate_update_time=time.time() - t1, - validate_munging_time=t1 - t0) + manager.add_stats(validate_update_time=time.time() - t1, validate_munging_time=t1 - t0) return meta # It's a match on (id, path, size, hash, mtime). - manager.log(f'Metadata fresh for {id}: file {path}') + manager.log(f"Metadata fresh for {id}: file {path}") return meta @@ -1405,7 +1483,7 @@ def compute_hash(text: str) -> str: # hash randomization (enabled by default in Python 3.3). See the # note in # https://docs.python.org/3/reference/datamodel.html#object.__hash__. - return hash_digest(text.encode('utf-8')) + return hash_digest(text.encode("utf-8")) def json_dumps(obj: Any, debug_cache: bool) -> str: @@ -1415,11 +1493,19 @@ def json_dumps(obj: Any, debug_cache: bool) -> str: return json.dumps(obj, sort_keys=True) -def write_cache(id: str, path: str, tree: MypyFile, - dependencies: List[str], suppressed: List[str], - dep_prios: List[int], dep_lines: List[int], - old_interface_hash: str, source_hash: str, - ignore_all: bool, manager: BuildManager) -> Tuple[str, Optional[CacheMeta]]: +def write_cache( + id: str, + path: str, + tree: MypyFile, + dependencies: List[str], + suppressed: List[str], + dep_prios: List[int], + dep_lines: List[int], + old_interface_hash: str, + source_hash: str, + ignore_all: bool, + manager: BuildManager, +) -> Tuple[str, Optional[CacheMeta]]: """Write cache files for a module. Note that this mypy's behavior is still correct when any given @@ -1450,7 +1536,7 @@ def write_cache(id: str, path: str, tree: MypyFile, # Obtain file paths. meta_json, data_json, _ = get_cache_names(id, path, manager.options) - manager.log(f'Writing {id} {path} {meta_json} {data_json}') + manager.log(f"Writing {id} {path} {meta_json} {data_json}") # Update tree.path so that in bazel mode it's made relative (since # sometimes paths leak out). @@ -1514,22 +1600,23 @@ def write_cache(id: str, path: str, tree: MypyFile, # verifying the cache. options = manager.options.clone_for_module(id) assert source_hash is not None - meta = {'id': id, - 'path': path, - 'mtime': mtime, - 'size': size, - 'hash': source_hash, - 'data_mtime': data_mtime, - 'dependencies': dependencies, - 'suppressed': suppressed, - 'options': options.select_options_affecting_cache(), - 'dep_prios': dep_prios, - 'dep_lines': dep_lines, - 'interface_hash': interface_hash, - 'version_id': manager.version_id, - 'ignore_all': ignore_all, - 'plugin_data': plugin_data, - } + meta = { + "id": id, + "path": path, + "mtime": mtime, + "size": size, + "hash": source_hash, + "data_mtime": data_mtime, + "dependencies": dependencies, + "suppressed": suppressed, + "options": options.select_options_affecting_cache(), + "dep_prios": dep_prios, + "dep_lines": dep_lines, + "interface_hash": interface_hash, + "version_id": manager.version_id, + "ignore_all": ignore_all, + "plugin_data": plugin_data, + } # Write meta cache file meta_str = json_dumps(meta, manager.options.debug_cache) @@ -1778,21 +1865,22 @@ class State: # Cumulative time spent on this file, in microseconds (for profiling stats) time_spent_us: int = 0 - def __init__(self, - id: Optional[str], - path: Optional[str], - source: Optional[str], - manager: BuildManager, - caller_state: 'Optional[State]' = None, - caller_line: int = 0, - ancestor_for: 'Optional[State]' = None, - root_source: bool = False, - # If `temporary` is True, this State is being created to just - # quickly parse/load the tree, without an intention to further - # process it. With this flag, any changes to external state as well - # as error reporting should be avoided. - temporary: bool = False, - ) -> None: + def __init__( + self, + id: Optional[str], + path: Optional[str], + source: Optional[str], + manager: BuildManager, + caller_state: "Optional[State]" = None, + caller_line: int = 0, + ancestor_for: "Optional[State]" = None, + root_source: bool = False, + # If `temporary` is True, this State is being created to just + # quickly parse/load the tree, without an intention to further + # process it. With this flag, any changes to external state as well + # as error reporting should be avoided. + temporary: bool = False, + ) -> None: if not temporary: assert id or path or source is not None, "Neither id, path nor source given" self.manager = manager @@ -1805,7 +1893,7 @@ def __init__(self, self.import_context.append((caller_state.xpath, caller_line)) else: self.import_context = [] - self.id = id or '__main__' + self.id = id or "__main__" self.options = manager.options.clone_for_module(self.id) self.early_errors = [] self._type_checker = None @@ -1813,18 +1901,25 @@ def __init__(self, assert id is not None try: path, follow_imports = find_module_and_diagnose( - manager, id, self.options, caller_state, caller_line, - ancestor_for, root_source, skip_diagnose=temporary) + manager, + id, + self.options, + caller_state, + caller_line, + ancestor_for, + root_source, + skip_diagnose=temporary, + ) except ModuleNotFound: if not temporary: manager.missing_modules.add(id) raise - if follow_imports == 'silent': + if follow_imports == "silent": self.ignore_all = True self.path = path if path: self.abspath = os.path.abspath(path) - self.xpath = path or '' + self.xpath = path or "" if path and source is None and self.manager.cache_enabled: self.meta = find_cache_meta(self.id, path, manager) # TODO: Get mtime if not cached. @@ -1832,7 +1927,7 @@ def __init__(self, self.interface_hash = self.meta.interface_hash self.meta_source_hash = self.meta.hash if path and source is None and self.manager.fscache.isdir(path): - source = '' + source = "" self.source = source self.add_ancestors() t0 = time.time() @@ -1847,11 +1942,9 @@ def __init__(self, self.suppressed_set = set(self.suppressed) all_deps = self.dependencies + self.suppressed assert len(all_deps) == len(self.meta.dep_prios) - self.priorities = {id: pri - for id, pri in zip(all_deps, self.meta.dep_prios)} + self.priorities = {id: pri for id, pri in zip(all_deps, self.meta.dep_prios)} assert len(all_deps) == len(self.meta.dep_lines) - self.dep_line_map = {id: line - for id, line in zip(all_deps, self.meta.dep_lines)} + self.dep_line_map = {id: line for id, line in zip(all_deps, self.meta.dep_lines)} if temporary: self.load_tree(temporary=True) if not manager.use_fine_grained_cache(): @@ -1888,15 +1981,15 @@ def add_ancestors(self) -> None: if self.path is not None: _, name = os.path.split(self.path) base, _ = os.path.splitext(name) - if '.' in base: + if "." in base: # This is just a weird filename, don't add anything self.ancestors = [] return # All parent packages are new ancestors. ancestors = [] parent = self.id - while '.' in parent: - parent, _ = parent.rsplit('.', 1) + while "." in parent: + parent, _ = parent.rsplit(".", 1) ancestors.append(parent) self.ancestors = ancestors @@ -1906,9 +1999,11 @@ def is_fresh(self) -> bool: # self.meta.dependencies when a dependency is dropped due to # suppression by silent mode. However when a suppressed # dependency is added back we find out later in the process. - return (self.meta is not None - and self.is_interface_fresh() - and self.dependencies == self.meta.dependencies) + return ( + self.meta is not None + and self.is_interface_fresh() + and self.dependencies == self.meta.dependencies + ) def is_interface_fresh(self) -> bool: return self.externally_same @@ -1947,8 +2042,15 @@ def wrap_context(self, check_blockers: bool = True) -> Iterator[None]: except CompileError: raise except Exception as err: - report_internal_error(err, self.path, 0, self.manager.errors, - self.options, self.manager.stdout, self.manager.stderr) + report_internal_error( + err, + self.path, + 0, + self.manager.errors, + self.options, + self.manager.stdout, + self.manager.stderr, + ) self.manager.errors.set_import_context(save_import_context) # TODO: Move this away once we've removed the old semantic analyzer? if check_blockers: @@ -1958,11 +2060,13 @@ def load_fine_grained_deps(self) -> Dict[str, Set[str]]: return self.manager.load_fine_grained_deps(self.id) def load_tree(self, temporary: bool = False) -> None: - assert self.meta is not None, "Internal error: this method must be called only" \ - " for cached modules" + assert self.meta is not None, ( + "Internal error: this method must be called only" " for cached modules" + ) - data = _load_json_file(self.meta.data_json, self.manager, "Load tree ", - "Could not load tree: ") + data = _load_json_file( + self.meta.data_json, self.manager, "Load tree ", "Could not load tree: " + ) if data is None: return None @@ -1979,8 +2083,7 @@ def fix_cross_refs(self) -> None: assert self.tree is not None, "Internal error: method must be called on parsed file only" # We need to set allow_missing when doing a fine grained cache # load because we need to gracefully handle missing modules. - fixup_module(self.tree, self.manager.modules, - self.options.use_fine_grained_cache) + fixup_module(self.tree, self.manager.modules, self.options.use_fine_grained_cache) # Methods for processing modules from source code. @@ -2012,36 +2115,45 @@ def parse_file(self) -> None: if self.path and source is None: try: path = manager.maybe_swap_for_shadow_path(self.path) - source = decode_python_encoding(manager.fscache.read(path), - manager.options.python_version) + source = decode_python_encoding( + manager.fscache.read(path), manager.options.python_version + ) self.source_hash = manager.fscache.hash_digest(path) except OSError as ioerr: # ioerr.strerror differs for os.stat failures between Windows and # other systems, but os.strerror(ioerr.errno) does not, so we use that. # (We want the error messages to be platform-independent so that the # tests have predictable output.) - raise CompileError([ - "mypy: can't read file '{}': {}".format( - self.path, os.strerror(ioerr.errno))], - module_with_blocker=self.id) from ioerr + raise CompileError( + [ + "mypy: can't read file '{}': {}".format( + self.path, os.strerror(ioerr.errno) + ) + ], + module_with_blocker=self.id, + ) from ioerr except (UnicodeDecodeError, DecodeError) as decodeerr: - if self.path.endswith('.pyd'): + if self.path.endswith(".pyd"): err = f"mypy: stubgen does not support .pyd files: '{self.path}'" else: err = f"mypy: can't decode file '{self.path}': {str(decodeerr)}" raise CompileError([err], module_with_blocker=self.id) from decodeerr elif self.path and self.manager.fscache.isdir(self.path): - source = '' - self.source_hash = '' + source = "" + self.source_hash = "" else: assert source is not None self.source_hash = compute_hash(source) self.parse_inline_configuration(source) if not cached: - self.tree = manager.parse_file(self.id, self.xpath, source, - self.ignore_all or self.options.ignore_errors, - self.options) + self.tree = manager.parse_file( + self.id, + self.xpath, + source, + self.ignore_all or self.options.ignore_errors, + self.options, + ) else: # Reuse a cached AST @@ -2049,7 +2161,8 @@ def parse_file(self) -> None: manager.errors.set_file_ignored_lines( self.xpath, self.tree.ignored_lines, - self.ignore_all or self.options.ignore_errors) + self.ignore_all or self.options.ignore_errors, + ) self.time_spent_us += time_spent_us(t0) @@ -2148,8 +2261,9 @@ def compute_dependencies(self) -> None: self.suppressed_set = set() self.priorities = {} # id -> priority self.dep_line_map = {} # id -> line - dep_entries = (manager.all_imported_modules_in_file(self.tree) + - self.manager.plugin.get_additional_deps(self.tree)) + dep_entries = manager.all_imported_modules_in_file( + self.tree + ) + self.manager.plugin.get_additional_deps(self.tree) for pri, id, line in dep_entries: self.priorities[id] = min(pri, self.priorities.get(id, PRI_ALL)) if id == self.id: @@ -2158,8 +2272,8 @@ def compute_dependencies(self) -> None: if id not in self.dep_line_map: self.dep_line_map[id] = line # Every module implicitly depends on builtins. - if self.id != 'builtins': - self.add_dependency('builtins') + if self.id != "builtins": + self.add_dependency("builtins") self.check_blockers() # Can fail due to bogus relative imports @@ -2176,8 +2290,12 @@ def type_checker(self) -> TypeChecker: assert self.tree is not None, "Internal error: must be called on parsed file only" manager = self.manager self._type_checker = TypeChecker( - manager.errors, manager.modules, self.options, - self.tree, self.xpath, manager.plugin, + manager.errors, + manager.modules, + self.options, + self.tree, + self.xpath, + manager.plugin, ) return self._type_checker @@ -2212,11 +2330,13 @@ def finish_passes(self) -> None: self._patch_indirect_dependencies(self.type_checker().module_refs, self.type_map()) if self.options.dump_inference_stats: - dump_type_stats(self.tree, - self.xpath, - modules=self.manager.modules, - inferred=True, - typemap=self.type_map()) + dump_type_stats( + self.tree, + self.xpath, + modules=self.manager.modules, + inferred=True, + typemap=self.type_map(), + ) manager.report_file(self.tree, self.type_map(), self.options) self.update_fine_grained_deps(self.manager.fg_deps) @@ -2230,9 +2350,9 @@ def free_state(self) -> None: self._type_checker.reset() self._type_checker = None - def _patch_indirect_dependencies(self, - module_refs: Set[str], - type_map: Dict[Expression, Type]) -> None: + def _patch_indirect_dependencies( + self, module_refs: Set[str], type_map: Dict[Expression, Type] + ) -> None: types = set(type_map.values()) assert None not in types valid = self.valid_references() @@ -2251,7 +2371,7 @@ def _patch_indirect_dependencies(self, def compute_fine_grained_deps(self) -> Dict[str, Set[str]]: assert self.tree is not None - if self.id in ('builtins', 'typing', 'types', 'sys', '_typeshed'): + if self.id in ("builtins", "typing", "types", "sys", "_typeshed"): # We don't track changes to core parts of typeshed -- the # assumption is that they are only changed as part of mypy # updates, which will invalidate everything anyway. These @@ -2261,15 +2381,19 @@ def compute_fine_grained_deps(self) -> Dict[str, Set[str]]: # dependencies then to handle cyclic imports. return {} from mypy.server.deps import get_dependencies # Lazy import to speed up startup - return get_dependencies(target=self.tree, - type_map=self.type_map(), - python_version=self.options.python_version, - options=self.manager.options) + + return get_dependencies( + target=self.tree, + type_map=self.type_map(), + python_version=self.options.python_version, + options=self.manager.options, + ) def update_fine_grained_deps(self, deps: Dict[str, Set[str]]) -> None: options = self.manager.options if options.cache_fine_grained or options.fine_grained_incremental: from mypy.server.deps import merge_dependencies # Lazy import to speed up startup + merge_dependencies(self.compute_fine_grained_deps(), deps) TypeState.update_protocol_deps(deps) @@ -2286,9 +2410,11 @@ def valid_references(self) -> Set[str]: def write_cache(self) -> None: assert self.tree is not None, "Internal error: method must be called on parsed file only" # We don't support writing cache files in fine-grained incremental mode. - if (not self.path - or self.options.cache_dir == os.devnull - or self.options.fine_grained_incremental): + if ( + not self.path + or self.options.cache_dir == os.devnull + or self.options.fine_grained_incremental + ): return is_errors = self.transitive_error if is_errors: @@ -2299,13 +2425,22 @@ def write_cache(self) -> None: dep_prios = self.dependency_priorities() dep_lines = self.dependency_lines() assert self.source_hash is not None - assert len(set(self.dependencies)) == len(self.dependencies), ( - f"Duplicates in dependencies list for {self.id} ({self.dependencies})") + assert len(set(self.dependencies)) == len( + self.dependencies + ), f"Duplicates in dependencies list for {self.id} ({self.dependencies})" new_interface_hash, self.meta = write_cache( - self.id, self.path, self.tree, - list(self.dependencies), list(self.suppressed), - dep_prios, dep_lines, self.interface_hash, self.source_hash, self.ignore_all, - self.manager) + self.id, + self.path, + self.tree, + list(self.dependencies), + list(self.suppressed), + dep_prios, + dep_lines, + self.interface_hash, + self.source_hash, + self.ignore_all, + self.manager, + ) if new_interface_hash == self.interface_hash: self.manager.log(f"Cached module {self.id} has same interface") else: @@ -2324,8 +2459,9 @@ def verify_dependencies(self, suppressed_only: bool = False) -> None: all_deps = self.suppressed else: # Strip out indirect dependencies. See comment in build.load_graph(). - dependencies = [dep for dep in self.dependencies - if self.priorities.get(dep) != PRI_INDIRECT] + dependencies = [ + dep for dep in self.dependencies if self.priorities.get(dep) != PRI_INDIRECT + ] all_deps = dependencies + self.suppressed + self.ancestors for dep in all_deps: if dep in manager.modules: @@ -2341,9 +2477,13 @@ def verify_dependencies(self, suppressed_only: bool = False) -> None: state, ancestor = self, None # Called just for its side effects of producing diagnostics. find_module_and_diagnose( - manager, dep, options, - caller_state=state, caller_line=line, - ancestor_for=ancestor) + manager, + dep, + options, + caller_state=state, + caller_line=line, + ancestor_for=ancestor, + ) except (ModuleNotFound, CompileError): # Swallow up any ModuleNotFounds or CompilerErrors while generating # a diagnostic. CompileErrors may get generated in @@ -2370,22 +2510,23 @@ def generate_unused_ignore_notes(self) -> None: def generate_ignore_without_code_notes(self) -> None: if self.manager.errors.is_error_code_enabled(codes.IGNORE_WITHOUT_CODE): self.manager.errors.generate_ignore_without_code_errors( - self.xpath, - self.options.warn_unused_ignores, + self.xpath, self.options.warn_unused_ignores ) # Module import and diagnostic glue -def find_module_and_diagnose(manager: BuildManager, - id: str, - options: Options, - caller_state: 'Optional[State]' = None, - caller_line: int = 0, - ancestor_for: 'Optional[State]' = None, - root_source: bool = False, - skip_diagnose: bool = False) -> Tuple[str, str]: +def find_module_and_diagnose( + manager: BuildManager, + id: str, + options: Options, + caller_state: "Optional[State]" = None, + caller_line: int = 0, + ancestor_for: "Optional[State]" = None, + root_source: bool = False, + skip_diagnose: bool = False, +) -> Tuple[str, str]: """Find a module by name, respecting follow_imports and producing diagnostics. If the module is not found, then the ModuleNotFound exception is raised. @@ -2407,7 +2548,7 @@ def find_module_and_diagnose(manager: BuildManager, Returns a tuple containing (file path, target's effective follow_imports setting) """ file_id = id - if id == 'builtins' and options.python_version[0] == 2: + if id == "builtins" and options.python_version[0] == 2: # The __builtin__ module is called internally by mypy # 'builtins' in Python 2 mode (similar to Python 3), # but the stub file is __builtin__.pyi. The reason is @@ -2416,7 +2557,7 @@ def find_module_and_diagnose(manager: BuildManager, # that the implementation can mostly ignore the # difference and just assume 'builtins' everywhere, # which simplifies code. - file_id = '__builtin__' + file_id = "__builtin__" result = find_module_with_reason(file_id, manager) if isinstance(result, str): # For non-stubs, look at options.follow_imports: @@ -2424,41 +2565,48 @@ def find_module_and_diagnose(manager: BuildManager, # - silent -> analyze but silence errors # - skip -> don't analyze, make the type Any follow_imports = options.follow_imports - if (root_source # Honor top-level modules - or (not result.endswith('.py') # Stubs are always normal - and not options.follow_imports_for_stubs) # except when they aren't - or id in mypy.semanal_main.core_modules): # core is always normal - follow_imports = 'normal' + if ( + root_source # Honor top-level modules + or ( + not result.endswith(".py") # Stubs are always normal + and not options.follow_imports_for_stubs + ) # except when they aren't + or id in mypy.semanal_main.core_modules + ): # core is always normal + follow_imports = "normal" if skip_diagnose: pass - elif follow_imports == 'silent': + elif follow_imports == "silent": # Still import it, but silence non-blocker errors. manager.log(f"Silencing {result} ({id})") - elif follow_imports == 'skip' or follow_imports == 'error': + elif follow_imports == "skip" or follow_imports == "error": # In 'error' mode, produce special error messages. if id not in manager.missing_modules: manager.log(f"Skipping {result} ({id})") - if follow_imports == 'error': + if follow_imports == "error": if ancestor_for: skipping_ancestor(manager, id, result, ancestor_for) else: - skipping_module(manager, caller_line, caller_state, - id, result) + skipping_module(manager, caller_line, caller_state, id, result) raise ModuleNotFound if not manager.options.no_silence_site_packages: for dir in manager.search_paths.package_path + manager.search_paths.typeshed_path: if is_sub_path(result, dir): # Silence errors in site-package dirs and typeshed - follow_imports = 'silent' - if (id in CORE_BUILTIN_MODULES - and not is_typeshed_file(result) - and not is_stub_package_file(result) - and not options.use_builtins_fixtures - and not options.custom_typeshed_dir): - raise CompileError([ - f'mypy: "{os.path.relpath(result)}" shadows library module "{id}"', - f'note: A user-defined top-level module with name "{id}" is not supported' - ]) + follow_imports = "silent" + if ( + id in CORE_BUILTIN_MODULES + and not is_typeshed_file(result) + and not is_stub_package_file(result) + and not options.use_builtins_fixtures + and not options.custom_typeshed_dir + ): + raise CompileError( + [ + f'mypy: "{os.path.relpath(result)}" shadows library module "{id}"', + f'note: A user-defined top-level module with name "{id}" is not supported', + ] + ) return (result, follow_imports) else: # Could not find a module. Typically the reason is a @@ -2473,11 +2621,15 @@ def find_module_and_diagnose(manager: BuildManager, # negatives. (Unless there are stubs but they are incomplete.) global_ignore_missing_imports = manager.options.ignore_missing_imports py_ver = options.python_version[0] - if ((is_legacy_bundled_package(top_level, py_ver) - or is_legacy_bundled_package(second_level, py_ver)) - and global_ignore_missing_imports - and not options.ignore_missing_imports_per_module - and result is ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED): + if ( + ( + is_legacy_bundled_package(top_level, py_ver) + or is_legacy_bundled_package(second_level, py_ver) + ) + and global_ignore_missing_imports + and not options.ignore_missing_imports_per_module + and result is ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED + ): ignore_missing_imports = False if skip_diagnose: @@ -2495,8 +2647,7 @@ def find_module_and_diagnose(manager: BuildManager, raise ModuleNotFound -def exist_added_packages(suppressed: List[str], - manager: BuildManager, options: Options) -> bool: +def exist_added_packages(suppressed: List[str], manager: BuildManager, options: Options) -> bool: """Find if there are any newly added packages that were previously suppressed. Exclude everything not in build for follow-imports=skip. @@ -2509,10 +2660,11 @@ def exist_added_packages(suppressed: List[str], path = find_module_simple(dep, manager) if not path: continue - if (options.follow_imports == 'skip' and - (not path.endswith('.pyi') or options.follow_imports_for_stubs)): + if options.follow_imports == "skip" and ( + not path.endswith(".pyi") or options.follow_imports_for_stubs + ): continue - if '__init__.py' in path: + if "__init__.py" in path: # It is better to have a bit lenient test, this will only slightly reduce # performance, while having a too strict test may affect correctness. return True @@ -2541,15 +2693,16 @@ def in_partial_package(id: str, manager: BuildManager) -> bool: This checks if there is any existing parent __init__.pyi stub that defines a module-level __getattr__ (a.k.a. partial stub package). """ - while '.' in id: - parent, _ = id.rsplit('.', 1) + while "." in id: + parent, _ = id.rsplit(".", 1) if parent in manager.modules: parent_mod: Optional[MypyFile] = manager.modules[parent] else: # Parent is not in build, try quickly if we can find it. try: - parent_st = State(id=parent, path=None, source=None, manager=manager, - temporary=True) + parent_st = State( + id=parent, path=None, source=None, manager=manager, temporary=True + ) except (ModuleNotFound, CompileError): parent_mod = None else: @@ -2564,50 +2717,59 @@ def in_partial_package(id: str, manager: BuildManager) -> bool: return False -def module_not_found(manager: BuildManager, line: int, caller_state: State, - target: str, reason: ModuleNotFoundReason) -> None: +def module_not_found( + manager: BuildManager, + line: int, + caller_state: State, + target: str, + reason: ModuleNotFoundReason, +) -> None: errors = manager.errors save_import_context = errors.import_context() errors.set_import_context(caller_state.import_context) errors.set_file(caller_state.xpath, caller_state.id) - if target == 'builtins': - errors.report(line, 0, "Cannot find 'builtins' module. Typeshed appears broken!", - blocker=True) + if target == "builtins": + errors.report( + line, 0, "Cannot find 'builtins' module. Typeshed appears broken!", blocker=True + ) errors.raise_error() else: daemon = manager.options.fine_grained_incremental msg, notes = reason.error_message_templates(daemon) - pyver = '%d.%d' % manager.options.python_version + pyver = "%d.%d" % manager.options.python_version errors.report(line, 0, msg.format(module=target, pyver=pyver), code=codes.IMPORT) top_level, second_level = get_top_two_prefixes(target) if second_level in legacy_bundled_packages: top_level = second_level for note in notes: - if '{stub_dist}' in note: + if "{stub_dist}" in note: note = note.format(stub_dist=legacy_bundled_packages[top_level].name) - errors.report(line, 0, note, severity='note', only_once=True, code=codes.IMPORT) + errors.report(line, 0, note, severity="note", only_once=True, code=codes.IMPORT) if reason is ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED: manager.missing_stub_packages.add(legacy_bundled_packages[top_level].name) errors.set_import_context(save_import_context) -def skipping_module(manager: BuildManager, line: int, caller_state: Optional[State], - id: str, path: str) -> None: +def skipping_module( + manager: BuildManager, line: int, caller_state: Optional[State], id: str, path: str +) -> None: """Produce an error for an import ignored due to --follow_imports=error""" assert caller_state, (id, path) save_import_context = manager.errors.import_context() manager.errors.set_import_context(caller_state.import_context) manager.errors.set_file(caller_state.xpath, caller_state.id) - manager.errors.report(line, 0, - f'Import of "{id}" ignored', - severity='error') - manager.errors.report(line, 0, - "(Using --follow-imports=error, module not passed on command line)", - severity='note', only_once=True) + manager.errors.report(line, 0, f'Import of "{id}" ignored', severity="error") + manager.errors.report( + line, + 0, + "(Using --follow-imports=error, module not passed on command line)", + severity="note", + only_once=True, + ) manager.errors.set_import_context(save_import_context) -def skipping_ancestor(manager: BuildManager, id: str, path: str, ancestor_for: 'State') -> None: +def skipping_ancestor(manager: BuildManager, id: str, path: str, ancestor_for: "State") -> None: """Produce an error for an ancestor ignored due to --follow_imports=error""" # TODO: Read the path (the __init__.py file) and return # immediately if it's empty or only contains comments. @@ -2615,11 +2777,16 @@ def skipping_ancestor(manager: BuildManager, id: str, path: str, ancestor_for: ' # so we'd need to cache the decision. manager.errors.set_import_context([]) manager.errors.set_file(ancestor_for.xpath, ancestor_for.id) - manager.errors.report(-1, -1, f'Ancestor package "{id}" ignored', - severity='error', only_once=True) - manager.errors.report(-1, -1, - "(Using --follow-imports=error, submodule passed on command line)", - severity='note', only_once=True) + manager.errors.report( + -1, -1, f'Ancestor package "{id}" ignored', severity="error", only_once=True + ) + manager.errors.report( + -1, + -1, + "(Using --follow-imports=error, submodule passed on command line)", + severity="note", + only_once=True, + ) def log_configuration(manager: BuildManager, sources: List[BuildSource]) -> None: @@ -2657,10 +2824,7 @@ def log_configuration(manager: BuildManager, sources: List[BuildSource]) -> None # The driver -def dispatch(sources: List[BuildSource], - manager: BuildManager, - stdout: TextIO, - ) -> Graph: +def dispatch(sources: List[BuildSource], manager: BuildManager, stdout: TextIO) -> Graph: log_configuration(manager, sources) t0 = time.time() @@ -2678,12 +2842,12 @@ def dispatch(sources: List[BuildSource], graph = load_graph(sources, manager) t1 = time.time() - manager.add_stats(graph_size=len(graph), - stubs_found=sum(g.path is not None and g.path.endswith('.pyi') - for g in graph.values()), - graph_load_time=(t1 - t0), - fm_cache_size=len(manager.find_module_cache.results), - ) + manager.add_stats( + graph_size=len(graph), + stubs_found=sum(g.path is not None and g.path.endswith(".pyi") for g in graph.values()), + graph_load_time=(t1 - t0), + fm_cache_size=len(manager.find_module_cache.results), + ) if not graph: print("Nothing to do?!", file=stdout) return graph @@ -2704,7 +2868,7 @@ def dispatch(sources: List[BuildSource], manager.add_stats(load_fg_deps_time=time.time() - t2) if fg_deps_meta is not None: manager.fg_deps_meta = fg_deps_meta - elif manager.stats.get('fresh_metas', 0) > 0: + elif manager.stats.get("fresh_metas", 0) > 0: # Clear the stats so we don't infinite loop because of positive fresh_metas manager.stats.clear() # There were some cache files read, but no fine-grained dependencies loaded. @@ -2734,8 +2898,10 @@ def dispatch(sources: List[BuildSource], if manager.options.dump_deps: # This speeds up startup a little when not using the daemon mode. from mypy.server.deps import dump_all_dependencies - dump_all_dependencies(manager.modules, manager.all_types, - manager.options.python_version, manager.options) + + dump_all_dependencies( + manager.modules, manager.all_types, manager.options.python_version, manager.options + ) return graph @@ -2751,21 +2917,23 @@ def __init__(self, index: int, scc: List[str]) -> None: def dumps(self) -> str: """Convert to JSON string.""" total_size = sum(self.sizes.values()) - return "[{}, {}, {},\n {},\n {}]".format(json.dumps(self.node_id), - json.dumps(total_size), - json.dumps(self.scc), - json.dumps(self.sizes), - json.dumps(self.deps)) + return "[{}, {}, {},\n {},\n {}]".format( + json.dumps(self.node_id), + json.dumps(total_size), + json.dumps(self.scc), + json.dumps(self.sizes), + json.dumps(self.deps), + ) def dump_timing_stats(path: str, graph: Graph) -> None: """ Dump timing stats for each file in the given graph """ - with open(path, 'w') as f: + with open(path, "w") as f: for k in sorted(graph.keys()): v = graph[k] - f.write(f'{v.id} {v.time_spent_us}\n') + f.write(f"{v.id} {v.time_spent_us}\n") def dump_graph(graph: Graph, stdout: Optional[TextIO] = None) -> None: @@ -2800,15 +2968,19 @@ def dump_graph(graph: Graph, stdout: Optional[TextIO] = None) -> None: pri = state.priorities[dep] if dep in inv_nodes: dep_id = inv_nodes[dep] - if (dep_id != node.node_id and - (dep_id not in node.deps or pri < node.deps[dep_id])): + if dep_id != node.node_id and ( + dep_id not in node.deps or pri < node.deps[dep_id] + ): node.deps[dep_id] = pri print("[" + ",\n ".join(node.dumps() for node in nodes) + "\n]", file=stdout) -def load_graph(sources: List[BuildSource], manager: BuildManager, - old_graph: Optional[Graph] = None, - new_modules: Optional[List[State]] = None) -> Graph: +def load_graph( + sources: List[BuildSource], + manager: BuildManager, + old_graph: Optional[Graph] = None, + new_modules: Optional[List[State]] = None, +) -> Graph: """Given some source files, load the full dependency graph. If an old_graph is passed in, it is used as the starting point and @@ -2832,29 +3004,33 @@ def load_graph(sources: List[BuildSource], manager: BuildManager, # Seed the graph with the initial root sources. for bs in sources: try: - st = State(id=bs.module, path=bs.path, source=bs.text, manager=manager, - root_source=True) + st = State( + id=bs.module, path=bs.path, source=bs.text, manager=manager, root_source=True + ) except ModuleNotFound: continue if st.id in graph: manager.errors.set_file(st.xpath, st.id) manager.errors.report( - -1, -1, + -1, + -1, f'Duplicate module named "{st.id}" (also at "{graph[st.id].xpath}")', blocker=True, ) manager.errors.report( - -1, -1, + -1, + -1, "See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules " # noqa: E501 "for more info", - severity='note', + severity="note", ) manager.errors.report( - -1, -1, + -1, + -1, "Common resolutions include: a) using `--exclude` to avoid checking one of them, " "b) adding `__init__.py` somewhere, c) using `--explicit-package-bases` or " "adjusting MYPYPATH", - severity='note' + severity="note", ) manager.errors.raise_error() @@ -2902,11 +3078,18 @@ def load_graph(sources: List[BuildSource], manager: BuildManager, if dep in st.ancestors: # TODO: Why not 'if dep not in st.dependencies' ? # Ancestors don't have import context. - newst = State(id=dep, path=None, source=None, manager=manager, - ancestor_for=st) + newst = State( + id=dep, path=None, source=None, manager=manager, ancestor_for=st + ) else: - newst = State(id=dep, path=None, source=None, manager=manager, - caller_state=st, caller_line=st.dep_line_map.get(dep, 1)) + newst = State( + id=dep, + path=None, + source=None, + manager=manager, + caller_state=st, + caller_line=st.dep_line_map.get(dep, 1), + ) except ModuleNotFound: if dep in st.dependencies_set: st.suppress_dependency(dep) @@ -2916,22 +3099,25 @@ def load_graph(sources: List[BuildSource], manager: BuildManager, if newst_path in seen_files: manager.errors.report( - -1, 0, - 'Source file found twice under different module names: ' + -1, + 0, + "Source file found twice under different module names: " '"{}" and "{}"'.format(seen_files[newst_path].id, newst.id), blocker=True, ) manager.errors.report( - -1, 0, + -1, + 0, "See https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules " # noqa: E501 "for more info", - severity='note', + severity="note", ) manager.errors.report( - -1, 0, + -1, + 0, "Common resolutions include: a) adding `__init__.py` somewhere, " "b) using `--explicit-package-bases` or adjusting MYPYPATH", - severity='note', + severity="note", ) manager.errors.raise_error() @@ -2950,8 +3136,7 @@ def load_graph(sources: List[BuildSource], manager: BuildManager, def process_graph(graph: Graph, manager: BuildManager) -> None: """Process everything in dependency order.""" sccs = sorted_components(graph) - manager.log("Found %d SCCs; largest has %d nodes" % - (len(sccs), max(len(scc) for scc in sccs))) + manager.log("Found %d SCCs; largest has %d nodes" % (len(sccs), max(len(scc) for scc in sccs))) fresh_scc_queue: List[List[str]] = [] @@ -2965,21 +3150,25 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: # Make the order of the SCC that includes 'builtins' and 'typing', # among other things, predictable. Various things may break if # the order changes. - if 'builtins' in ascc: + if "builtins" in ascc: scc = sorted(scc, reverse=True) # If builtins is in the list, move it last. (This is a bit of # a hack, but it's necessary because the builtins module is # part of a small cycle involving at least {builtins, abc, # typing}. Of these, builtins must be processed last or else # some builtin objects will be incompletely processed.) - scc.remove('builtins') - scc.append('builtins') + scc.remove("builtins") + scc.append("builtins") if manager.options.verbosity >= 2: for id in scc: - manager.trace(f"Priorities for {id}:", - " ".join("%s:%d" % (x, graph[id].priorities[x]) - for x in graph[id].dependencies - if x in ascc and x in graph[id].priorities)) + manager.trace( + f"Priorities for {id}:", + " ".join( + "%s:%d" % (x, graph[id].priorities[x]) + for x in graph[id].dependencies + if x in ascc and x in graph[id].priorities + ), + ) # Because the SCCs are presented in topological sort order, we # don't need to look at dependencies recursively for staleness # -- the immediate dependencies are sufficient. @@ -3006,8 +3195,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: # cache file is newer than any scc node's cache file. oldest_in_scc = min(graph[id].xmeta.data_mtime for id in scc) viable = {id for id in stale_deps if graph[id].meta is not None} - newest_in_deps = 0 if not viable else max(graph[dep].xmeta.data_mtime - for dep in viable) + newest_in_deps = ( + 0 if not viable else max(graph[dep].xmeta.data_mtime for dep in viable) + ) if manager.options.verbosity >= 3: # Dump all mtimes for extreme debugging. all_ids = sorted(ascc | viable, key=lambda id: graph[id].xmeta.data_mtime) for id in all_ids: @@ -3081,8 +3271,11 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: nodes_left = sum(len(scc) for scc in fresh_scc_queue) manager.add_stats(sccs_left=sccs_left, nodes_left=nodes_left) if sccs_left: - manager.log("{} fresh SCCs ({} nodes) left in queue (and will remain unprocessed)" - .format(sccs_left, nodes_left)) + manager.log( + "{} fresh SCCs ({} nodes) left in queue (and will remain unprocessed)".format( + sccs_left, nodes_left + ) + ) manager.trace(str(fresh_scc_queue)) else: manager.log("No fresh SCCs left in queue") @@ -3161,11 +3354,11 @@ def process_stale_scc(graph: Graph, scc: List[str], manager: BuildManager) -> No # We may already have parsed the module, or not. # If the former, parse_file() is a no-op. graph[id].parse_file() - if 'typing' in scc: + if "typing" in scc: # For historical reasons we need to manually add typing aliases # for built-in generic collections, see docstring of # SemanticAnalyzerPass2.add_builtin_aliases for details. - typing_mod = graph['typing'].tree + typing_mod = graph["typing"].tree assert typing_mod, "The typing module was not parsed" mypy.semanal_main.semantic_analysis_for_scc(graph, scc, manager.errors) @@ -3197,9 +3390,9 @@ def process_stale_scc(graph: Graph, scc: List[str], manager: BuildManager) -> No graph[id].mark_as_rechecked() -def sorted_components(graph: Graph, - vertices: Optional[AbstractSet[str]] = None, - pri_max: int = PRI_ALL) -> List[AbstractSet[str]]: +def sorted_components( + graph: Graph, vertices: Optional[AbstractSet[str]] = None, pri_max: int = PRI_ALL +) -> List[AbstractSet[str]]: """Return the graph's SCCs, topologically sorted by dependencies. The sort order is from leaves (nodes without dependencies) to @@ -3231,8 +3424,7 @@ def sorted_components(graph: Graph, # - If ready is [{a, b}, {c, d}], a.order == 1, b.order == 3, # c.order == 2, d.order == 4, the sort keys become [1, 2] # and the result is [{c, d}, {a, b}]. - res.extend(sorted(ready, - key=lambda scc: -min(graph[id].order for id in scc))) + res.extend(sorted(ready, key=lambda scc: -min(graph[id].order for id in scc))) return res @@ -3241,13 +3433,16 @@ def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: in if id not in vertices: return [] state = graph[id] - return [dep - for dep in state.dependencies - if dep in vertices and state.priorities.get(dep, PRI_HIGH) < pri_max] + return [ + dep + for dep in state.dependencies + if dep in vertices and state.priorities.get(dep, PRI_HIGH) < pri_max + ] -def strongly_connected_components(vertices: AbstractSet[str], - edges: Dict[str, List[str]]) -> Iterator[Set[str]]: +def strongly_connected_components( + vertices: AbstractSet[str], edges: Dict[str, List[str]] +) -> Iterator[Set[str]]: """Compute Strongly Connected Components of a directed graph. Args: @@ -3281,8 +3476,8 @@ def dfs(v: str) -> Iterator[Set[str]]: if boundaries[-1] == index[v]: boundaries.pop() - scc = set(stack[index[v]:]) - del stack[index[v]:] + scc = set(stack[index[v] :]) + del stack[index[v] :] identified.update(scc) yield scc @@ -3335,14 +3530,12 @@ def topsort(data: Dict[T, Set[T]]) -> Iterable[Set[T]]: if not ready: break yield ready - data = {item: (dep - ready) - for item, dep in data.items() - if item not in ready} + data = {item: (dep - ready) for item, dep in data.items() if item not in ready} assert not data, f"A cyclic dependency exists amongst {data!r}" def missing_stubs_file(cache_dir: str) -> str: - return os.path.join(cache_dir, 'missing_stubs') + return os.path.join(cache_dir, "missing_stubs") def record_missing_stub_packages(cache_dir: str, missing_stub_packages: Set[str]) -> None: @@ -3353,9 +3546,9 @@ def record_missing_stub_packages(cache_dir: str, missing_stub_packages: Set[str] """ fnam = missing_stubs_file(cache_dir) if missing_stub_packages: - with open(fnam, 'w') as f: + with open(fnam, "w") as f: for pkg in sorted(missing_stub_packages): - f.write(f'{pkg}\n') + f.write(f"{pkg}\n") else: if os.path.isfile(fnam): os.remove(fnam) diff --git a/mypy/checker.py b/mypy/checker.py index c131e80d47f08..02fd439ba7aa5 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1,98 +1,221 @@ """Mypy type checker.""" -import itertools import fnmatch +import itertools from collections import defaultdict from contextlib import contextmanager - from typing import ( - Any, Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, - Iterable, Sequence, Mapping, Generic, AbstractSet, Callable, overload + AbstractSet, + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, + overload, ) + from typing_extensions import Final, TypeAlias as _TypeAlias -from mypy.backports import nullcontext -from mypy.errorcodes import TYPE_VAR -from mypy.errors import Errors, report_internal_error, ErrorWatcher -from mypy.nodes import ( - SymbolTable, Statement, MypyFile, Var, Expression, Lvalue, Node, - OverloadedFuncDef, FuncDef, FuncItem, FuncBase, TypeInfo, - ClassDef, Block, AssignmentStmt, NameExpr, MemberExpr, IndexExpr, - TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, IfStmt, - WhileStmt, OperatorAssignmentStmt, WithStmt, AssertStmt, - RaiseStmt, TryStmt, ForStmt, DelStmt, CallExpr, IntExpr, StrExpr, - UnicodeExpr, OpExpr, UnaryExpr, LambdaExpr, TempNode, SymbolTableNode, - Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt, - ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr, - Import, ImportFrom, ImportAll, ImportBase, TypeAlias, - ARG_POS, ARG_STAR, ARG_NAMED, LITERAL_TYPE, LDEF, MDEF, GDEF, - CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr, - is_final_node, MatchStmt) -from mypy import nodes -from mypy import operators -from mypy.literals import literal, literal_hash, Key -from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any, make_optional_type -from mypy.types import ( - Type, AnyType, CallableType, FunctionLike, Overloaded, TupleType, TypedDictType, - Instance, NoneType, strip_type, TypeType, TypeOfAny, - UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType, - is_named_instance, union_items, TypeQuery, LiteralType, - is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType, - get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType, - OVERLOAD_NAMES, UnboundType -) -from mypy.typetraverser import TypeTraverserVisitor -from mypy.sametypes import is_same_type -from mypy.messages import ( - MessageBuilder, make_inferred_type_note, append_invariance_notes, pretty_seq, - format_type, format_type_bare, format_type_distinctly, SUGGESTED_TEST_FIXTURES -) import mypy.checkexpr +from mypy import errorcodes as codes, message_registry, nodes, operators +from mypy.backports import nullcontext +from mypy.binder import ConditionalTypeBinder, get_declaration from mypy.checkmember import ( - MemberContext, analyze_member_access, analyze_descriptor_access, - type_object_type, + MemberContext, analyze_decorator_or_funcbase_access, + analyze_descriptor_access, + analyze_member_access, + type_object_type, ) from mypy.checkpattern import PatternChecker -from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS -from mypy.typeops import ( - map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, - erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal, - try_getting_str_literals_from_type, try_getting_int_literals_from_type, - tuple_fallback, is_singleton_type, try_expanding_sum_type_to_union, - true_only, false_only, function_type, get_type_vars, custom_special_method, - is_literal_type_like, -) -from mypy import message_registry -from mypy.message_registry import ErrorMessage -from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, is_more_precise, - restrict_subtype_away, is_callable_compatible, - unify_generic_callable, find_member -) from mypy.constraints import SUPERTYPE_OF -from mypy.maptype import map_instance_to_supertype -from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any -from mypy.semanal import set_callable_name, refers_to_fullname -from mypy.mro import calculate_mro, MroError -from mypy.erasetype import erase_typevars, remove_instance_last_known_values, erase_type +from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values +from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode +from mypy.errors import Errors, ErrorWatcher, report_internal_error from mypy.expandtype import expand_type, expand_type_by_instance -from mypy.visitor import NodeVisitor from mypy.join import join_types -from mypy.treetransform import TransformVisitor -from mypy.binder import ConditionalTypeBinder, get_declaration +from mypy.literals import Key, literal, literal_hash +from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_erased_types, is_overlapping_types +from mypy.message_registry import ErrorMessage +from mypy.messages import ( + SUGGESTED_TEST_FIXTURES, + MessageBuilder, + append_invariance_notes, + format_type, + format_type_bare, + format_type_distinctly, + make_inferred_type_note, + pretty_seq, +) +from mypy.mro import MroError, calculate_mro +from mypy.nodes import ( + ARG_NAMED, + ARG_POS, + ARG_STAR, + CONTRAVARIANT, + COVARIANT, + GDEF, + INVARIANT, + LDEF, + LITERAL_TYPE, + MDEF, + AssertStmt, + AssignmentExpr, + AssignmentStmt, + Block, + BreakStmt, + CallExpr, + ClassDef, + ComparisonExpr, + Context, + ContinueStmt, + Decorator, + DelStmt, + EllipsisExpr, + Expression, + ExpressionStmt, + ForStmt, + FuncBase, + FuncDef, + FuncItem, + IfStmt, + Import, + ImportAll, + ImportBase, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListExpr, + Lvalue, + MatchStmt, + MemberExpr, + MypyFile, + NameExpr, + Node, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + PassStmt, + PrintStmt, + PromoteExpr, + RaiseStmt, + RefExpr, + ReturnStmt, + StarExpr, + Statement, + StrExpr, + SymbolTable, + SymbolTableNode, + TempNode, + TryStmt, + TupleExpr, + TypeAlias, + TypeInfo, + TypeVarExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + is_final_node, +) from mypy.options import Options -from mypy.plugin import Plugin, CheckerPluginInterface -from mypy.sharedparse import BINARY_MAGIC_METHODS +from mypy.plugin import CheckerPluginInterface, Plugin +from mypy.sametypes import is_same_type from mypy.scope import Scope -from mypy import errorcodes as codes +from mypy.semanal import refers_to_fullname, set_callable_name +from mypy.semanal_enum import ENUM_BASES, ENUM_SPECIAL_PROPS +from mypy.sharedparse import BINARY_MAGIC_METHODS from mypy.state import state -from mypy.traverser import has_return_statement, all_return_statements -from mypy.errorcodes import ErrorCode, UNUSED_AWAITABLE, UNUSED_COROUTINE -from mypy.util import is_typeshed_file, is_dunder, is_sunder +from mypy.subtypes import ( + find_member, + is_callable_compatible, + is_equivalent, + is_more_precise, + is_proper_subtype, + is_subtype, + restrict_subtype_away, + unify_generic_callable, +) +from mypy.traverser import all_return_statements, has_return_statement +from mypy.treetransform import TransformVisitor +from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type, make_optional_type +from mypy.typeops import ( + bind_self, + coerce_to_literal, + custom_special_method, + erase_def_to_union_or_bound, + erase_to_bound, + erase_to_union_or_bound, + false_only, + function_type, + get_type_vars, + is_literal_type_like, + is_singleton_type, + make_simplified_union, + map_type_from_supertype, + true_only, + try_expanding_sum_type_to_union, + try_getting_int_literals_from_type, + try_getting_str_literals_from_type, + tuple_fallback, +) +from mypy.types import ( + OVERLOAD_NAMES, + AnyType, + CallableType, + DeletedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + ParamSpecType, + PartialType, + ProperType, + StarType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeGuardedType, + TypeOfAny, + TypeQuery, + TypeTranslator, + TypeType, + TypeVarId, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + get_proper_type, + get_proper_types, + is_literal_type, + is_named_instance, + is_optional, + remove_optional, + strip_type, + union_items, +) +from mypy.typetraverser import TypeTraverserVisitor +from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars +from mypy.util import is_dunder, is_sunder, is_typeshed_file +from mypy.visitor import NodeVisitor -T = TypeVar('T') +T = TypeVar("T") DEFAULT_LAST_PASS: Final = 1 # Pass numbers start at 0 @@ -230,8 +353,15 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface): # functions such as open(), etc. plugin: Plugin - def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Options, - tree: MypyFile, path: str, plugin: Plugin) -> None: + def __init__( + self, + errors: Errors, + modules: Dict[str, MypyFile], + options: Options, + tree: MypyFile, + path: str, + plugin: Plugin, + ) -> None: """Construct a type checker. Use errors to report type check errors. @@ -265,9 +395,9 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile], options: Option if options.strict_optional_whitelist is None: self.suppress_none_errors = not options.show_none_errors else: - self.suppress_none_errors = not any(fnmatch.fnmatch(path, pattern) - for pattern - in options.strict_optional_whitelist) + self.suppress_none_errors = not any( + fnmatch.fnmatch(path, pattern) for pattern in options.strict_optional_whitelist + ) # If True, process function definitions. If False, don't. This is used # for processing module top levels in fine-grained incremental mode. self.recurse_into_functions = True @@ -323,33 +453,37 @@ def check_first_pass(self) -> None: with self.tscope.module_scope(self.tree.fullname): with self.enter_partial_types(), self.binder.top_frame_context(): for d in self.tree.defs: - if (self.binder.is_unreachable() - and self.should_report_unreachable_issues() - and not self.is_raising_or_empty(d)): + if ( + self.binder.is_unreachable() + and self.should_report_unreachable_issues() + and not self.is_raising_or_empty(d) + ): self.msg.unreachable_statement(d) break self.accept(d) assert not self.current_node_deferred - all_ = self.globals.get('__all__') + all_ = self.globals.get("__all__") if all_ is not None and all_.type is not None: all_node = all_.node assert all_node is not None - seq_str = self.named_generic_type('typing.Sequence', - [self.named_type('builtins.str')]) + seq_str = self.named_generic_type( + "typing.Sequence", [self.named_type("builtins.str")] + ) if self.options.python_version[0] < 3: - seq_str = self.named_generic_type('typing.Sequence', - [self.named_type('builtins.unicode')]) + seq_str = self.named_generic_type( + "typing.Sequence", [self.named_type("builtins.unicode")] + ) if not is_subtype(all_.type, seq_str): str_seq_s, all_s = format_type_distinctly(seq_str, all_.type) - self.fail(message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s), - all_node) + self.fail( + message_registry.ALL_MUST_BE_SEQ_STR.format(str_seq_s, all_s), all_node + ) - def check_second_pass(self, - todo: Optional[Sequence[Union[DeferredNode, - FineGrainedDeferredNode]]] = None - ) -> bool: + def check_second_pass( + self, todo: Optional[Sequence[Union[DeferredNode, FineGrainedDeferredNode]]] = None + ) -> bool: """Run second or following pass of type checking. This goes through deferred nodes, returning True if there were any. @@ -374,10 +508,12 @@ def check_second_pass(self, # print("XXX in pass %d, class %s, function %s" % # (self.pass_num, type_name, node.fullname or node.name)) done.add(node) - with self.tscope.class_scope(active_typeinfo) if active_typeinfo \ - else nullcontext(): - with self.scope.push_class(active_typeinfo) if active_typeinfo \ - else nullcontext(): + with self.tscope.class_scope( + active_typeinfo + ) if active_typeinfo else nullcontext(): + with self.scope.push_class( + active_typeinfo + ) if active_typeinfo else nullcontext(): self.check_partial(node) return True @@ -438,8 +574,13 @@ def accept(self, stmt: Statement) -> None: except Exception as err: report_internal_error(err, self.errors.file, stmt.line, self.errors, self.options) - def accept_loop(self, body: Statement, else_body: Optional[Statement] = None, *, - exit_condition: Optional[Expression] = None) -> None: + def accept_loop( + self, + body: Statement, + else_body: Optional[Statement] = None, + *, + exit_condition: Optional[Expression] = None, + ) -> None: """Repeatedly type check a loop body until the frame doesn't change. If exit_condition is set, assume it must be False on exit from the loop. @@ -448,8 +589,7 @@ def accept_loop(self, body: Statement, else_body: Optional[Statement] = None, *, # The outer frame accumulates the results of all iterations with self.binder.frame_context(can_skip=False, conditional_frame=True): while True: - with self.binder.frame_context(can_skip=True, - break_frame=2, continue_frame=1): + with self.binder.frame_context(can_skip=True, break_frame=2, continue_frame=1): self.accept(body) if not self.binder.last_pop_changed: break @@ -527,7 +667,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: elif isinstance(inner_type, Instance): inner_call = get_proper_type( analyze_member_access( - name='__call__', + name="__call__", typ=inner_type, context=defn.impl, is_lvalue=False, @@ -536,7 +676,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: msg=self.msg, original_type=inner_type, chk=self, - ), + ) ) if isinstance(inner_call, CallableType): impl_type = inner_call @@ -550,7 +690,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: sig1 = self.function_type(item.func) assert isinstance(sig1, CallableType) - for j, item2 in enumerate(defn.items[i + 1:]): + for j, item2 in enumerate(defn.items[i + 1 :]): assert isinstance(item2, Decorator) sig2 = self.function_type(item2.func) assert isinstance(sig2, CallableType) @@ -559,8 +699,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: continue if overload_can_never_match(sig1, sig2): - self.msg.overloaded_signature_will_never_match( - i + 1, i + j + 2, item2.func) + self.msg.overloaded_signature_will_never_match(i + 1, i + j + 2, item2.func) elif not is_descriptor_get: # Note: we force mypy to check overload signatures in strict-optional mode # so we don't incorrectly report errors when a user tries typing an overload @@ -577,8 +716,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # See Python 2's map function for a concrete example of this kind of overload. with state.strict_optional_set(True): if is_unsafe_overlapping_overload_signatures(sig1, sig2): - self.msg.overloaded_signatures_overlap( - i + 1, i + j + 2, item.func) + self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func) if impl_type is not None: assert defn.impl is not None @@ -592,9 +730,12 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # This is to match the direction the implementation's return # needs to be compatible in. if impl_type.variables: - impl = unify_generic_callable(impl_type, sig1, - ignore_return=False, - return_constraint_direction=SUPERTYPE_OF) + impl = unify_generic_callable( + impl_type, + sig1, + ignore_return=False, + return_constraint_direction=SUPERTYPE_OF, + ) if impl is None: self.msg.overloaded_signatures_typevar_specific(i + 1, defn.impl) continue @@ -607,14 +748,16 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: impl = impl.copy_modified(arg_types=[sig1.arg_types[0]] + impl.arg_types[1:]) # Is the overload alternative's arguments subtypes of the implementation's? - if not is_callable_compatible(impl, sig1, - is_compat=is_subtype_no_promote, - ignore_return=True): + if not is_callable_compatible( + impl, sig1, is_compat=is_subtype_no_promote, ignore_return=True + ): self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl) # Is the overload alternative's return type a subtype of the implementation's? - if not (is_subtype_no_promote(sig1.ret_type, impl.ret_type) or - is_subtype_no_promote(impl.ret_type, sig1.ret_type)): + if not ( + is_subtype_no_promote(sig1.ret_type, impl.ret_type) + or is_subtype_no_promote(impl.ret_type, sig1.ret_type) + ): self.msg.overloaded_signatures_ret_specific(i + 1, defn.impl) # Here's the scoop about generators and coroutines. @@ -669,15 +812,15 @@ def is_generator_return_type(self, typ: Type, is_coroutine: bool) -> bool: typ = get_proper_type(typ) if is_coroutine: # This means we're in Python 3.5 or later. - at = self.named_generic_type('typing.Awaitable', [AnyType(TypeOfAny.special_form)]) + at = self.named_generic_type("typing.Awaitable", [AnyType(TypeOfAny.special_form)]) if is_subtype(at, typ): return True else: any_type = AnyType(TypeOfAny.special_form) - gt = self.named_generic_type('typing.Generator', [any_type, any_type, any_type]) + gt = self.named_generic_type("typing.Generator", [any_type, any_type, any_type]) if is_subtype(gt, typ): return True - return isinstance(typ, Instance) and typ.type.fullname == 'typing.AwaitableGenerator' + return isinstance(typ, Instance) and typ.type.fullname == "typing.AwaitableGenerator" def is_async_generator_return_type(self, typ: Type) -> bool: """Is `typ` a valid type for an async generator? @@ -686,7 +829,7 @@ def is_async_generator_return_type(self, typ: Type) -> bool: """ try: any_type = AnyType(TypeOfAny.special_form) - agt = self.named_generic_type('typing.AsyncGenerator', [any_type, any_type]) + agt = self.named_generic_type("typing.AsyncGenerator", [any_type, any_type]) except KeyError: # we're running on a version of typing that doesn't have AsyncGenerator yet return False @@ -698,15 +841,16 @@ def get_generator_yield_type(self, return_type: Type, is_coroutine: bool) -> Typ if isinstance(return_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=return_type) - elif (not self.is_generator_return_type(return_type, is_coroutine) - and not self.is_async_generator_return_type(return_type)): + elif not self.is_generator_return_type( + return_type, is_coroutine + ) and not self.is_async_generator_return_type(return_type): # If the function doesn't have a proper Generator (or # Awaitable) return type, anything is permissible. return AnyType(TypeOfAny.from_error) elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType(TypeOfAny.from_error) - elif return_type.type.fullname == 'typing.Awaitable': + elif return_type.type.fullname == "typing.Awaitable": # Awaitable: ty is Any. return AnyType(TypeOfAny.special_form) elif return_type.args: @@ -727,22 +871,25 @@ def get_generator_receive_type(self, return_type: Type, is_coroutine: bool) -> T if isinstance(return_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=return_type) - elif (not self.is_generator_return_type(return_type, is_coroutine) - and not self.is_async_generator_return_type(return_type)): + elif not self.is_generator_return_type( + return_type, is_coroutine + ) and not self.is_async_generator_return_type(return_type): # If the function doesn't have a proper Generator (or # Awaitable) return type, anything is permissible. return AnyType(TypeOfAny.from_error) elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType(TypeOfAny.from_error) - elif return_type.type.fullname == 'typing.Awaitable': + elif return_type.type.fullname == "typing.Awaitable": # Awaitable, AwaitableGenerator: tc is Any. return AnyType(TypeOfAny.special_form) - elif (return_type.type.fullname in ('typing.Generator', 'typing.AwaitableGenerator') - and len(return_type.args) >= 3): + elif ( + return_type.type.fullname in ("typing.Generator", "typing.AwaitableGenerator") + and len(return_type.args) >= 3 + ): # Generator: tc is args[1]. return return_type.args[1] - elif return_type.type.fullname == 'typing.AsyncGenerator' and len(return_type.args) >= 2: + elif return_type.type.fullname == "typing.AsyncGenerator" and len(return_type.args) >= 2: return return_type.args[1] else: # `return_type` is a supertype of Generator, so callers won't be able to send it @@ -770,11 +917,13 @@ def get_generator_return_type(self, return_type: Type, is_coroutine: bool) -> Ty elif not isinstance(return_type, Instance): # Same as above, but written as a separate branch so the typechecker can understand. return AnyType(TypeOfAny.from_error) - elif return_type.type.fullname == 'typing.Awaitable' and len(return_type.args) == 1: + elif return_type.type.fullname == "typing.Awaitable" and len(return_type.args) == 1: # Awaitable: tr is args[0]. return return_type.args[0] - elif (return_type.type.fullname in ('typing.Generator', 'typing.AwaitableGenerator') - and len(return_type.args) >= 3): + elif ( + return_type.type.fullname in ("typing.Generator", "typing.AwaitableGenerator") + and len(return_type.args) >= 3 + ): # AwaitableGenerator, Generator: tr is args[2]. return return_type.args[2] else: @@ -831,14 +980,21 @@ def _visit_func_def(self, defn: FuncDef) -> None: self.fail(message_registry.INCOMPATIBLE_REDEFINITION, defn) else: # TODO: Update conditional type binder. - self.check_subtype(new_type, orig_type, defn, - message_registry.INCOMPATIBLE_REDEFINITION, - 'redefinition with type', - 'original type') - - def check_func_item(self, defn: FuncItem, - type_override: Optional[CallableType] = None, - name: Optional[str] = None) -> None: + self.check_subtype( + new_type, + orig_type, + defn, + message_registry.INCOMPATIBLE_REDEFINITION, + "redefinition with type", + "original type", + ) + + def check_func_item( + self, + defn: FuncItem, + type_override: Optional[CallableType] = None, + name: Optional[str] = None, + ) -> None: """Type check a function. If type_override is provided, use it as the function type. @@ -853,12 +1009,12 @@ def check_func_item(self, defn: FuncItem, with self.enter_attribute_inference_context(): self.check_func_def(defn, typ, name) else: - raise RuntimeError('Not supported') + raise RuntimeError("Not supported") self.dynamic_funcs.pop() self.current_node_deferred = False - if name == '__exit__': + if name == "__exit__": self.check__exit__return_type(defn) @contextmanager @@ -884,14 +1040,18 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if isinstance(item, FuncDef): fdef = item # Check if __init__ has an invalid, non-None return type. - if (fdef.info and fdef.name in ('__init__', '__init_subclass__') and - not isinstance(get_proper_type(typ.ret_type), NoneType) and - not self.dynamic_funcs[-1]): - self.fail(message_registry.MUST_HAVE_NONE_RETURN_TYPE.format(fdef.name), - item) + if ( + fdef.info + and fdef.name in ("__init__", "__init_subclass__") + and not isinstance(get_proper_type(typ.ret_type), NoneType) + and not self.dynamic_funcs[-1] + ): + self.fail( + message_registry.MUST_HAVE_NONE_RETURN_TYPE.format(fdef.name), item + ) # Check validity of __new__ signature - if fdef.info and fdef.name == '__new__': + if fdef.info and fdef.name == "__new__": self.check___new___signature(fdef, typ) self.check_for_missing_annotations(fdef) @@ -904,41 +1064,47 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if has_any_from_unimported_type(arg_type): prefix = f'Argument {idx + 1} to "{fdef.name}"' self.msg.unimported_type_becomes_any(prefix, arg_type, fdef) - check_for_explicit_any(fdef.type, self.options, self.is_typeshed_stub, - self.msg, context=fdef) + check_for_explicit_any( + fdef.type, self.options, self.is_typeshed_stub, self.msg, context=fdef + ) if name: # Special method names if defn.info and self.is_reverse_op_method(name): self.check_reverse_op_method(item, typ, name, defn) - elif name in ('__getattr__', '__getattribute__'): + elif name in ("__getattr__", "__getattribute__"): self.check_getattr_method(typ, defn, name) - elif name == '__setattr__': + elif name == "__setattr__": self.check_setattr_method(typ, defn) # Refuse contravariant return type variable if isinstance(typ.ret_type, TypeVarType): if typ.ret_type.variance == CONTRAVARIANT: - self.fail(message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT, - typ.ret_type) + self.fail( + message_registry.RETURN_TYPE_CANNOT_BE_CONTRAVARIANT, typ.ret_type + ) self.check_unbound_return_typevar(typ) # Check that Generator functions have the appropriate return type. if defn.is_generator: if defn.is_async_generator: if not self.is_async_generator_return_type(typ.ret_type): - self.fail(message_registry.INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR, - typ) + self.fail( + message_registry.INVALID_RETURN_TYPE_FOR_ASYNC_GENERATOR, typ + ) else: if not self.is_generator_return_type(typ.ret_type, defn.is_coroutine): self.fail(message_registry.INVALID_RETURN_TYPE_FOR_GENERATOR, typ) # Python 2 generators aren't allowed to return values. orig_ret_type = get_proper_type(typ.ret_type) - if (self.options.python_version[0] == 2 and - isinstance(orig_ret_type, Instance) and - orig_ret_type.type.fullname == 'typing.Generator'): - if not isinstance(get_proper_type(orig_ret_type.args[2]), - (NoneType, AnyType)): + if ( + self.options.python_version[0] == 2 + and isinstance(orig_ret_type, Instance) + and orig_ret_type.type.fullname == "typing.Generator" + ): + if not isinstance( + get_proper_type(orig_ret_type.args[2]), (NoneType, AnyType) + ): self.fail(message_registry.INVALID_GENERATOR_RETURN_ITEM_TYPE, typ) # Fix the type if decorated with `@types.coroutine` or `@asyncio.coroutine`. @@ -953,8 +1119,9 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) tr = self.get_coroutine_return_type(t) else: tr = self.get_generator_return_type(t, c) - ret_type = self.named_generic_type('typing.AwaitableGenerator', - [ty, tc, tr, t]) + ret_type = self.named_generic_type( + "typing.AwaitableGenerator", [ty, tc, tr, t] + ) typ = typ.copy_modified(ret_type=ret_type) defn.type = typ @@ -968,31 +1135,42 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) # We temporary push the definition to get the self type as # visible from *inside* of this function/method. ref_type: Optional[Type] = self.scope.active_self_type() - if (isinstance(defn, FuncDef) and ref_type is not None and i == 0 - and not defn.is_static - and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2]): - isclass = defn.is_class or defn.name in ('__new__', '__init_subclass__') + if ( + isinstance(defn, FuncDef) + and ref_type is not None + and i == 0 + and not defn.is_static + and typ.arg_kinds[0] not in [nodes.ARG_STAR, nodes.ARG_STAR2] + ): + isclass = defn.is_class or defn.name in ("__new__", "__init_subclass__") if isclass: ref_type = mypy.types.TypeType.make_normalized(ref_type) erased = get_proper_type(erase_to_bound(arg_type)) if not is_subtype(ref_type, erased, ignore_type_params=True): note = None - if (isinstance(erased, Instance) and erased.type.is_protocol or - isinstance(erased, TypeType) and - isinstance(erased.item, Instance) and - erased.item.type.is_protocol): + if ( + isinstance(erased, Instance) + and erased.type.is_protocol + or isinstance(erased, TypeType) + and isinstance(erased.item, Instance) + and erased.item.type.is_protocol + ): # We allow the explicit self-type to be not a supertype of # the current class if it is a protocol. For such cases # the consistency check will be performed at call sites. msg = None - elif typ.arg_names[i] in {'self', 'cls'}: - if (self.options.python_version[0] < 3 - and is_same_type(erased, arg_type) and not isclass): + elif typ.arg_names[i] in {"self", "cls"}: + if ( + self.options.python_version[0] < 3 + and is_same_type(erased, arg_type) + and not isclass + ): msg = message_registry.INVALID_SELF_TYPE_OR_EXTRA_ARG - note = '(Hint: typically annotations omit the type for self)' + note = "(Hint: typically annotations omit the type for self)" else: msg = message_registry.ERASED_SELF_TYPE_NOT_SUPERTYPE.format( - erased, ref_type) + erased, ref_type + ) else: msg = message_registry.MISSING_OR_INVALID_SELF_TYPE if msg: @@ -1002,9 +1180,9 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) elif isinstance(arg_type, TypeVarType): # Refuse covariant parameter type variables # TODO: check recursively for inner type variables - if ( - arg_type.variance == COVARIANT and - defn.name not in ('__init__', '__new__') + if arg_type.variance == COVARIANT and defn.name not in ( + "__init__", + "__new__", ): ctx: Context = arg_type if ctx.line < 0: @@ -1013,13 +1191,12 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) if typ.arg_kinds[i] == nodes.ARG_STAR: if not isinstance(arg_type, ParamSpecType): # builtins.tuple[T] is typing.Tuple[T, ...] - arg_type = self.named_generic_type('builtins.tuple', - [arg_type]) + arg_type = self.named_generic_type("builtins.tuple", [arg_type]) elif typ.arg_kinds[i] == nodes.ARG_STAR2: if not isinstance(arg_type, ParamSpecType): - arg_type = self.named_generic_type('builtins.dict', - [self.str_type(), - arg_type]) + arg_type = self.named_generic_type( + "builtins.dict", [self.str_type(), arg_type] + ) item.arguments[i].variable.type = arg_type # Type check initialization expressions. @@ -1041,10 +1218,12 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) unreachable = self.binder.is_unreachable() if self.options.warn_no_return and not unreachable: - if (defn.is_generator or - is_named_instance(self.return_types[-1], 'typing.AwaitableGenerator')): - return_type = self.get_generator_return_type(self.return_types[-1], - defn.is_coroutine) + if defn.is_generator or is_named_instance( + self.return_types[-1], "typing.AwaitableGenerator" + ): + return_type = self.get_generator_return_type( + self.return_types[-1], defn.is_coroutine + ) elif defn.is_coroutine: return_type = self.get_coroutine_return_type(self.return_types[-1]) else: @@ -1067,7 +1246,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: Optional[str]) def check_unbound_return_typevar(self, typ: CallableType) -> None: """Fails when the return typevar is not defined in arguments.""" - if (typ.ret_type in typ.variables): + if typ.ret_type in typ.variables: arg_type_visitor = CollectArgTypes() for argtype in typ.arg_types: argtype.accept(arg_type_visitor) @@ -1082,8 +1261,8 @@ def check_default_args(self, item: FuncItem, body_is_trivial: bool) -> None: if body_is_trivial and isinstance(arg.initializer, EllipsisExpr): continue name = arg.variable.name - msg = 'Incompatible default for ' - if name.startswith('__tuple_arg_'): + msg = "Incompatible default for " + if name.startswith("__tuple_arg_"): msg += f"tuple argument {name[12:]}" else: msg += f'argument "{name}"' @@ -1092,18 +1271,19 @@ def check_default_args(self, item: FuncItem, body_is_trivial: bool) -> None: arg.initializer, context=arg.initializer, msg=msg, - lvalue_name='argument', - rvalue_name='default', - code=codes.ASSIGNMENT) + lvalue_name="argument", + rvalue_name="default", + code=codes.ASSIGNMENT, + ) def is_forward_op_method(self, method_name: str) -> bool: - if self.options.python_version[0] == 2 and method_name == '__div__': + if self.options.python_version[0] == 2 and method_name == "__div__": return True else: return method_name in operators.reverse_op_methods def is_reverse_op_method(self, method_name: str) -> bool: - if self.options.python_version[0] == 2 and method_name == '__rdiv__': + if self.options.python_version[0] == 2 and method_name == "__rdiv__": return True else: return method_name in operators.reverse_op_method_set @@ -1115,20 +1295,25 @@ def is_unannotated_any(t: Type) -> bool: return False return isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated - has_explicit_annotation = (isinstance(fdef.type, CallableType) - and any(not is_unannotated_any(t) - for t in fdef.type.arg_types + [fdef.type.ret_type])) + has_explicit_annotation = isinstance(fdef.type, CallableType) and any( + not is_unannotated_any(t) for t in fdef.type.arg_types + [fdef.type.ret_type] + ) show_untyped = not self.is_typeshed_stub or self.options.warn_incomplete_stub check_incomplete_defs = self.options.disallow_incomplete_defs and has_explicit_annotation if show_untyped and (self.options.disallow_untyped_defs or check_incomplete_defs): if fdef.type is None and self.options.disallow_untyped_defs: - if (not fdef.arguments or (len(fdef.arguments) == 1 and - (fdef.arg_names[0] == 'self' or fdef.arg_names[0] == 'cls'))): + if not fdef.arguments or ( + len(fdef.arguments) == 1 + and (fdef.arg_names[0] == "self" or fdef.arg_names[0] == "cls") + ): self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef) if not has_return_statement(fdef) and not fdef.is_generator: - self.note('Use "-> None" if function does not return a value', fdef, - code=codes.NO_UNTYPED_DEF) + self.note( + 'Use "-> None" if function does not return a value', + fdef, + code=codes.NO_UNTYPED_DEF, + ) else: self.fail(message_registry.FUNCTION_TYPE_EXPECTED, fdef) elif isinstance(fdef.type, CallableType): @@ -1136,8 +1321,9 @@ def is_unannotated_any(t: Type) -> bool: if is_unannotated_any(ret_type): self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef) elif fdef.is_generator: - if is_unannotated_any(self.get_generator_return_type(ret_type, - fdef.is_coroutine)): + if is_unannotated_any( + self.get_generator_return_type(ret_type, fdef.is_coroutine) + ): self.fail(message_registry.RETURN_TYPE_EXPECTED, fdef) elif fdef.is_coroutine and isinstance(ret_type, Instance): if is_unannotated_any(self.get_coroutine_return_type(ret_type)): @@ -1157,15 +1343,14 @@ def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None: self.type_type(), fdef, message_registry.INVALID_NEW_TYPE, - 'returns', - 'but must return a subtype of' + "returns", + "but must return a subtype of", ) - elif not isinstance(get_proper_type(bound_type.ret_type), - (AnyType, Instance, TupleType)): + elif not isinstance(get_proper_type(bound_type.ret_type), (AnyType, Instance, TupleType)): self.fail( - message_registry.NON_INSTANCE_NEW_TYPE.format( - format_type(bound_type.ret_type)), - fdef) + message_registry.NON_INSTANCE_NEW_TYPE.format(format_type(bound_type.ret_type)), + fdef, + ) else: # And that it returns a subtype of the class self.check_subtype( @@ -1173,8 +1358,8 @@ def check___new___signature(self, fdef: FuncDef, typ: CallableType) -> None: self_type, fdef, message_registry.INVALID_NEW_TYPE, - 'returns', - 'but must return a subtype of' + "returns", + "but must return a subtype of", ) def is_trivial_body(self, block: Block) -> bool: @@ -1196,8 +1381,11 @@ def halt(self, reason: str = ...) -> NoReturn: body = block.body # Skip a docstring - if (body and isinstance(body[0], ExpressionStmt) and - isinstance(body[0].expr, (StrExpr, UnicodeExpr))): + if ( + body + and isinstance(body[0], ExpressionStmt) + and isinstance(body[0].expr, (StrExpr, UnicodeExpr)) + ): body = block.body[1:] if len(body) == 0: @@ -1215,16 +1403,15 @@ def halt(self, reason: str = ...) -> NoReturn: if isinstance(expr, CallExpr): expr = expr.callee - return (isinstance(expr, NameExpr) - and expr.fullname == 'builtins.NotImplementedError') + return isinstance(expr, NameExpr) and expr.fullname == "builtins.NotImplementedError" - return (isinstance(stmt, PassStmt) or - (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))) + return isinstance(stmt, PassStmt) or ( + isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr) + ) - def check_reverse_op_method(self, defn: FuncItem, - reverse_type: CallableType, reverse_name: str, - context: Context) -> None: + def check_reverse_op_method( + self, defn: FuncItem, reverse_type: CallableType, reverse_name: str, context: Context + ) -> None: """Check a reverse operator method such as __radd__.""" # Decides whether it's worth calling check_overlapping_op_methods(). @@ -1235,17 +1422,18 @@ def check_reverse_op_method(self, defn: FuncItem, assert defn.info # First check for a valid signature - method_type = CallableType([AnyType(TypeOfAny.special_form), - AnyType(TypeOfAny.special_form)], - [nodes.ARG_POS, nodes.ARG_POS], - [None, None], - AnyType(TypeOfAny.special_form), - self.named_type('builtins.function')) + method_type = CallableType( + [AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)], + [nodes.ARG_POS, nodes.ARG_POS], + [None, None], + AnyType(TypeOfAny.special_form), + self.named_type("builtins.function"), + ) if not is_subtype(reverse_type, method_type): self.msg.invalid_signature(reverse_type, context) return - if reverse_name in ('__eq__', '__ne__'): + if reverse_name in ("__eq__", "__ne__"): # These are defined for all objects => can't cause trouble. return @@ -1255,16 +1443,18 @@ def check_reverse_op_method(self, defn: FuncItem, if isinstance(ret_type, AnyType): return if isinstance(ret_type, Instance): - if ret_type.type.fullname == 'builtins.object': + if ret_type.type.fullname == "builtins.object": return if reverse_type.arg_kinds[0] == ARG_STAR: - reverse_type = reverse_type.copy_modified(arg_types=[reverse_type.arg_types[0]] * 2, - arg_kinds=[ARG_POS] * 2, - arg_names=[reverse_type.arg_names[0], "_"]) + reverse_type = reverse_type.copy_modified( + arg_types=[reverse_type.arg_types[0]] * 2, + arg_kinds=[ARG_POS] * 2, + arg_names=[reverse_type.arg_names[0], "_"], + ) assert len(reverse_type.arg_types) >= 2 - if self.options.python_version[0] == 2 and reverse_name == '__rdiv__': - forward_name = '__div__' + if self.options.python_version[0] == 2 and reverse_name == "__rdiv__": + forward_name = "__div__" else: forward_name = operators.normal_from_reverse_op[reverse_name] forward_inst = get_proper_type(reverse_type.arg_types[1]) @@ -1280,24 +1470,35 @@ def check_reverse_op_method(self, defn: FuncItem, opt_meta = item.type.metaclass_type if opt_meta is not None: forward_inst = opt_meta - if not (isinstance(forward_inst, (Instance, UnionType)) - and forward_inst.has_readable_member(forward_name)): + if not ( + isinstance(forward_inst, (Instance, UnionType)) + and forward_inst.has_readable_member(forward_name) + ): return forward_base = reverse_type.arg_types[1] - forward_type = self.expr_checker.analyze_external_member_access(forward_name, forward_base, - context=defn) - self.check_overlapping_op_methods(reverse_type, reverse_name, defn.info, - forward_type, forward_name, forward_base, - context=defn) - - def check_overlapping_op_methods(self, - reverse_type: CallableType, - reverse_name: str, - reverse_class: TypeInfo, - forward_type: Type, - forward_name: str, - forward_base: Type, - context: Context) -> None: + forward_type = self.expr_checker.analyze_external_member_access( + forward_name, forward_base, context=defn + ) + self.check_overlapping_op_methods( + reverse_type, + reverse_name, + defn.info, + forward_type, + forward_name, + forward_base, + context=defn, + ) + + def check_overlapping_op_methods( + self, + reverse_type: CallableType, + reverse_name: str, + reverse_class: TypeInfo, + forward_type: Type, + forward_name: str, + forward_base: Type, + context: Context, + ) -> None: """Check for overlapping method and reverse method signatures. This function assumes that: @@ -1340,22 +1541,20 @@ def check_overlapping_op_methods(self, if isinstance(forward_item, CallableType): if self.is_unsafe_overlapping_op(forward_item, forward_base, reverse_type): self.msg.operator_method_signatures_overlap( - reverse_class, reverse_name, - forward_base, forward_name, context) + reverse_class, reverse_name, forward_base, forward_name, context + ) elif isinstance(forward_item, Overloaded): for item in forward_item.items: if self.is_unsafe_overlapping_op(item, forward_base, reverse_type): self.msg.operator_method_signatures_overlap( - reverse_class, reverse_name, - forward_base, forward_name, - context) + reverse_class, reverse_name, forward_base, forward_name, context + ) elif not isinstance(forward_item, AnyType): self.msg.forward_operator_not_callable(forward_name, context) - def is_unsafe_overlapping_op(self, - forward_item: CallableType, - forward_base: Type, - reverse_type: CallableType) -> bool: + def is_unsafe_overlapping_op( + self, forward_item: CallableType, forward_base: Type, reverse_type: CallableType + ) -> bool: # TODO: check argument kinds? if len(forward_item.arg_types) < 1: # Not a valid operator method -- can't succeed anyway. @@ -1409,11 +1608,12 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None: return typ = bind_self(self.function_type(defn)) cls = defn.info - other_method = '__' + method[3:] + other_method = "__" + method[3:] if cls.has_readable_member(other_method): instance = fill_typevars(cls) - typ2 = get_proper_type(self.expr_checker.analyze_external_member_access( - other_method, instance, defn)) + typ2 = get_proper_type( + self.expr_checker.analyze_external_member_access(other_method, instance, defn) + ) fail = False if isinstance(typ2, FunctionLike): if not is_more_general_arg_prefix(typ, typ2): @@ -1427,24 +1627,27 @@ def check_inplace_operator_method(self, defn: FuncBase) -> None: def check_getattr_method(self, typ: Type, context: Context, name: str) -> None: if len(self.scope.stack) == 1: # module scope - if name == '__getattribute__': + if name == "__getattribute__": self.fail(message_registry.MODULE_LEVEL_GETATTRIBUTE, context) return # __getattr__ is fine at the module level as of Python 3.7 (PEP 562). We could # show an error for Python < 3.7, but that would be annoying in code that supports # both 3.7 and older versions. - method_type = CallableType([self.named_type('builtins.str')], - [nodes.ARG_POS], - [None], - AnyType(TypeOfAny.special_form), - self.named_type('builtins.function')) + method_type = CallableType( + [self.named_type("builtins.str")], + [nodes.ARG_POS], + [None], + AnyType(TypeOfAny.special_form), + self.named_type("builtins.function"), + ) elif self.scope.active_class(): - method_type = CallableType([AnyType(TypeOfAny.special_form), - self.named_type('builtins.str')], - [nodes.ARG_POS, nodes.ARG_POS], - [None, None], - AnyType(TypeOfAny.special_form), - self.named_type('builtins.function')) + method_type = CallableType( + [AnyType(TypeOfAny.special_form), self.named_type("builtins.str")], + [nodes.ARG_POS, nodes.ARG_POS], + [None, None], + AnyType(TypeOfAny.special_form), + self.named_type("builtins.function"), + ) else: return if not is_subtype(typ, method_type): @@ -1453,37 +1656,52 @@ def check_getattr_method(self, typ: Type, context: Context, name: str) -> None: def check_setattr_method(self, typ: Type, context: Context) -> None: if not self.scope.active_class(): return - method_type = CallableType([AnyType(TypeOfAny.special_form), - self.named_type('builtins.str'), - AnyType(TypeOfAny.special_form)], - [nodes.ARG_POS, nodes.ARG_POS, nodes.ARG_POS], - [None, None, None], - NoneType(), - self.named_type('builtins.function')) + method_type = CallableType( + [ + AnyType(TypeOfAny.special_form), + self.named_type("builtins.str"), + AnyType(TypeOfAny.special_form), + ], + [nodes.ARG_POS, nodes.ARG_POS, nodes.ARG_POS], + [None, None, None], + NoneType(), + self.named_type("builtins.function"), + ) if not is_subtype(typ, method_type): - self.msg.invalid_signature_for_special_method(typ, context, '__setattr__') + self.msg.invalid_signature_for_special_method(typ, context, "__setattr__") def check_slots_definition(self, typ: Type, context: Context) -> None: """Check the type of __slots__.""" str_type = self.named_type("builtins.str") - expected_type = UnionType([str_type, - self.named_generic_type("typing.Iterable", [str_type])]) - self.check_subtype(typ, expected_type, context, - message_registry.INVALID_TYPE_FOR_SLOTS, - 'actual type', - 'expected type', - code=codes.ASSIGNMENT) + expected_type = UnionType( + [str_type, self.named_generic_type("typing.Iterable", [str_type])] + ) + self.check_subtype( + typ, + expected_type, + context, + message_registry.INVALID_TYPE_FOR_SLOTS, + "actual type", + "expected type", + code=codes.ASSIGNMENT, + ) def check_match_args(self, var: Var, typ: Type, context: Context) -> None: """Check that __match_args__ contains literal strings""" typ = get_proper_type(typ) - if not isinstance(typ, TupleType) or \ - not all([is_string_literal(item) for item in typ.items]): - self.msg.note("__match_args__ must be a tuple containing string literals for checking " - "of match statements to work", context, code=codes.LITERAL_REQ) + if not isinstance(typ, TupleType) or not all( + [is_string_literal(item) for item in typ.items] + ): + self.msg.note( + "__match_args__ must be a tuple containing string literals for checking " + "of match statements to work", + context, + code=codes.LITERAL_REQ, + ) - def expand_typevars(self, defn: FuncItem, - typ: CallableType) -> List[Tuple[FuncItem, CallableType]]: + def expand_typevars( + self, defn: FuncItem, typ: CallableType + ) -> List[Tuple[FuncItem, CallableType]]: # TODO use generator subst: List[List[Tuple[TypeVarId, Type]]] = [] tvars = list(typ.variables) or [] @@ -1518,10 +1736,9 @@ def check_method_override(self, defn: Union[FuncDef, OverloadedFuncDef, Decorato # Node was deferred, we will have another attempt later. return - def check_method_or_accessor_override_for_base(self, defn: Union[FuncDef, - OverloadedFuncDef, - Decorator], - base: TypeInfo) -> bool: + def check_method_or_accessor_override_for_base( + self, defn: Union[FuncDef, OverloadedFuncDef, Decorator], base: TypeInfo + ) -> bool: """Check if method definition is compatible with a base class. Return True if the node was deferred because one of the corresponding @@ -1539,25 +1756,24 @@ def check_method_or_accessor_override_for_base(self, defn: Union[FuncDef, self.check_if_final_var_override_writable(name, base_attr.node, defn) # Check the type of override. - if name not in ('__init__', '__new__', '__init_subclass__'): + if name not in ("__init__", "__new__", "__init_subclass__"): # Check method override # (__init__, __new__, __init_subclass__ are special). if self.check_method_override_for_base_with_name(defn, name, base): return True if name in operators.inplace_operator_methods: # Figure out the name of the corresponding operator method. - method = '__' + name[3:] + method = "__" + name[3:] # An inplace operator method such as __iadd__ might not be # always introduced safely if a base class defined __add__. # TODO can't come up with an example where this is # necessary; now it's "just in case" - return self.check_method_override_for_base_with_name(defn, method, - base) + return self.check_method_override_for_base_with_name(defn, method, base) return False def check_method_override_for_base_with_name( - self, defn: Union[FuncDef, OverloadedFuncDef, Decorator], - name: str, base: TypeInfo) -> bool: + self, defn: Union[FuncDef, OverloadedFuncDef, Decorator], name: str, base: TypeInfo + ) -> bool: """Check if overriding an attribute `name` of `base` with `defn` is valid. Return True if the supertype node was not analysed yet, and `defn` was deferred. @@ -1586,8 +1802,7 @@ def check_method_override_for_base_with_name( override_class = defn.func.is_class typ = get_proper_type(typ) if isinstance(typ, FunctionLike) and not is_static(context): - typ = bind_self(typ, self.scope.active_self_type(), - is_classmethod=override_class) + typ = bind_self(typ, self.scope.active_self_type(), is_classmethod=override_class) # Map the overridden method type to subtype context so that # it can be checked for compatibility. original_type = get_proper_type(base_attr.type) @@ -1627,35 +1842,39 @@ def check_method_override_for_base_with_name( if isinstance(original_type, AnyType) or isinstance(typ, AnyType): pass elif isinstance(original_type, FunctionLike) and isinstance(typ, FunctionLike): - original = self.bind_and_map_method(base_attr, original_type, - defn.info, base) + original = self.bind_and_map_method(base_attr, original_type, defn.info, base) # Check that the types are compatible. # TODO overloaded signatures - self.check_override(typ, - original, - defn.name, - name, - base.name, - original_class_or_static, - override_class_or_static, - context) + self.check_override( + typ, + original, + defn.name, + name, + base.name, + original_class_or_static, + override_class_or_static, + context, + ) elif is_equivalent(original_type, typ): # Assume invariance for a non-callable attribute here. Note # that this doesn't affect read-only properties which can have # covariant overrides. # pass - elif (base_attr.node and not self.is_writable_attribute(base_attr.node) - and is_subtype(typ, original_type)): + elif ( + base_attr.node + and not self.is_writable_attribute(base_attr.node) + and is_subtype(typ, original_type) + ): # If the attribute is read-only, allow covariance pass else: - self.msg.signature_incompatible_with_supertype( - defn.name, name, base.name, context) + self.msg.signature_incompatible_with_supertype(defn.name, name, base.name, context) return False - def bind_and_map_method(self, sym: SymbolTableNode, typ: FunctionLike, - sub_info: TypeInfo, super_info: TypeInfo) -> FunctionLike: + def bind_and_map_method( + self, sym: SymbolTableNode, typ: FunctionLike, sub_info: TypeInfo, super_info: TypeInfo + ) -> FunctionLike: """Bind self-type and map type variables for a method. Arguments: @@ -1664,8 +1883,9 @@ def bind_and_map_method(self, sym: SymbolTableNode, typ: FunctionLike, sub_info: class where the method is used super_info: class where the method was defined """ - if (isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) - and not is_static(sym.node)): + if isinstance(sym.node, (FuncDef, OverloadedFuncDef, Decorator)) and not is_static( + sym.node + ): if isinstance(sym.node, Decorator): is_class_method = sym.node.func.is_class else: @@ -1689,11 +1909,17 @@ def get_op_other_domain(self, tp: FunctionLike) -> Optional[Type]: else: assert False, "Need to check all FunctionLike subtypes here" - def check_override(self, override: FunctionLike, original: FunctionLike, - name: str, name_in_super: str, supertype: str, - original_class_or_static: bool, - override_class_or_static: bool, - node: Context) -> None: + def check_override( + self, + override: FunctionLike, + original: FunctionLike, + name: str, + name_in_super: str, + supertype: str, + original_class_or_static: bool, + override_class_or_static: bool, + node: Context, + ) -> None: """Check a method override with given signatures. Arguments: @@ -1713,8 +1939,11 @@ def check_override(self, override: FunctionLike, original: FunctionLike, # this could be unsafe with reverse operator methods. original_domain = self.get_op_other_domain(original) override_domain = self.get_op_other_domain(override) - if (original_domain and override_domain and - not is_subtype(override_domain, original_domain)): + if ( + original_domain + and override_domain + and not is_subtype(override_domain, original_domain) + ): fail = True op_method_wider_note = True if isinstance(original, FunctionLike) and isinstance(override, FunctionLike): @@ -1729,10 +1958,12 @@ def check_override(self, override: FunctionLike, original: FunctionLike, if fail: emitted_msg = False - if (isinstance(override, CallableType) and - isinstance(original, CallableType) and - len(override.arg_types) == len(original.arg_types) and - override.min_args == original.min_args): + if ( + isinstance(override, CallableType) + and isinstance(original, CallableType) + and len(override.arg_types) == len(original.arg_types) + and override.min_args == original.min_args + ): # Give more detailed messages for the common case of both # signatures having the same number of arguments and no # overloads. @@ -1752,8 +1983,9 @@ def erase_override(t: Type) -> Type: return erase_typevars(t, ids_to_erase=override_ids) for i in range(len(override.arg_types)): - if not is_subtype(original.arg_types[i], - erase_override(override.arg_types[i])): + if not is_subtype( + original.arg_types[i], erase_override(override.arg_types[i]) + ): arg_type_in_super = original.arg_types[i] self.msg.argument_incompatible_with_supertype( i + 1, @@ -1762,14 +1994,14 @@ def erase_override(t: Type) -> Type: name_in_super, arg_type_in_super, supertype, - node + node, ) emitted_msg = True - if not is_subtype(erase_override(override.ret_type), - original.ret_type): + if not is_subtype(erase_override(override.ret_type), original.ret_type): self.msg.return_type_incompatible_with_supertype( - name, name_in_super, supertype, original.ret_type, override.ret_type, node) + name, name_in_super, supertype, original.ret_type, override.ret_type, node + ) emitted_msg = True elif isinstance(override, Overloaded) and isinstance(original, Overloaded): # Give a more detailed message in the case where the user is trying to @@ -1788,16 +2020,21 @@ def erase_override(t: Type) -> Type: if len(order) == len(original.items) and order != sorted(order): self.msg.overload_signature_incompatible_with_supertype( - name, name_in_super, supertype, node) + name, name_in_super, supertype, node + ) emitted_msg = True if not emitted_msg: # Fall back to generic incompatibility message. self.msg.signature_incompatible_with_supertype( - name, name_in_super, supertype, node, original=original, override=override) + name, name_in_super, supertype, node, original=original, override=override + ) if op_method_wider_note: - self.note("Overloaded operator methods can't have wider argument types" - " in overrides", node, code=codes.OVERRIDE) + self.note( + "Overloaded operator methods can't have wider argument types" " in overrides", + node, + code=codes.OVERRIDE, + ) def check__exit__return_type(self, defn: FuncItem) -> None: """Generate error if the return type of __exit__ is problematic. @@ -1818,8 +2055,10 @@ def check__exit__return_type(self, defn: FuncItem) -> None: if not returns: return - if all(isinstance(ret.expr, NameExpr) and ret.expr.fullname == 'builtins.False' - for ret in returns): + if all( + isinstance(ret.expr, NameExpr) and ret.expr.fullname == "builtins.False" + for ret in returns + ): self.msg.incorrect__exit__return(defn) def visit_class_def(self, defn: ClassDef) -> None: @@ -1847,8 +2086,9 @@ def visit_class_def(self, defn: ClassDef) -> None: sig: Type = type_object_type(defn.info, self.named_type) # Decorators are applied in reverse order. for decorator in reversed(defn.decorators): - if (isinstance(decorator, CallExpr) - and isinstance(decorator.analyzed, PromoteExpr)): + if isinstance(decorator, CallExpr) and isinstance( + decorator.analyzed, PromoteExpr + ): # _promote is a special type checking related construct. continue @@ -1860,9 +2100,9 @@ def visit_class_def(self, defn: ClassDef) -> None: # TODO: Figure out how to have clearer error messages. # (e.g. "class decorator must be a function that accepts a type." - sig, _ = self.expr_checker.check_call(dec, [temp], - [nodes.ARG_POS], defn, - callable_name=fullname) + sig, _ = self.expr_checker.check_call( + dec, [temp], [nodes.ARG_POS], defn, callable_name=fullname + ) # TODO: Apply the sig to the actual TypeInfo so we can handle decorators # that completely swap out the type. (e.g. Callable[[Type[A]], Type[B]]) if typ.is_protocol and typ.defn.type_vars: @@ -1893,25 +2133,27 @@ def check_init_subclass(self, defn: ClassDef) -> None: Base.__init_subclass__(thing=5) is called at line 4. This is what we simulate here. Child.__init_subclass__ is never called. """ - if (defn.info.metaclass_type and - defn.info.metaclass_type.type.fullname not in ('builtins.type', 'abc.ABCMeta')): + if defn.info.metaclass_type and defn.info.metaclass_type.type.fullname not in ( + "builtins.type", + "abc.ABCMeta", + ): # We can't safely check situations when both __init_subclass__ and a custom # metaclass are present. return # At runtime, only Base.__init_subclass__ will be called, so # we skip the current class itself. for base in defn.info.mro[1:]: - if '__init_subclass__' not in base.names: + if "__init_subclass__" not in base.names: continue name_expr = NameExpr(defn.name) name_expr.node = base - callee = MemberExpr(name_expr, '__init_subclass__') + callee = MemberExpr(name_expr, "__init_subclass__") args = list(defn.keywords.values()) arg_names: List[Optional[str]] = list(defn.keywords.keys()) # 'metaclass' keyword is consumed by the rest of the type machinery, # and is never passed to __init_subclass__ implementations - if 'metaclass' in arg_names: - idx = arg_names.index('metaclass') + if "metaclass" in arg_names: + idx = arg_names.index("metaclass") arg_names.pop(idx) args.pop(idx) arg_kinds = [ARG_NAMED] * len(args) @@ -1919,9 +2161,7 @@ def check_init_subclass(self, defn: ClassDef) -> None: call_expr.line = defn.line call_expr.column = defn.column call_expr.end_line = defn.end_line - self.expr_checker.accept(call_expr, - allow_none_return=True, - always_allow_any=True) + self.expr_checker.accept(call_expr, allow_none_return=True, always_allow_any=True) # We are only interested in the first Base having __init_subclass__, # all other bases have already been checked. break @@ -1930,13 +2170,14 @@ def check_enum(self, defn: ClassDef) -> None: assert defn.info.is_enum if defn.info.fullname not in ENUM_BASES: for sym in defn.info.names.values(): - if (isinstance(sym.node, Var) and sym.node.has_explicit_value and - sym.node.name == '__members__'): + if ( + isinstance(sym.node, Var) + and sym.node.has_explicit_value + and sym.node.name == "__members__" + ): # `__members__` will always be overwritten by `Enum` and is considered # read-only so we disallow assigning a value to it - self.fail( - message_registry.ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDEN, sym.node - ) + self.fail(message_registry.ENUM_MEMBERS_ATTR_WILL_BE_OVERRIDEN, sym.node) for base in defn.info.mro[1:-1]: # we don't need self and `object` if base.is_enum and base.fullname not in ENUM_BASES: self.check_final_enum(defn, base) @@ -1947,10 +2188,7 @@ def check_enum(self, defn: ClassDef) -> None: def check_final_enum(self, defn: ClassDef, base: TypeInfo) -> None: for sym in base.names.values(): if self.is_final_enum_value(sym): - self.fail( - f'Cannot extend enum with existing members: "{base.name}"', - defn, - ) + self.fail(f'Cannot extend enum with existing members: "{base.name}"', defn) break def is_final_enum_value(self, sym: SymbolTableNode) -> bool: @@ -2001,19 +2239,16 @@ class Baz(int, Foo, Bar, enum.Flag): ... enum_base = base continue elif enum_base is not None and not base.type.is_enum: - self.fail( - f'No non-enum mixin classes are allowed after "{enum_base}"', - defn, - ) + self.fail(f'No non-enum mixin classes are allowed after "{enum_base}"', defn) break def check_enum_new(self, defn: ClassDef) -> None: def has_new_method(info: TypeInfo) -> bool: - new_method = info.get('__new__') + new_method = info.get("__new__") return bool( new_method and new_method.node - and new_method.node.fullname != 'builtins.object.__new__' + and new_method.node.fullname != "builtins.object.__new__" ) has_new = False @@ -2022,16 +2257,13 @@ def has_new_method(info: TypeInfo) -> bool: if base.type.is_enum: # If we have an `Enum`, then we need to check all its bases. - candidate = any( - not b.is_enum and has_new_method(b) - for b in base.type.mro[1:-1] - ) + candidate = any(not b.is_enum and has_new_method(b) for b in base.type.mro[1:-1]) else: candidate = has_new_method(base.type) if candidate and has_new: self.fail( - 'Only a single data type mixin is allowed for Enum subtypes, ' + "Only a single data type mixin is allowed for Enum subtypes, " 'found extra "{}"'.format(base), defn, ) @@ -2083,7 +2315,7 @@ def check_multiple_inheritance(self, typ: TypeInfo) -> None: for name in non_overridden_attrs: if is_private(name): continue - for base2 in mro[i + 1:]: + for base2 in mro[i + 1 :]: # We only need to check compatibility of attributes from classes not # in a subclass relationship. For subclasses, normal (single inheritance) # checks suffice (these are implemented elsewhere). @@ -2104,8 +2336,9 @@ def determine_type_of_class_member(self, sym: SymbolTableNode) -> Optional[Type] return AnyType(TypeOfAny.special_form) return None - def check_compatibility(self, name: str, base1: TypeInfo, - base2: TypeInfo, ctx: TypeInfo) -> None: + def check_compatibility( + self, name: str, base1: TypeInfo, base2: TypeInfo, ctx: TypeInfo + ) -> None: """Check if attribute name in base1 is compatible with base2 in multiple inheritance. Assume base1 comes before base2 in the MRO, and that base1 and base2 don't have @@ -2126,7 +2359,7 @@ class C(B, A[int]): ... # this is unsafe because... x: A[int] = C() x.foo # ...runtime type is (str) -> None, while static type is (int) -> None """ - if name in ('__init__', '__new__', '__init_subclass__'): + if name in ("__init__", "__new__", "__init_subclass__"): # __init__ and friends can be incompatible -- it's a special case. return first = base1.names[name] @@ -2134,13 +2367,14 @@ class C(B, A[int]): ... # this is unsafe because... first_type = get_proper_type(self.determine_type_of_class_member(first)) second_type = get_proper_type(self.determine_type_of_class_member(second)) - if (isinstance(first_type, FunctionLike) and - isinstance(second_type, FunctionLike)): + if isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike): if first_type.is_type_obj() and second_type.is_type_obj(): # For class objects only check the subtype relationship of the classes, # since we allow incompatible overrides of '__init__'/'__new__' - ok = is_subtype(left=fill_typevars_with_any(first_type.type_object()), - right=fill_typevars_with_any(second_type.type_object())) + ok = is_subtype( + left=fill_typevars_with_any(first_type.type_object()), + right=fill_typevars_with_any(second_type.type_object()), + ) else: # First bind/map method types when necessary. first_sig = self.bind_and_map_method(first, first_type, ctx, base1) @@ -2169,8 +2403,7 @@ class C(B, A[int]): ... # this is unsafe because... if isinstance(second.node, Var) and second.node.allow_incompatible_override: ok = True if not ok: - self.msg.base_class_definitions_incompatible(name, base1, base2, - ctx) + self.msg.base_class_definitions_incompatible(name, base1, base2, ctx) def visit_import_from(self, node: ImportFrom) -> None: self.check_import(node) @@ -2188,11 +2421,17 @@ def check_import(self, node: ImportBase) -> None: if lvalue_type is None: # TODO: This is broken. lvalue_type = AnyType(TypeOfAny.special_form) - message = '{} "{}"'.format(message_registry.INCOMPATIBLE_IMPORT_OF, - cast(NameExpr, assign.rvalue).name) - self.check_simple_assignment(lvalue_type, assign.rvalue, node, - msg=message, lvalue_name='local name', - rvalue_name='imported name') + message = '{} "{}"'.format( + message_registry.INCOMPATIBLE_IMPORT_OF, cast(NameExpr, assign.rvalue).name + ) + self.check_simple_assignment( + lvalue_type, + assign.rvalue, + node, + msg=message, + lvalue_name="local name", + rvalue_name="imported name", + ) # # Statements @@ -2213,9 +2452,11 @@ def visit_block(self, b: Block) -> None: self.accept(s) def should_report_unreachable_issues(self) -> bool: - return (self.in_checked_function() - and self.options.warn_unreachable - and not self.binder.is_unreachable_warning_suppressed()) + return ( + self.in_checked_function() + and self.options.warn_unreachable + and not self.binder.is_unreachable_warning_suppressed() + ) def is_raising_or_empty(self, s: Statement) -> bool: """Returns 'true' if the given statement either throws an error of some kind @@ -2235,8 +2476,11 @@ def is_raising_or_empty(self, s: Statement) -> bool: return True elif isinstance(s.expr, CallExpr): with self.expr_checker.msg.filter_errors(): - typ = get_proper_type(self.expr_checker.accept( - s.expr, allow_none_return=True, always_allow_any=True)) + typ = get_proper_type( + self.expr_checker.accept( + s.expr, allow_none_return=True, always_allow_any=True + ) + ) if isinstance(typ, UninhabitedType): return True @@ -2257,14 +2501,17 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: if s.is_alias_def: self.check_type_alias_rvalue(s) - if (s.type is not None and - self.options.disallow_any_unimported and - has_any_from_unimported_type(s.type)): + if ( + s.type is not None + and self.options.disallow_any_unimported + and has_any_from_unimported_type(s.type) + ): if isinstance(s.lvalues[-1], TupleExpr): # This is a multiple assignment. Instead of figuring out which type is problematic, # give a generic error message. - self.msg.unimported_type_becomes_any("A type on this line", - AnyType(TypeOfAny.special_form), s) + self.msg.unimported_type_becomes_any( + "A type on this line", AnyType(TypeOfAny.special_form), s + ) else: self.msg.unimported_type_becomes_any("Type of variable", s.type, s) check_for_explicit_any(s.type, self.options, self.is_typeshed_stub, self.msg, context=s) @@ -2280,12 +2527,16 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: self.check_assignment(lv, rvalue, s.type is None) self.check_final(s) - if (s.is_final_def and s.type and not has_no_typevars(s.type) - and self.scope.active_class() is not None): + if ( + s.is_final_def + and s.type + and not has_no_typevars(s.type) + and self.scope.active_class() is not None + ): self.fail(message_registry.DEPENDENT_FINAL_IN_CLASS_BODY, s) def check_type_alias_rvalue(self, s: AssignmentStmt) -> None: - if not (self.is_stub and isinstance(s.rvalue, OpExpr) and s.rvalue.op == '|'): + if not (self.is_stub and isinstance(s.rvalue, OpExpr) and s.rvalue.op == "|"): # We do this mostly for compatibility with old semantic analyzer. # TODO: should we get rid of this? alias_type = self.expr_checker.accept(s.rvalue) @@ -2295,7 +2546,7 @@ def check_type_alias_rvalue(self, s: AssignmentStmt) -> None: alias_type = AnyType(TypeOfAny.special_form) def accept_items(e: Expression) -> None: - if isinstance(e, OpExpr) and e.op == '|': + if isinstance(e, OpExpr) and e.op == "|": accept_items(e.left) accept_items(e.right) else: @@ -2307,48 +2558,55 @@ def accept_items(e: Expression) -> None: accept_items(s.rvalue) self.store_type(s.lvalues[-1], alias_type) - def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type: bool = True, - new_syntax: bool = False) -> None: + def check_assignment( + self, + lvalue: Lvalue, + rvalue: Expression, + infer_lvalue_type: bool = True, + new_syntax: bool = False, + ) -> None: """Type check a single assignment: lvalue = rvalue.""" if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): - self.check_assignment_to_multiple_lvalues(lvalue.items, rvalue, rvalue, - infer_lvalue_type) + self.check_assignment_to_multiple_lvalues( + lvalue.items, rvalue, rvalue, infer_lvalue_type + ) else: - self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, '=') + self.try_infer_partial_generic_type_from_assignment(lvalue, rvalue, "=") lvalue_type, index_lvalue, inferred = self.check_lvalue(lvalue) # If we're assigning to __getattr__ or similar methods, check that the signature is # valid. if isinstance(lvalue, NameExpr) and lvalue.node: name = lvalue.node.name - if name in ('__setattr__', '__getattribute__', '__getattr__'): + if name in ("__setattr__", "__getattribute__", "__getattr__"): # If an explicit type is given, use that. if lvalue_type: signature = lvalue_type else: signature = self.expr_checker.accept(rvalue) if signature: - if name == '__setattr__': + if name == "__setattr__": self.check_setattr_method(signature, lvalue) else: self.check_getattr_method(signature, lvalue, name) - if name == '__slots__': + if name == "__slots__": typ = lvalue_type or self.expr_checker.accept(rvalue) self.check_slots_definition(typ, lvalue) - if name == '__match_args__' and inferred is not None: + if name == "__match_args__" and inferred is not None: typ = self.expr_checker.accept(rvalue) self.check_match_args(inferred, typ, lvalue) # Defer PartialType's super type checking. - if (isinstance(lvalue, RefExpr) and - not (isinstance(lvalue_type, PartialType) and - lvalue_type.type is None) and - not (isinstance(lvalue, NameExpr) and lvalue.name == '__match_args__')): + if ( + isinstance(lvalue, RefExpr) + and not (isinstance(lvalue_type, PartialType) and lvalue_type.type is None) + and not (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__") + ): if self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue): # We hit an error on this line; don't check for any others return - if isinstance(lvalue, MemberExpr) and lvalue.name == '__match_args__': + if isinstance(lvalue, MemberExpr) and lvalue.name == "__match_args__": self.fail(message_registry.CANNOT_MODIFY_MATCH_ARGS, lvalue) if lvalue_type: @@ -2367,8 +2625,7 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type if not self.current_node_deferred: # Partial type can't be final, so strip any literal values. rvalue_type = remove_instance_last_known_values(rvalue_type) - inferred_type = make_simplified_union( - [rvalue_type, NoneType()]) + inferred_type = make_simplified_union([rvalue_type, NoneType()]) self.set_inferred_type(var, lvalue, inferred_type) else: var.type = None @@ -2379,22 +2636,27 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type # an error will be reported elsewhere. self.infer_partial_type(lvalue_type.var, lvalue, rvalue_type) # Handle None PartialType's super type checking here, after it's resolved. - if (isinstance(lvalue, RefExpr) and - self.check_compatibility_all_supers(lvalue, lvalue_type, rvalue)): + if isinstance(lvalue, RefExpr) and self.check_compatibility_all_supers( + lvalue, lvalue_type, rvalue + ): # We hit an error on this line; don't check for any others return - elif (is_literal_none(rvalue) and - isinstance(lvalue, NameExpr) and - isinstance(lvalue.node, Var) and - lvalue.node.is_initialized_in_class and - not new_syntax): + elif ( + is_literal_none(rvalue) + and isinstance(lvalue, NameExpr) + and isinstance(lvalue.node, Var) + and lvalue.node.is_initialized_in_class + and not new_syntax + ): # Allow None's to be assigned to class variables with non-Optional types. rvalue_type = lvalue_type - elif (isinstance(lvalue, MemberExpr) and - lvalue.kind is None): # Ignore member access to modules + elif ( + isinstance(lvalue, MemberExpr) and lvalue.kind is None + ): # Ignore member access to modules instance_type = self.expr_checker.accept(lvalue.expr) rvalue_type, lvalue_type, infer_lvalue_type = self.check_member_assignment( - instance_type, lvalue_type, rvalue, context=rvalue) + instance_type, lvalue_type, rvalue, context=rvalue + ) else: # Hacky special case for assigning a literal None # to a variable defined in a previous if @@ -2402,13 +2664,15 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type # make the type optional. This is somewhat # unpleasant, and a generalization of this would # be an improvement! - if (is_literal_none(rvalue) and - isinstance(lvalue, NameExpr) and - lvalue.kind == LDEF and - isinstance(lvalue.node, Var) and - lvalue.node.type and - lvalue.node in self.var_decl_frames and - not isinstance(get_proper_type(lvalue_type), AnyType)): + if ( + is_literal_none(rvalue) + and isinstance(lvalue, NameExpr) + and lvalue.kind == LDEF + and isinstance(lvalue.node, Var) + and lvalue.node.type + and lvalue.node in self.var_decl_frames + and not isinstance(get_proper_type(lvalue_type), AnyType) + ): decl_frame_map = self.var_decl_frames[lvalue.node] # Check if the nearest common ancestor frame for the definition site # and the current site is the enclosing frame of an if/elif/else block. @@ -2421,20 +2685,25 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type lvalue_type = make_optional_type(lvalue_type) self.set_inferred_type(lvalue.node, lvalue, lvalue_type) - rvalue_type = self.check_simple_assignment(lvalue_type, rvalue, context=rvalue, - code=codes.ASSIGNMENT) + rvalue_type = self.check_simple_assignment( + lvalue_type, rvalue, context=rvalue, code=codes.ASSIGNMENT + ) # Special case: only non-abstract non-protocol classes can be assigned to # variables with explicit type Type[A], where A is protocol or abstract. rvalue_type = get_proper_type(rvalue_type) lvalue_type = get_proper_type(lvalue_type) - if (isinstance(rvalue_type, CallableType) and rvalue_type.is_type_obj() and - (rvalue_type.type_object().is_abstract or - rvalue_type.type_object().is_protocol) and - isinstance(lvalue_type, TypeType) and - isinstance(lvalue_type.item, Instance) and - (lvalue_type.item.type.is_abstract or - lvalue_type.item.type.is_protocol)): + if ( + isinstance(rvalue_type, CallableType) + and rvalue_type.is_type_obj() + and ( + rvalue_type.type_object().is_abstract + or rvalue_type.type_object().is_protocol + ) + and isinstance(lvalue_type, TypeType) + and isinstance(lvalue_type.item, Instance) + and (lvalue_type.item.type.is_abstract or lvalue_type.item.type.is_protocol) + ): self.msg.concrete_only_assign(lvalue_type, rvalue) return if rvalue_type and infer_lvalue_type and not isinstance(lvalue_type, PartialType): @@ -2447,22 +2716,20 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type if inferred: rvalue_type = self.expr_checker.accept(rvalue) - if not (inferred.is_final or (isinstance(lvalue, NameExpr) and - lvalue.name == '__match_args__')): + if not ( + inferred.is_final + or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__") + ): rvalue_type = remove_instance_last_known_values(rvalue_type) self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue) self.check_assignment_to_slots(lvalue) # (type, operator) tuples for augmented assignments supported with partial types - partial_type_augmented_ops: Final = { - ('builtins.list', '+'), - ('builtins.set', '|'), - } - - def try_infer_partial_generic_type_from_assignment(self, - lvalue: Lvalue, - rvalue: Expression, - op: str) -> None: + partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")} + + def try_infer_partial_generic_type_from_assignment( + self, lvalue: Lvalue, rvalue: Expression, op: str + ) -> None: """Try to infer a precise type for partial generic type from assignment. 'op' is '=' for normal assignment and a binary operator ('+', ...) for @@ -2475,9 +2742,11 @@ def try_infer_partial_generic_type_from_assignment(self, x = [1] # Infer List[int] as type of 'x' """ var = None - if (isinstance(lvalue, NameExpr) - and isinstance(lvalue.node, Var) - and isinstance(lvalue.node.type, PartialType)): + if ( + isinstance(lvalue, NameExpr) + and isinstance(lvalue.node, Var) + and isinstance(lvalue.node.type, PartialType) + ): var = lvalue.node elif isinstance(lvalue, MemberExpr): var = self.expr_checker.get_partial_self_var(lvalue) @@ -2487,7 +2756,7 @@ def try_infer_partial_generic_type_from_assignment(self, if typ.type is None: return # Return if this is an unsupported augmented assignment. - if op != '=' and (typ.type.fullname, op) not in self.partial_type_augmented_ops: + if op != "=" and (typ.type.fullname, op) not in self.partial_type_augmented_ops: return # TODO: some logic here duplicates the None partial type counterpart # inlined in check_assignment(), see #8043. @@ -2504,26 +2773,25 @@ def try_infer_partial_generic_type_from_assignment(self, var.type = fill_typevars_with_any(typ.type) del partial_types[var] - def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type], - rvalue: Expression) -> bool: + def check_compatibility_all_supers( + self, lvalue: RefExpr, lvalue_type: Optional[Type], rvalue: Expression + ) -> bool: lvalue_node = lvalue.node # Check if we are a class variable with at least one base class - if (isinstance(lvalue_node, Var) and - lvalue.kind in (MDEF, None) and # None for Vars defined via self - len(lvalue_node.info.bases) > 0): + if ( + isinstance(lvalue_node, Var) + and lvalue.kind in (MDEF, None) + and len(lvalue_node.info.bases) > 0 # None for Vars defined via self + ): for base in lvalue_node.info.mro[1:]: tnode = base.names.get(lvalue_node.name) if tnode is not None: - if not self.check_compatibility_classvar_super(lvalue_node, - base, - tnode.node): + if not self.check_compatibility_classvar_super(lvalue_node, base, tnode.node): # Show only one error per variable break - if not self.check_compatibility_final_super(lvalue_node, - base, - tnode.node): + if not self.check_compatibility_final_super(lvalue_node, base, tnode.node): # Show only one error per variable break @@ -2534,9 +2802,13 @@ def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[ # The type of "__slots__" and some other attributes usually doesn't need to # be compatible with a base class. We'll still check the type of "__slots__" # against "object" as an exception. - if (isinstance(lvalue_node, Var) and lvalue_node.allow_incompatible_override and - not (lvalue_node.name == "__slots__" and - base.fullname == "builtins.object")): + if ( + isinstance(lvalue_node, Var) + and lvalue_node.allow_incompatible_override + and not ( + lvalue_node.name == "__slots__" and base.fullname == "builtins.object" + ) + ): continue if is_private(lvalue_node.name): @@ -2546,12 +2818,9 @@ def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[ if base_type: assert base_node is not None - if not self.check_compatibility_super(lvalue, - lvalue_type, - rvalue, - base, - base_type, - base_node): + if not self.check_compatibility_super( + lvalue, lvalue_type, rvalue, base, base_type, base_node + ): # Only show one error per variable; even if other # base classes are also incompatible return True @@ -2561,9 +2830,15 @@ def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[ break return False - def check_compatibility_super(self, lvalue: RefExpr, lvalue_type: Optional[Type], - rvalue: Expression, base: TypeInfo, base_type: Type, - base_node: Node) -> bool: + def check_compatibility_super( + self, + lvalue: RefExpr, + lvalue_type: Optional[Type], + rvalue: Expression, + base: TypeInfo, + base_type: Type, + base_node: Node, + ) -> bool: lvalue_node = lvalue.node assert isinstance(lvalue_node, Var) @@ -2585,8 +2860,7 @@ def check_compatibility_super(self, lvalue: RefExpr, lvalue_type: Optional[Type] base_type = get_proper_type(base_type) compare_type = get_proper_type(compare_type) if compare_type: - if (isinstance(base_type, CallableType) and - isinstance(compare_type, CallableType)): + if isinstance(base_type, CallableType) and isinstance(compare_type, CallableType): base_static = is_node_static(base_node) compare_static = is_node_static(compare_node) @@ -2611,15 +2885,20 @@ def check_compatibility_super(self, lvalue: RefExpr, lvalue_type: Optional[Type] if base_static and compare_static: lvalue_node.is_staticmethod = True - return self.check_subtype(compare_type, base_type, rvalue, - message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, - 'expression has type', - f'base class "{base.name}" defined the type as', - code=codes.ASSIGNMENT) + return self.check_subtype( + compare_type, + base_type, + rvalue, + message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, + "expression has type", + f'base class "{base.name}" defined the type as', + code=codes.ASSIGNMENT, + ) return True - def lvalue_type_from_base(self, expr_node: Var, - base: TypeInfo) -> Tuple[Optional[Type], Optional[Node]]: + def lvalue_type_from_base( + self, expr_node: Var, base: TypeInfo + ) -> Tuple[Optional[Type], Optional[Node]]: """For a NameExpr that is part of a class, walk all base classes and try to find the first class that defines a Type for the same name.""" expr_name = expr_node.name @@ -2649,8 +2928,9 @@ def lvalue_type_from_base(self, expr_node: Var, # value, not the Callable if base_node.is_property: base_type = get_proper_type(base_type.ret_type) - if isinstance(base_type, FunctionLike) and isinstance(base_node, - OverloadedFuncDef): + if isinstance(base_type, FunctionLike) and isinstance( + base_node, OverloadedFuncDef + ): # Same for properties with setter if base_node.is_property: base_type = base_type.items[0].ret_type @@ -2659,8 +2939,9 @@ def lvalue_type_from_base(self, expr_node: Var, return None, None - def check_compatibility_classvar_super(self, node: Var, - base: TypeInfo, base_node: Optional[Node]) -> bool: + def check_compatibility_classvar_super( + self, node: Var, base: TypeInfo, base_node: Optional[Node] + ) -> bool: if not isinstance(base_node, Var): return True if node.is_classvar and not base_node.is_classvar: @@ -2671,8 +2952,9 @@ def check_compatibility_classvar_super(self, node: Var, return False return True - def check_compatibility_final_super(self, node: Var, - base: TypeInfo, base_node: Optional[Node]) -> bool: + def check_compatibility_final_super( + self, node: Var, base: TypeInfo, base_node: Optional[Node] + ) -> bool: """Check if an assignment overrides a final attribute in a base class. This only checks situations where either a node in base class is not a variable @@ -2697,10 +2979,9 @@ def check_compatibility_final_super(self, node: Var, self.check_if_final_var_override_writable(node.name, base_node, node) return True - def check_if_final_var_override_writable(self, - name: str, - base_node: Optional[Node], - ctx: Context) -> None: + def check_if_final_var_override_writable( + self, name: str, base_node: Optional[Node], ctx: Context + ) -> None: """Check that a final variable doesn't override writeable attribute. This is done to prevent situations like this: @@ -2732,8 +3013,9 @@ def enter_final_context(self, is_final_def: bool) -> Iterator[None]: finally: self._is_final_def = old_ctx - def check_final(self, - s: Union[AssignmentStmt, OperatorAssignmentStmt, AssignmentExpr]) -> None: + def check_final( + self, s: Union[AssignmentStmt, OperatorAssignmentStmt, AssignmentExpr] + ) -> None: """Check if this assignment does not assign to a final attribute. This function performs the check only for name assignments at module @@ -2752,11 +3034,16 @@ def check_final(self, assert isinstance(lv, RefExpr) if lv.node is not None: assert isinstance(lv.node, Var) - if (lv.node.final_unset_in_class and not lv.node.final_set_in_init and - not self.is_stub and # It is OK to skip initializer in stub files. - # Avoid extra error messages, if there is no type in Final[...], - # then we already reported the error about missing r.h.s. - isinstance(s, AssignmentStmt) and s.type is not None): + if ( + lv.node.final_unset_in_class + and not lv.node.final_set_in_init + and not self.is_stub + and # It is OK to skip initializer in stub files. + # Avoid extra error messages, if there is no type in Final[...], + # then we already reported the error about missing r.h.s. + isinstance(s, AssignmentStmt) + and s.type is not None + ): self.msg.final_without_value(s) for lv in lvs: if isinstance(lv, RefExpr) and isinstance(lv.node, Var): @@ -2791,7 +3078,7 @@ def check_assignment_to_slots(self, lvalue: Lvalue) -> None: if lvalue.name in inst.type.slots: return # We are assigning to an existing slot for base_info in inst.type.mro[:-1]: - if base_info.names.get('__setattr__') is not None: + if base_info.names.get("__setattr__") is not None: # When type has `__setattr__` defined, # we can assign any dynamic value. # We exclude object, because it always has `__setattr__`. @@ -2807,14 +3094,11 @@ def check_assignment_to_slots(self, lvalue: Lvalue) -> None: return self.fail( - message_registry.NAME_NOT_IN_SLOTS.format( - lvalue.name, inst.type.fullname, - ), - lvalue, + message_registry.NAME_NOT_IN_SLOTS.format(lvalue.name, inst.type.fullname), lvalue ) def is_assignable_slot(self, lvalue: Lvalue, typ: Optional[Type]) -> bool: - if getattr(lvalue, 'node', None): + if getattr(lvalue, "node", None): return False # This is a definition typ = get_proper_type(typ) @@ -2825,16 +3109,20 @@ def is_assignable_slot(self, lvalue: Lvalue, typ: Optional[Type]) -> bool: # `__set__` special method. Like `@property` does. # This makes assigning to properties possible, # even without extra slot spec. - return typ.type.get('__set__') is not None + return typ.type.get("__set__") is not None if isinstance(typ, FunctionLike): return True # Can be a property, or some other magic if isinstance(typ, UnionType): return all(self.is_assignable_slot(lvalue, u) for u in typ.items) return False - def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Expression, - context: Context, - infer_lvalue_type: bool = True) -> None: + def check_assignment_to_multiple_lvalues( + self, + lvalues: List[Lvalue], + rvalue: Expression, + context: Context, + infer_lvalue_type: bool = True, + ) -> None: if isinstance(rvalue, TupleExpr) or isinstance(rvalue, ListExpr): # Recursively go into Tuple or List expression rhs instead of # using the type of rhs, because this allowed more fine grained @@ -2849,8 +3137,9 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Ex if isinstance(typs, TupleType): rvalues.extend([TempNode(typ) for typ in typs.items]) elif self.type_is_iterable(typs) and isinstance(typs, Instance): - if (iterable_type is not None - and iterable_type != self.iterable_item_type(typs)): + if iterable_type is not None and iterable_type != self.iterable_item_type( + typs + ): self.fail(message_registry.CONTIGUOUS_ITERABLE_EXPECTED, context) else: if last_idx is None or last_idx + 1 == idx_rval: @@ -2860,8 +3149,7 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Ex else: self.fail(message_registry.CONTIGUOUS_ITERABLE_EXPECTED, context) else: - self.fail(message_registry.ITERABLE_TYPE_EXPECTED.format(typs), - context) + self.fail(message_registry.ITERABLE_TYPE_EXPECTED.format(typs), context) else: rvalues.append(rval) iterable_start: Optional[int] = None @@ -2873,26 +3161,34 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Ex if iterable_start is None: iterable_start = i iterable_end = i - if (iterable_start is not None - and iterable_end is not None - and iterable_type is not None): + if ( + iterable_start is not None + and iterable_end is not None + and iterable_type is not None + ): iterable_num = iterable_end - iterable_start + 1 rvalue_needed = len(lvalues) - (len(rvalues) - iterable_num) if rvalue_needed > 0: - rvalues = rvalues[0: iterable_start] + [TempNode(iterable_type) - for i in range(rvalue_needed)] + rvalues[iterable_end + 1:] + rvalues = ( + rvalues[0:iterable_start] + + [TempNode(iterable_type) for i in range(rvalue_needed)] + + rvalues[iterable_end + 1 :] + ) if self.check_rvalue_count_in_assignment(lvalues, len(rvalues), context): - star_index = next((i for i, lv in enumerate(lvalues) if - isinstance(lv, StarExpr)), len(lvalues)) + star_index = next( + (i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues) + ) left_lvs = lvalues[:star_index] - star_lv = cast(StarExpr, - lvalues[star_index]) if star_index != len(lvalues) else None - right_lvs = lvalues[star_index + 1:] + star_lv = ( + cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None + ) + right_lvs = lvalues[star_index + 1 :] left_rvs, star_rvs, right_rvs = self.split_around_star( - rvalues, star_index, len(lvalues)) + rvalues, star_index, len(lvalues) + ) lr_pairs = list(zip(left_lvs, left_rvs)) if star_lv: @@ -2906,24 +3202,27 @@ def check_assignment_to_multiple_lvalues(self, lvalues: List[Lvalue], rvalue: Ex else: self.check_multi_assignment(lvalues, rvalue, context, infer_lvalue_type) - def check_rvalue_count_in_assignment(self, lvalues: List[Lvalue], rvalue_count: int, - context: Context) -> bool: + def check_rvalue_count_in_assignment( + self, lvalues: List[Lvalue], rvalue_count: int, context: Context + ) -> bool: if any(isinstance(lvalue, StarExpr) for lvalue in lvalues): if len(lvalues) - 1 > rvalue_count: - self.msg.wrong_number_values_to_unpack(rvalue_count, - len(lvalues) - 1, context) + self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues) - 1, context) return False elif rvalue_count != len(lvalues): self.msg.wrong_number_values_to_unpack(rvalue_count, len(lvalues), context) return False return True - def check_multi_assignment(self, lvalues: List[Lvalue], - rvalue: Expression, - context: Context, - infer_lvalue_type: bool = True, - rv_type: Optional[Type] = None, - undefined_rvalue: bool = False) -> None: + def check_multi_assignment( + self, + lvalues: List[Lvalue], + rvalue: Expression, + context: Context, + infer_lvalue_type: bool = True, + rv_type: Optional[Type] = None, + undefined_rvalue: bool = False, + ) -> None: """Check the assignment of one rvalue to a number of lvalues.""" # Infer the type of an ordinary rvalue expression. @@ -2940,24 +3239,33 @@ def check_multi_assignment(self, lvalues: List[Lvalue], for lv in lvalues: if isinstance(lv, StarExpr): lv = lv.expr - temp_node = self.temp_node(AnyType(TypeOfAny.from_another_any, - source_any=rvalue_type), context) + temp_node = self.temp_node( + AnyType(TypeOfAny.from_another_any, source_any=rvalue_type), context + ) self.check_assignment(lv, temp_node, infer_lvalue_type) elif isinstance(rvalue_type, TupleType): - self.check_multi_assignment_from_tuple(lvalues, rvalue, rvalue_type, - context, undefined_rvalue, infer_lvalue_type) + self.check_multi_assignment_from_tuple( + lvalues, rvalue, rvalue_type, context, undefined_rvalue, infer_lvalue_type + ) elif isinstance(rvalue_type, UnionType): - self.check_multi_assignment_from_union(lvalues, rvalue, rvalue_type, context, - infer_lvalue_type) - elif isinstance(rvalue_type, Instance) and rvalue_type.type.fullname == 'builtins.str': + self.check_multi_assignment_from_union( + lvalues, rvalue, rvalue_type, context, infer_lvalue_type + ) + elif isinstance(rvalue_type, Instance) and rvalue_type.type.fullname == "builtins.str": self.msg.unpacking_strings_disallowed(context) else: - self.check_multi_assignment_from_iterable(lvalues, rvalue_type, - context, infer_lvalue_type) + self.check_multi_assignment_from_iterable( + lvalues, rvalue_type, context, infer_lvalue_type + ) - def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: Expression, - rvalue_type: UnionType, context: Context, - infer_lvalue_type: bool) -> None: + def check_multi_assignment_from_union( + self, + lvalues: List[Expression], + rvalue: Expression, + rvalue_type: UnionType, + context: Context, + infer_lvalue_type: bool, + ) -> None: """Check assignment to multiple lvalue targets when rvalue type is a Union[...]. For example: @@ -2977,9 +3285,14 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E for item in rvalue_type.items: # Type check the assignment separately for each union item and collect # the inferred lvalue types for each union item. - self.check_multi_assignment(lvalues, rvalue, context, - infer_lvalue_type=infer_lvalue_type, - rv_type=item, undefined_rvalue=True) + self.check_multi_assignment( + lvalues, + rvalue, + context, + infer_lvalue_type=infer_lvalue_type, + rv_type=item, + undefined_rvalue=True, + ) for t, lv in zip(transposed, self.flatten_lvalues(lvalues)): # We can access _type_maps directly since temporary type maps are # only created within expressions. @@ -2999,10 +3312,12 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E # TODO: fix signature of zip() in typeshed. types, declared_types = cast(Any, zip)(*clean_items) - self.binder.assign_type(expr, - make_simplified_union(list(types)), - make_simplified_union(list(declared_types)), - False) + self.binder.assign_type( + expr, + make_simplified_union(list(types)), + make_simplified_union(list(declared_types)), + False, + ) for union, lv in zip(union_types, self.flatten_lvalues(lvalues)): # Properly store the inferred types. _1, _2, inferred = self.check_lvalue(lv) @@ -3023,23 +3338,30 @@ def flatten_lvalues(self, lvalues: List[Expression]) -> List[Expression]: res.append(lv) return res - def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expression, - rvalue_type: TupleType, context: Context, - undefined_rvalue: bool, - infer_lvalue_type: bool = True) -> None: + def check_multi_assignment_from_tuple( + self, + lvalues: List[Lvalue], + rvalue: Expression, + rvalue_type: TupleType, + context: Context, + undefined_rvalue: bool, + infer_lvalue_type: bool = True, + ) -> None: if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context): - star_index = next((i for i, lv in enumerate(lvalues) - if isinstance(lv, StarExpr)), len(lvalues)) + star_index = next( + (i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues) + ) left_lvs = lvalues[:star_index] star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None - right_lvs = lvalues[star_index + 1:] + right_lvs = lvalues[star_index + 1 :] if not undefined_rvalue: # Infer rvalue again, now in the correct type context. lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type) - reinferred_rvalue_type = get_proper_type(self.expr_checker.accept(rvalue, - lvalue_type)) + reinferred_rvalue_type = get_proper_type( + self.expr_checker.accept(rvalue, lvalue_type) + ) if isinstance(reinferred_rvalue_type, UnionType): # If this is an Optional type in non-strict Optional code, unwrap it. @@ -3047,9 +3369,9 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre if len(relevant_items) == 1: reinferred_rvalue_type = get_proper_type(relevant_items[0]) if isinstance(reinferred_rvalue_type, UnionType): - self.check_multi_assignment_from_union(lvalues, rvalue, - reinferred_rvalue_type, context, - infer_lvalue_type) + self.check_multi_assignment_from_union( + lvalues, rvalue, reinferred_rvalue_type, context, infer_lvalue_type + ) return if isinstance(reinferred_rvalue_type, AnyType): # We can get Any if the current node is @@ -3063,26 +3385,30 @@ def check_multi_assignment_from_tuple(self, lvalues: List[Lvalue], rvalue: Expre rvalue_type = reinferred_rvalue_type left_rv_types, star_rv_types, right_rv_types = self.split_around_star( - rvalue_type.items, star_index, len(lvalues)) + rvalue_type.items, star_index, len(lvalues) + ) for lv, rv_type in zip(left_lvs, left_rv_types): self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type) if star_lv: - list_expr = ListExpr([self.temp_node(rv_type, context) - for rv_type in star_rv_types]) + list_expr = ListExpr( + [self.temp_node(rv_type, context) for rv_type in star_rv_types] + ) list_expr.set_line(context.get_line()) self.check_assignment(star_lv.expr, list_expr, infer_lvalue_type) for lv, rv_type in zip(right_lvs, right_rv_types): self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type) def lvalue_type_for_inference(self, lvalues: List[Lvalue], rvalue_type: TupleType) -> Type: - star_index = next((i for i, lv in enumerate(lvalues) - if isinstance(lv, StarExpr)), len(lvalues)) + star_index = next( + (i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues) + ) left_lvs = lvalues[:star_index] star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None - right_lvs = lvalues[star_index + 1:] + right_lvs = lvalues[star_index + 1 :] left_rv_types, star_rv_types, right_rv_types = self.split_around_star( - rvalue_type.items, star_index, len(lvalues)) + rvalue_type.items, star_index, len(lvalues) + ) type_parameters: List[Type] = [] @@ -3109,10 +3435,11 @@ def append_types_for_inference(lvs: List[Expression], rv_types: List[Type]) -> N append_types_for_inference(right_lvs, right_rv_types) - return TupleType(type_parameters, self.named_type('builtins.tuple')) + return TupleType(type_parameters, self.named_type("builtins.tuple")) - def split_around_star(self, items: List[T], star_index: int, - length: int) -> Tuple[List[T], List[T], List[T]]: + def split_around_star( + self, items: List[T], star_index: int, length: int + ) -> Tuple[List[T], List[T], List[T]]: """Splits a list of items in three to match another list of length 'length' that contains a starred expression at 'star_index' in the following way: @@ -3130,29 +3457,36 @@ def type_is_iterable(self, type: Type) -> bool: type = get_proper_type(type) if isinstance(type, CallableType) and type.is_type_obj(): type = type.fallback - return is_subtype(type, self.named_generic_type('typing.Iterable', - [AnyType(TypeOfAny.special_form)])) + return is_subtype( + type, self.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)]) + ) - def check_multi_assignment_from_iterable(self, lvalues: List[Lvalue], rvalue_type: Type, - context: Context, - infer_lvalue_type: bool = True) -> None: + def check_multi_assignment_from_iterable( + self, + lvalues: List[Lvalue], + rvalue_type: Type, + context: Context, + infer_lvalue_type: bool = True, + ) -> None: rvalue_type = get_proper_type(rvalue_type) if self.type_is_iterable(rvalue_type) and isinstance(rvalue_type, Instance): item_type = self.iterable_item_type(rvalue_type) for lv in lvalues: if isinstance(lv, StarExpr): - items_type = self.named_generic_type('builtins.list', [item_type]) - self.check_assignment(lv.expr, self.temp_node(items_type, context), - infer_lvalue_type) + items_type = self.named_generic_type("builtins.list", [item_type]) + self.check_assignment( + lv.expr, self.temp_node(items_type, context), infer_lvalue_type + ) else: - self.check_assignment(lv, self.temp_node(item_type, context), - infer_lvalue_type) + self.check_assignment( + lv, self.temp_node(item_type, context), infer_lvalue_type + ) else: self.msg.type_not_iterable(rvalue_type, context) - def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type], - Optional[IndexExpr], - Optional[Var]]: + def check_lvalue( + self, lvalue: Lvalue + ) -> Tuple[Optional[Type], Optional[IndexExpr], Optional[Var]]: lvalue_type = None index_lvalue = None inferred = None @@ -3176,11 +3510,14 @@ def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type], lvalue_type = self.expr_checker.analyze_ref_expr(lvalue, lvalue=True) self.store_type(lvalue, lvalue_type) elif isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): - types = [self.check_lvalue(sub_expr)[0] or - # This type will be used as a context for further inference of rvalue, - # we put Uninhabited if there is no information available from lvalue. - UninhabitedType() for sub_expr in lvalue.items] - lvalue_type = TupleType(types, self.named_type('builtins.tuple')) + types = [ + self.check_lvalue(sub_expr)[0] or + # This type will be used as a context for further inference of rvalue, + # we put Uninhabited if there is no information available from lvalue. + UninhabitedType() + for sub_expr in lvalue.items + ] + lvalue_type = TupleType(types, self.named_type("builtins.tuple")) elif isinstance(lvalue, StarExpr): typ, _, _ = self.check_lvalue(lvalue.expr) lvalue_type = StarType(typ) if typ else None @@ -3205,8 +3542,9 @@ def is_definition(self, s: Lvalue) -> bool: return s.is_inferred_def return False - def infer_variable_type(self, name: Var, lvalue: Lvalue, - init_type: Type, context: Context) -> None: + def infer_variable_type( + self, name: Var, lvalue: Lvalue, init_type: Type, context: Context + ) -> None: """Infer the type of initialized variables from initializer type.""" init_type = get_proper_type(init_type) if isinstance(init_type, DeletedType): @@ -3219,9 +3557,13 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue, if not self.infer_partial_type(name, lvalue, init_type): self.msg.need_annotation_for_var(name, context, self.options.python_version) self.set_inference_error_fallback_type(name, lvalue, init_type) - elif (isinstance(lvalue, MemberExpr) and self.inferred_attribute_types is not None - and lvalue.def_var and lvalue.def_var in self.inferred_attribute_types - and not is_same_type(self.inferred_attribute_types[lvalue.def_var], init_type)): + elif ( + isinstance(lvalue, MemberExpr) + and self.inferred_attribute_types is not None + and lvalue.def_var + and lvalue.def_var in self.inferred_attribute_types + and not is_same_type(self.inferred_attribute_types[lvalue.def_var], init_type) + ): # Multiple, inconsistent types inferred for an attribute. self.msg.need_annotation_for_var(name, context, self.options.python_version) name.type = AnyType(TypeOfAny.from_error) @@ -3240,19 +3582,26 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool elif isinstance(init_type, Instance): fullname = init_type.type.fullname is_ref = isinstance(lvalue, RefExpr) - if (is_ref and - (fullname == 'builtins.list' or - fullname == 'builtins.set' or - fullname == 'builtins.dict' or - fullname == 'collections.OrderedDict') and - all(isinstance(t, (NoneType, UninhabitedType)) - for t in get_proper_types(init_type.args))): + if ( + is_ref + and ( + fullname == "builtins.list" + or fullname == "builtins.set" + or fullname == "builtins.dict" + or fullname == "collections.OrderedDict" + ) + and all( + isinstance(t, (NoneType, UninhabitedType)) + for t in get_proper_types(init_type.args) + ) + ): partial_type = PartialType(init_type.type, name) - elif is_ref and fullname == 'collections.defaultdict': + elif is_ref and fullname == "collections.defaultdict": arg0 = get_proper_type(init_type.args[0]) arg1 = get_proper_type(init_type.args[1]) - if (isinstance(arg0, (NoneType, UninhabitedType)) and - self.is_valid_defaultdict_partial_value_type(arg1)): + if isinstance( + arg0, (NoneType, UninhabitedType) + ) and self.is_valid_defaultdict_partial_value_type(arg1): arg1 = erase_type(arg1) assert isinstance(arg1, Instance) partial_type = PartialType(init_type.type, name, arg1) @@ -3328,33 +3677,46 @@ def inference_error_fallback_type(self, type: Type) -> Type: # we therefore need to erase them. return erase_typevars(fallback) - def check_simple_assignment(self, lvalue_type: Optional[Type], rvalue: Expression, - context: Context, - msg: str = message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, - lvalue_name: str = 'variable', - rvalue_name: str = 'expression', *, - code: Optional[ErrorCode] = None) -> Type: + def check_simple_assignment( + self, + lvalue_type: Optional[Type], + rvalue: Expression, + context: Context, + msg: str = message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, + lvalue_name: str = "variable", + rvalue_name: str = "expression", + *, + code: Optional[ErrorCode] = None, + ) -> Type: if self.is_stub and isinstance(rvalue, EllipsisExpr): # '...' is always a valid initializer in a stub. return AnyType(TypeOfAny.special_form) else: lvalue_type = get_proper_type(lvalue_type) always_allow_any = lvalue_type is not None and not isinstance(lvalue_type, AnyType) - rvalue_type = self.expr_checker.accept(rvalue, lvalue_type, - always_allow_any=always_allow_any) + rvalue_type = self.expr_checker.accept( + rvalue, lvalue_type, always_allow_any=always_allow_any + ) rvalue_type = get_proper_type(rvalue_type) if isinstance(rvalue_type, DeletedType): self.msg.deleted_as_rvalue(rvalue_type, context) if isinstance(lvalue_type, DeletedType): self.msg.deleted_as_lvalue(lvalue_type, context) elif lvalue_type: - self.check_subtype(rvalue_type, lvalue_type, context, msg, - f'{rvalue_name} has type', - f'{lvalue_name} has type', code=code) + self.check_subtype( + rvalue_type, + lvalue_type, + context, + msg, + f"{rvalue_name} has type", + f"{lvalue_name} has type", + code=code, + ) return rvalue_type - def check_member_assignment(self, instance_type: Type, attribute_type: Type, - rvalue: Expression, context: Context) -> Tuple[Type, Type, bool]: + def check_member_assignment( + self, instance_type: Type, attribute_type: Type, rvalue: Expression, context: Context + ) -> Tuple[Type, Type, bool]: """Type member assignment. This defers to check_simple_assignment, unless the member expression @@ -3369,50 +3731,66 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type, instance_type = get_proper_type(instance_type) attribute_type = get_proper_type(attribute_type) # Descriptors don't participate in class-attribute access - if ((isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or - isinstance(instance_type, TypeType)): - rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context, - code=codes.ASSIGNMENT) + if (isinstance(instance_type, FunctionLike) and instance_type.is_type_obj()) or isinstance( + instance_type, TypeType + ): + rvalue_type = self.check_simple_assignment( + attribute_type, rvalue, context, code=codes.ASSIGNMENT + ) return rvalue_type, attribute_type, True if not isinstance(attribute_type, Instance): # TODO: support __set__() for union types. - rvalue_type = self.check_simple_assignment(attribute_type, rvalue, context, - code=codes.ASSIGNMENT) + rvalue_type = self.check_simple_assignment( + attribute_type, rvalue, context, code=codes.ASSIGNMENT + ) return rvalue_type, attribute_type, True mx = MemberContext( - is_lvalue=False, is_super=False, is_operator=False, - original_type=instance_type, context=context, self_type=None, - msg=self.msg, chk=self, + is_lvalue=False, + is_super=False, + is_operator=False, + original_type=instance_type, + context=context, + self_type=None, + msg=self.msg, + chk=self, ) get_type = analyze_descriptor_access(attribute_type, mx) - if not attribute_type.type.has_readable_member('__set__'): + if not attribute_type.type.has_readable_member("__set__"): # If there is no __set__, we type-check that the assigned value matches # the return type of __get__. This doesn't match the python semantics, # (which allow you to override the descriptor with any value), but preserves # the type of accessing the attribute (even after the override). - rvalue_type = self.check_simple_assignment(get_type, rvalue, context, - code=codes.ASSIGNMENT) + rvalue_type = self.check_simple_assignment( + get_type, rvalue, context, code=codes.ASSIGNMENT + ) return rvalue_type, get_type, True - dunder_set = attribute_type.type.get_method('__set__') + dunder_set = attribute_type.type.get_method("__set__") if dunder_set is None: self.fail(message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(attribute_type), context) return AnyType(TypeOfAny.from_error), get_type, False bound_method = analyze_decorator_or_funcbase_access( - defn=dunder_set, itype=attribute_type, info=attribute_type.type, - self_type=attribute_type, name='__set__', mx=mx) + defn=dunder_set, + itype=attribute_type, + info=attribute_type.type, + self_type=attribute_type, + name="__set__", + mx=mx, + ) typ = map_instance_to_supertype(attribute_type, dunder_set.info) dunder_set_type = expand_type_by_instance(bound_method, typ) callable_name = self.expr_checker.method_fullname(attribute_type, "__set__") dunder_set_type = self.expr_checker.transform_callee_type( - callable_name, dunder_set_type, + callable_name, + dunder_set_type, [TempNode(instance_type, context=context), rvalue], [nodes.ARG_POS, nodes.ARG_POS], - context, object_type=attribute_type, + context, + object_type=attribute_type, ) # For non-overloaded setters, the result should be type-checked like a regular assignment. @@ -3423,8 +3801,10 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type, dunder_set_type, [TempNode(instance_type, context=context), type_context], [nodes.ARG_POS, nodes.ARG_POS], - context, object_type=attribute_type, - callable_name=callable_name) + context, + object_type=attribute_type, + callable_name=callable_name, + ) # And now we in fact type check the call, to show errors related to wrong arguments # count, etc., replacing the type context for non-overloaded setters only. @@ -3435,12 +3815,15 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type, dunder_set_type, [TempNode(instance_type, context=context), type_context], [nodes.ARG_POS, nodes.ARG_POS], - context, object_type=attribute_type, - callable_name=callable_name) + context, + object_type=attribute_type, + callable_name=callable_name, + ) # In the following cases, a message already will have been recorded in check_call. - if ((not isinstance(inferred_dunder_set_type, CallableType)) or - (len(inferred_dunder_set_type.arg_types) < 2)): + if (not isinstance(inferred_dunder_set_type, CallableType)) or ( + len(inferred_dunder_set_type.arg_types) < 2 + ): return AnyType(TypeOfAny.from_error), get_type, False set_type = inferred_dunder_set_type.arg_types[1] @@ -3448,13 +3831,15 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type, # and '__get__' type is narrower than '__set__', then we invoke the binder to narrow type # by this assignment. Technically, this is not safe, but in practice this is # what a user expects. - rvalue_type = self.check_simple_assignment(set_type, rvalue, context, - code=codes.ASSIGNMENT) + rvalue_type = self.check_simple_assignment( + set_type, rvalue, context, code=codes.ASSIGNMENT + ) infer = is_subtype(rvalue_type, get_type) and is_subtype(get_type, set_type) return rvalue_type if infer else set_type, get_type, infer - def check_indexed_assignment(self, lvalue: IndexExpr, - rvalue: Expression, context: Context) -> None: + def check_indexed_assignment( + self, lvalue: IndexExpr, rvalue: Expression, context: Context + ) -> None: """Type check indexed assignment base[index] = rvalue. The lvalue argument is the base[index] expression. @@ -3462,15 +3847,22 @@ def check_indexed_assignment(self, lvalue: IndexExpr, self.try_infer_partial_type_from_indexed_assignment(lvalue, rvalue) basetype = get_proper_type(self.expr_checker.accept(lvalue.base)) method_type = self.expr_checker.analyze_external_member_access( - '__setitem__', basetype, lvalue) + "__setitem__", basetype, lvalue + ) lvalue.method_type = method_type self.expr_checker.check_method_call( - '__setitem__', basetype, method_type, [lvalue.index, rvalue], - [nodes.ARG_POS, nodes.ARG_POS], context) + "__setitem__", + basetype, + method_type, + [lvalue.index, rvalue], + [nodes.ARG_POS, nodes.ARG_POS], + context, + ) def try_infer_partial_type_from_indexed_assignment( - self, lvalue: IndexExpr, rvalue: Expression) -> None: + self, lvalue: IndexExpr, rvalue: Expression + ) -> None: # TODO: Should we share some of this with try_infer_partial_type? var = None if isinstance(lvalue.base, RefExpr) and isinstance(lvalue.base.node, Var): @@ -3486,20 +3878,25 @@ def try_infer_partial_type_from_indexed_assignment( if partial_types is None: return typename = type_type.fullname - if (typename == 'builtins.dict' - or typename == 'collections.OrderedDict' - or typename == 'collections.defaultdict'): + if ( + typename == "builtins.dict" + or typename == "collections.OrderedDict" + or typename == "collections.defaultdict" + ): # TODO: Don't infer things twice. key_type = self.expr_checker.accept(lvalue.index) value_type = self.expr_checker.accept(rvalue) - if (is_valid_inferred_type(key_type) and - is_valid_inferred_type(value_type) and - not self.current_node_deferred and - not (typename == 'collections.defaultdict' and - var.type.value_type is not None and - not is_equivalent(value_type, var.type.value_type))): - var.type = self.named_generic_type(typename, - [key_type, value_type]) + if ( + is_valid_inferred_type(key_type) + and is_valid_inferred_type(value_type) + and not self.current_node_deferred + and not ( + typename == "collections.defaultdict" + and var.type.value_type is not None + and not is_equivalent(value_type, var.type.value_type) + ) + ): + var.type = self.named_generic_type(typename, [key_type, value_type]) del partial_types[var] def type_requires_usage(self, typ: Type) -> Optional[Tuple[str, ErrorCode]]: @@ -3538,8 +3935,9 @@ def check_return_stmt(self, s: ReturnStmt) -> None: defn = self.scope.top_function() if defn is not None: if defn.is_generator: - return_type = self.get_generator_return_type(self.return_types[-1], - defn.is_coroutine) + return_type = self.get_generator_return_type( + self.return_types[-1], defn.is_coroutine + ) elif defn.is_coroutine: return_type = self.get_coroutine_return_type(self.return_types[-1]) else: @@ -3563,8 +3961,11 @@ def check_return_stmt(self, s: ReturnStmt) -> None: allow_none_func_call = is_lambda or declared_none_return or declared_any_return # Return with a value. - typ = get_proper_type(self.expr_checker.accept( - s.expr, return_type, allow_none_return=allow_none_func_call)) + typ = get_proper_type( + self.expr_checker.accept( + s.expr, return_type, allow_none_return=allow_none_func_call + ) + ) if defn.is_async_generator: self.fail(message_registry.RETURN_IN_ASYNC_GENERATOR, s) @@ -3573,13 +3974,19 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if isinstance(typ, AnyType): # (Unless you asked to be warned in that case, and the # function is not declared to return Any) - if (self.options.warn_return_any + if ( + self.options.warn_return_any and not self.current_node_deferred and not is_proper_subtype(AnyType(TypeOfAny.special_form), return_type) - and not (defn.name in BINARY_MAGIC_METHODS and - is_literal_not_implemented(s.expr)) - and not (isinstance(return_type, Instance) and - return_type.type.fullname == 'builtins.object')): + and not ( + defn.name in BINARY_MAGIC_METHODS + and is_literal_not_implemented(s.expr) + ) + and not ( + isinstance(return_type, Instance) + and return_type.type.fullname == "builtins.object" + ) + ): self.msg.incorrectly_returning_any(return_type, s) return @@ -3593,19 +4000,23 @@ def check_return_stmt(self, s: ReturnStmt) -> None: self.fail(message_registry.NO_RETURN_VALUE_EXPECTED, s) else: self.check_subtype( - subtype_label='got', + subtype_label="got", subtype=typ, - supertype_label='expected', + supertype_label="expected", supertype=return_type, context=s.expr, outer_context=s, msg=message_registry.INCOMPATIBLE_RETURN_VALUE_TYPE, - code=codes.RETURN_VALUE) + code=codes.RETURN_VALUE, + ) else: # Empty returns are valid in Generators with Any typed returns, but not in # coroutines. - if (defn.is_generator and not defn.is_coroutine and - isinstance(return_type, AnyType)): + if ( + defn.is_generator + and not defn.is_coroutine + and isinstance(return_type, AnyType) + ): return if isinstance(return_type, (NoneType, AnyType)): @@ -3643,11 +4054,9 @@ def visit_while_stmt(self, s: WhileStmt) -> None: """Type check a while statement.""" if_stmt = IfStmt([s.expr], [s.body], None) if_stmt.set_line(s.get_line(), s.get_column()) - self.accept_loop(if_stmt, s.else_body, - exit_condition=s.expr) + self.accept_loop(if_stmt, s.else_body, exit_condition=s.expr) - def visit_operator_assignment_stmt(self, - s: OperatorAssignmentStmt) -> None: + def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None: """Type check an operator assignment statement, e.g. x += 1.""" self.try_infer_partial_generic_type_from_assignment(s.lvalue, s.rvalue, s.op) if isinstance(s.lvalue, MemberExpr): @@ -3659,16 +4068,16 @@ def visit_operator_assignment_stmt(self, inplace, method = infer_operator_assignment_method(lvalue_type, s.op) if inplace: # There is __ifoo__, treat as x = x.__ifoo__(y) - rvalue_type, method_type = self.expr_checker.check_op( - method, lvalue_type, s.rvalue, s) + rvalue_type, method_type = self.expr_checker.check_op(method, lvalue_type, s.rvalue, s) if not is_subtype(rvalue_type, lvalue_type): self.msg.incompatible_operator_assignment(s.op, s) else: # There is no __ifoo__, treat as x = x y expr = OpExpr(s.op, s.lvalue, s.rvalue) expr.set_line(s) - self.check_assignment(lvalue=s.lvalue, rvalue=expr, - infer_lvalue_type=True, new_syntax=False) + self.check_assignment( + lvalue=s.lvalue, rvalue=expr, infer_lvalue_type=True, new_syntax=False + ) self.check_final(s) def visit_assert_stmt(self, s: AssertStmt) -> None: @@ -3691,8 +4100,7 @@ def visit_raise_stmt(self, s: RaiseStmt) -> None: self.type_check_raise(s.from_expr, s, optional=True) self.binder.unreachable() - def type_check_raise(self, e: Expression, s: RaiseStmt, - optional: bool = False) -> None: + def type_check_raise(self, e: Expression, s: RaiseStmt, optional: bool = False) -> None: typ = get_proper_type(self.expr_checker.accept(e)) if isinstance(typ, DeletedType): self.msg.deleted_as_rvalue(typ, e) @@ -3705,7 +4113,7 @@ def type_check_raise(self, e: Expression, s: RaiseStmt, return # Python3 case: - exc_type = self.named_type('builtins.BaseException') + exc_type = self.named_type("builtins.BaseException") expected_type_items = [exc_type, TypeType(exc_type)] if optional: # This is used for `x` part in a case like `raise e from x`, @@ -3713,8 +4121,7 @@ def type_check_raise(self, e: Expression, s: RaiseStmt, expected_type_items.append(NoneType()) self.check_subtype( - typ, UnionType.make_union(expected_type_items), s, - message_registry.INVALID_EXCEPTION, + typ, UnionType.make_union(expected_type_items), s, message_registry.INVALID_EXCEPTION ) if isinstance(typ, FunctionLike): @@ -3733,17 +4140,21 @@ def _type_check_raise_python2(self, e: Expression, s: RaiseStmt, typ: ProperType # - `traceback` is `types.TracebackType | None` # Important note: `raise exc, msg` is not the same as `raise (exc, msg)` # We call `raise exc, msg, traceback` - legacy mode. - exc_type = self.named_type('builtins.BaseException') + exc_type = self.named_type("builtins.BaseException") exc_inst_or_type = UnionType([exc_type, TypeType(exc_type)]) - if (not s.legacy_mode and (isinstance(typ, TupleType) and typ.items - or (isinstance(typ, Instance) and typ.args - and typ.type.fullname == 'builtins.tuple'))): + if not s.legacy_mode and ( + isinstance(typ, TupleType) + and typ.items + or (isinstance(typ, Instance) and typ.args and typ.type.fullname == "builtins.tuple") + ): # `raise (exc, ...)` case: item = typ.items[0] if isinstance(typ, TupleType) else typ.args[0] self.check_subtype( - item, exc_inst_or_type, s, - 'When raising a tuple, first element must by derived from BaseException', + item, + exc_inst_or_type, + s, + "When raising a tuple, first element must by derived from BaseException", ) return elif s.legacy_mode: @@ -3751,14 +4162,12 @@ def _type_check_raise_python2(self, e: Expression, s: RaiseStmt, typ: ProperType # `raise Exception, msg, traceback` case # https://docs.python.org/2/reference/simple_stmts.html#the-raise-statement assert isinstance(typ, TupleType) # Is set in fastparse2.py - if (len(typ.items) >= 2 - and isinstance(get_proper_type(typ.items[1]), NoneType)): + if len(typ.items) >= 2 and isinstance(get_proper_type(typ.items[1]), NoneType): expected_type: Type = exc_inst_or_type else: expected_type = TypeType(exc_type) self.check_subtype( - typ.items[0], expected_type, s, - f'Argument 1 must be "{expected_type}" subtype', + typ.items[0], expected_type, s, f'Argument 1 must be "{expected_type}" subtype' ) # Typecheck `traceback` part: @@ -3767,22 +4176,26 @@ def _type_check_raise_python2(self, e: Expression, s: RaiseStmt, typ: ProperType # We do this after the main check for better error message # and better ordering: first about `BaseException` subtype, # then about `traceback` type. - traceback_type = UnionType.make_union([ - self.named_type('types.TracebackType'), - NoneType(), - ]) + traceback_type = UnionType.make_union( + [self.named_type("types.TracebackType"), NoneType()] + ) self.check_subtype( - typ.items[2], traceback_type, s, + typ.items[2], + traceback_type, + s, f'Argument 3 must be "{traceback_type}" subtype', ) else: expected_type_items = [ # `raise Exception` and `raise Exception()` cases: - exc_type, TypeType(exc_type), + exc_type, + TypeType(exc_type), ] self.check_subtype( - typ, UnionType.make_union(expected_type_items), - s, message_registry.INVALID_EXCEPTION, + typ, + UnionType.make_union(expected_type_items), + s, + message_registry.INVALID_EXCEPTION, ) def visit_try_stmt(self, s: TryStmt) -> None: @@ -3865,9 +4278,11 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None: if self.options.python_version[0] >= 3: source = var.name else: - source = ('(exception variable "{}", which we do not ' - 'accept outside except: blocks even in ' - 'python 2)'.format(var.name)) + source = ( + '(exception variable "{}", which we do not ' + "accept outside except: blocks even in " + "python 2)".format(var.name) + ) if isinstance(var.node, Var): var.node.type = DeletedType(source=source) self.binder.cleanse(var) @@ -3898,7 +4313,7 @@ def check_except_handler_test(self, n: Expression) -> Type: self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) return AnyType(TypeOfAny.from_error) - if not is_subtype(exc_type, self.named_type('builtins.BaseException')): + if not is_subtype(exc_type, self.named_type("builtins.BaseException")): self.fail(message_registry.INVALID_EXCEPTION_TYPE, n) return AnyType(TypeOfAny.from_error) @@ -3917,7 +4332,7 @@ def get_types_from_except_handler(self, typ: Type, n: Expression) -> List[Type]: for item in typ.relevant_items() for union_typ in self.get_types_from_except_handler(item, n) ] - elif isinstance(typ, Instance) and is_named_instance(typ, 'builtins.tuple'): + elif isinstance(typ, Instance) and is_named_instance(typ, "builtins.tuple"): # variadic tuple return [typ.args[0]] else: @@ -3938,17 +4353,18 @@ def analyze_async_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type """Analyse async iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = echk.accept(expr) - iterator = echk.check_method_call_by_name('__aiter__', iterable, [], [], expr)[0] - awaitable = echk.check_method_call_by_name('__anext__', iterator, [], [], expr)[0] - item_type = echk.check_awaitable_expr(awaitable, expr, - message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_FOR) + iterator = echk.check_method_call_by_name("__aiter__", iterable, [], [], expr)[0] + awaitable = echk.check_method_call_by_name("__anext__", iterator, [], [], expr)[0] + item_type = echk.check_awaitable_expr( + awaitable, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_FOR + ) return iterator, item_type def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]: """Analyse iterable expression and return iterator and iterator item types.""" echk = self.expr_checker iterable = get_proper_type(echk.accept(expr)) - iterator = echk.check_method_call_by_name('__iter__', iterable, [], [], expr)[0] + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] if isinstance(iterable, TupleType): joined: Type = UninhabitedType() @@ -3958,9 +4374,9 @@ def analyze_iterable_item_type(self, expr: Expression) -> Tuple[Type, Type]: else: # Non-tuple iterable. if self.options.python_version[0] >= 3: - nextmethod = '__next__' + nextmethod = "__next__" else: - nextmethod = 'next' + nextmethod = "next" return iterator, echk.check_method_call_by_name(nextmethod, iterator, [], [], expr)[0] def analyze_container_item_type(self, typ: Type) -> Optional[Type]: @@ -3976,8 +4392,8 @@ def analyze_container_item_type(self, typ: Type) -> Optional[Type]: if c_type: types.append(c_type) return UnionType.make_union(types) - if isinstance(typ, Instance) and typ.type.has_base('typing.Container'): - supertype = self.named_type('typing.Container').type + if isinstance(typ, Instance) and typ.type.has_base("typing.Container"): + supertype = self.named_type("typing.Container").type super_instance = map_instance_to_supertype(typ, supertype) assert len(super_instance.args) == 1 return super_instance.args[0] @@ -3985,15 +4401,16 @@ def analyze_container_item_type(self, typ: Type) -> Optional[Type]: return self.analyze_container_item_type(tuple_fallback(typ)) return None - def analyze_index_variables(self, index: Expression, item_type: Type, - infer_lvalue_type: bool, context: Context) -> None: + def analyze_index_variables( + self, index: Expression, item_type: Type, infer_lvalue_type: bool, context: Context + ) -> None: """Type check or infer for loop or list comprehension index vars.""" self.check_assignment(index, self.temp_node(item_type, context), infer_lvalue_type) def visit_del_stmt(self, s: DelStmt) -> None: if isinstance(s.expr, IndexExpr): e = s.expr - m = MemberExpr(e.base, '__delitem__') + m = MemberExpr(e.base, "__delitem__") m.line = s.line m.column = s.column c = CallExpr(m, [e.index], [nodes.ARG_POS], [None]) @@ -4004,13 +4421,14 @@ def visit_del_stmt(self, s: DelStmt) -> None: s.expr.accept(self.expr_checker) for elt in flatten(s.expr): if isinstance(elt, NameExpr): - self.binder.assign_type(elt, DeletedType(source=elt.name), - get_declaration(elt), False) + self.binder.assign_type( + elt, DeletedType(source=elt.name), get_declaration(elt), False + ) def visit_decorator(self, e: Decorator) -> None: for d in e.decorators: if isinstance(d, RefExpr): - if d.fullname == 'typing.no_type_check': + if d.fullname == "typing.no_type_check": e.var.type = AnyType(TypeOfAny.special_form) e.var.is_ready = True return @@ -4038,10 +4456,9 @@ def visit_decorator(self, e: Decorator) -> None: object_type = self.lookup_type(d.expr) fullname = self.expr_checker.method_fullname(object_type, d.name) self.check_for_untyped_decorator(e.func, dec, d) - sig, t2 = self.expr_checker.check_call(dec, [temp], - [nodes.ARG_POS], e, - callable_name=fullname, - object_type=object_type) + sig, t2 = self.expr_checker.check_call( + dec, [temp], [nodes.ARG_POS], e, callable_name=fullname, object_type=object_type + ) self.check_untyped_after_decorator(sig, e.func) sig = set_callable_name(sig, e.func) e.var.type = sig @@ -4051,17 +4468,18 @@ def visit_decorator(self, e: Decorator) -> None: if e.func.info and not e.func.is_dynamic(): self.check_method_override(e) - if e.func.info and e.func.name in ('__init__', '__new__'): + if e.func.info and e.func.name in ("__init__", "__new__"): if e.type and not isinstance(get_proper_type(e.type), (FunctionLike, AnyType)): self.fail(message_registry.BAD_CONSTRUCTOR_TYPE, e) - def check_for_untyped_decorator(self, - func: FuncDef, - dec_type: Type, - dec_expr: Expression) -> None: - if (self.options.disallow_untyped_decorators and - is_typed_callable(func.type) and - is_untyped_decorator(dec_type)): + def check_for_untyped_decorator( + self, func: FuncDef, dec_type: Type, dec_expr: Expression + ) -> None: + if ( + self.options.disallow_untyped_decorators + and is_typed_callable(func.type) + and is_untyped_decorator(dec_type) + ): self.msg.typed_function_untyped_decorator(func.name, dec_expr) def check_incompatible_property_override(self, e: Decorator) -> None: @@ -4071,10 +4489,11 @@ def check_incompatible_property_override(self, e: Decorator) -> None: base_attr = base.names.get(name) if not base_attr: continue - if (isinstance(base_attr.node, OverloadedFuncDef) and - base_attr.node.is_property and - cast(Decorator, - base_attr.node.items[0]).var.is_settable_property): + if ( + isinstance(base_attr.node, OverloadedFuncDef) + and base_attr.node.is_property + and cast(Decorator, base_attr.node.items[0]).var.is_settable_property + ): self.fail(message_registry.READ_ONLY_PROPERTY_OVERRIDES_READ_WRITE, e) def visit_with_stmt(self, s: WithStmt) -> None: @@ -4096,10 +4515,11 @@ def visit_with_stmt(self, s: WithStmt) -> None: if is_literal_type(exit_ret_type, "builtins.bool", False): continue - if (is_literal_type(exit_ret_type, "builtins.bool", True) - or (isinstance(exit_ret_type, Instance) - and exit_ret_type.type.fullname == 'builtins.bool' - and state.strict_optional)): + if is_literal_type(exit_ret_type, "builtins.bool", True) or ( + isinstance(exit_ret_type, Instance) + and exit_ret_type.type.fullname == "builtins.bool" + and state.strict_optional + ): # Note: if strict-optional is disabled, this bool instance # could actually be an Optional[bool]. exceptions_maybe_suppressed = True @@ -4120,31 +4540,37 @@ def check_untyped_after_decorator(self, typ: Type, func: FuncDef) -> None: if mypy.checkexpr.has_any_type(typ): self.msg.untyped_decorated_function(typ, func) - def check_async_with_item(self, expr: Expression, target: Optional[Expression], - infer_lvalue_type: bool) -> Type: + def check_async_with_item( + self, expr: Expression, target: Optional[Expression], infer_lvalue_type: bool + ) -> Type: echk = self.expr_checker ctx = echk.accept(expr) - obj = echk.check_method_call_by_name('__aenter__', ctx, [], [], expr)[0] + obj = echk.check_method_call_by_name("__aenter__", ctx, [], [], expr)[0] obj = echk.check_awaitable_expr( - obj, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER) + obj, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AENTER + ) if target: self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type) arg = self.temp_node(AnyType(TypeOfAny.special_form), expr) res, _ = echk.check_method_call_by_name( - '__aexit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr) + "__aexit__", ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr + ) return echk.check_awaitable_expr( - res, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT) + res, expr, message_registry.INCOMPATIBLE_TYPES_IN_ASYNC_WITH_AEXIT + ) - def check_with_item(self, expr: Expression, target: Optional[Expression], - infer_lvalue_type: bool) -> Type: + def check_with_item( + self, expr: Expression, target: Optional[Expression], infer_lvalue_type: bool + ) -> Type: echk = self.expr_checker ctx = echk.accept(expr) - obj = echk.check_method_call_by_name('__enter__', ctx, [], [], expr)[0] + obj = echk.check_method_call_by_name("__enter__", ctx, [], [], expr)[0] if target: self.check_assignment(target, self.temp_node(obj, expr), infer_lvalue_type) arg = self.temp_node(AnyType(TypeOfAny.special_form), expr) res, _ = echk.check_method_call_by_name( - '__exit__', ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr) + "__exit__", ctx, [arg] * 3, [nodes.ARG_POS] * 3, expr + ) return res def visit_print_stmt(self, s: PrintStmt) -> None: @@ -4154,20 +4580,21 @@ def visit_print_stmt(self, s: PrintStmt) -> None: target_type = get_proper_type(self.expr_checker.accept(s.target)) if not isinstance(target_type, NoneType): write_type = self.expr_checker.analyze_external_member_access( - 'write', target_type, s.target) + "write", target_type, s.target + ) required_type = CallableType( - arg_types=[self.named_type('builtins.str')], + arg_types=[self.named_type("builtins.str")], arg_kinds=[ARG_POS], arg_names=[None], ret_type=AnyType(TypeOfAny.implementation_artifact), - fallback=self.named_type('builtins.function'), + fallback=self.named_type("builtins.function"), ) # This has to be hard-coded, since it is a syntax pattern, not a function call. if not is_subtype(write_type, required_type): - self.fail(message_registry.PYTHON2_PRINT_FILE_TYPE.format( - write_type, - required_type, - ), s.target) + self.fail( + message_registry.PYTHON2_PRINT_FILE_TYPE.format(write_type, required_type), + s.target, + ) def visit_break_stmt(self, s: BreakStmt) -> None: self.binder.handle_break() @@ -4194,22 +4621,21 @@ def visit_match_stmt(self, s: MatchStmt) -> None: # The second pass narrows down the types and type checks bodies. for p, g, b in zip(s.patterns, s.guards, s.bodies): - current_subject_type = self.expr_checker.narrow_type_from_binder(s.subject, - subject_type) + current_subject_type = self.expr_checker.narrow_type_from_binder( + s.subject, subject_type + ) pattern_type = self.pattern_checker.accept(p, current_subject_type) with self.binder.frame_context(can_skip=True, fall_through=2): - if b.is_unreachable or isinstance(get_proper_type(pattern_type.type), - UninhabitedType): + if b.is_unreachable or isinstance( + get_proper_type(pattern_type.type), UninhabitedType + ): self.push_type_map(None) else_map: TypeMap = {} else: pattern_map, else_map = conditional_types_to_typemaps( - s.subject, - pattern_type.type, - pattern_type.rest_type + s.subject, pattern_type.type, pattern_type.rest_type ) - self.remove_capture_conflicts(pattern_type.captures, - inferred_types) + self.remove_capture_conflicts(pattern_type.captures, inferred_types) self.push_type_map(pattern_map) self.push_type_map(pattern_type.captures) if g is not None: @@ -4253,10 +4679,14 @@ def infer_variable_types_from_type_maps(self, type_maps: List[TypeMap]) -> Dict[ previous_type, _, _ = self.check_lvalue(expr) if previous_type is not None: already_exists = True - if self.check_subtype(typ, previous_type, expr, - msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, - subtype_label="pattern captures type", - supertype_label="variable has type"): + if self.check_subtype( + typ, + previous_type, + expr, + msg=message_registry.INCOMPATIBLE_TYPES_IN_CAPTURE, + subtype_label="pattern captures type", + supertype_label="variable has type", + ): inferred_types[var] = previous_type if not already_exists: @@ -4276,12 +4706,13 @@ def remove_capture_conflicts(self, type_map: TypeMap, inferred_types: Dict[Var, if node not in inferred_types or not is_subtype(typ, inferred_types[node]): del type_map[expr] - def make_fake_typeinfo(self, - curr_module_fullname: str, - class_gen_name: str, - class_short_name: str, - bases: List[Instance], - ) -> Tuple[ClassDef, TypeInfo]: + def make_fake_typeinfo( + self, + curr_module_fullname: str, + class_gen_name: str, + class_short_name: str, + bases: List[Instance], + ) -> Tuple[ClassDef, TypeInfo]: # Build the fake ClassDef and TypeInfo together. # The ClassDef is full of lies and doesn't actually contain a body. # Use format_bare to generate a nice name for error messages. @@ -4289,7 +4720,7 @@ def make_fake_typeinfo(self, # should be irrelevant for a generated type like this: # is_protocol, protocol_members, is_abstract cdef = ClassDef(class_short_name, Block([])) - cdef.fullname = curr_module_fullname + '.' + class_gen_name + cdef.fullname = curr_module_fullname + "." + class_gen_name info = TypeInfo(SymbolTable(), cdef, curr_module_fullname) cdef.info = info info.bases = bases @@ -4297,10 +4728,9 @@ def make_fake_typeinfo(self, info.calculate_metaclass_type() return cdef, info - def intersect_instances(self, - instances: Tuple[Instance, Instance], - ctx: Context, - ) -> Optional[Instance]: + def intersect_instances( + self, instances: Tuple[Instance, Instance], ctx: Context + ) -> Optional[Instance]: """Try creating an ad-hoc intersection of the given instances. Note that this function does *not* try and create a full-fledged @@ -4339,17 +4769,13 @@ def _get_base_classes(instances_: Tuple[Instance, Instance]) -> List[Instance]: return base_classes_ def _make_fake_typeinfo_and_full_name( - base_classes_: List[Instance], - curr_module_: MypyFile, + base_classes_: List[Instance], curr_module_: MypyFile ) -> Tuple[TypeInfo, str]: names_list = pretty_seq([x.type.name for x in base_classes_], "and") - short_name = f'' + short_name = f"" full_name_ = gen_unique_name(short_name, curr_module_.names) cdef, info_ = self.make_fake_typeinfo( - curr_module_.fullname, - full_name_, - short_name, - base_classes_, + curr_module_.fullname, full_name_, short_name, base_classes_ ) return info_, full_name_ @@ -4372,13 +4798,15 @@ def _make_fake_typeinfo_and_full_name( except MroError: if self.should_report_unreachable_issues(): self.msg.impossible_intersection( - pretty_names_list, "inconsistent method resolution order", ctx) + pretty_names_list, "inconsistent method resolution order", ctx + ) return None if local_errors.has_new_errors(): if self.should_report_unreachable_issues(): self.msg.impossible_intersection( - pretty_names_list, "incompatible method signatures", ctx) + pretty_names_list, "incompatible method signatures", ctx + ) return None curr_module.names[full_name] = SymbolTableNode(GDEF, info) @@ -4395,18 +4823,17 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType # have a valid fullname and a corresponding entry in a symbol table. We generate # a unique name inside the symbol table of the current module. cur_module = cast(MypyFile, self.scope.stack[0]) - gen_name = gen_unique_name(f"", - cur_module.names) + gen_name = gen_unique_name(f"", cur_module.names) # Synthesize a fake TypeInfo short_name = format_type_bare(typ) cdef, info = self.make_fake_typeinfo(cur_module.fullname, gen_name, short_name, [typ]) # Build up a fake FuncDef so we can populate the symbol table. - func_def = FuncDef('__call__', [], Block([]), callable_type) - func_def._fullname = cdef.fullname + '.__call__' + func_def = FuncDef("__call__", [], Block([]), callable_type) + func_def._fullname = cdef.fullname + ".__call__" func_def.info = info - info.names['__call__'] = SymbolTableNode(MDEF, func_def) + info.names["__call__"] = SymbolTableNode(MDEF, func_def) cur_module.names[gen_name] = SymbolTableNode(GDEF, info) @@ -4415,19 +4842,21 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType def make_fake_callable(self, typ: Instance) -> Instance: """Produce a new type that makes type Callable with a generic callable type.""" - fallback = self.named_type('builtins.function') - callable_type = CallableType([AnyType(TypeOfAny.explicit), - AnyType(TypeOfAny.explicit)], - [nodes.ARG_STAR, nodes.ARG_STAR2], - [None, None], - ret_type=AnyType(TypeOfAny.explicit), - fallback=fallback, - is_ellipsis_args=True) + fallback = self.named_type("builtins.function") + callable_type = CallableType( + [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)], + [nodes.ARG_STAR, nodes.ARG_STAR2], + [None, None], + ret_type=AnyType(TypeOfAny.explicit), + fallback=fallback, + is_ellipsis_args=True, + ) return self.intersect_instance_callable(typ, callable_type) - def partition_by_callable(self, typ: Type, - unsound_partition: bool) -> Tuple[List[Type], List[Type]]: + def partition_by_callable( + self, typ: Type, unsound_partition: bool + ) -> Tuple[List[Type], List[Type]]: """Partitions a type into callable subtypes and uncallable subtypes. Thus, given: @@ -4459,8 +4888,9 @@ def partition_by_callable(self, typ: Type, for subtype in typ.items: # Use unsound_partition when handling unions in order to # allow the expected type discrimination. - subcallables, subuncallables = self.partition_by_callable(subtype, - unsound_partition=True) + subcallables, subuncallables = self.partition_by_callable( + subtype, unsound_partition=True + ) callables.extend(subcallables) uncallables.extend(subuncallables) return callables, uncallables @@ -4474,8 +4904,9 @@ def partition_by_callable(self, typ: Type, # do better. # If it is possible for the false branch to execute, return the original # type to avoid losing type information. - callables, uncallables = self.partition_by_callable(erase_to_union_or_bound(typ), - unsound_partition) + callables, uncallables = self.partition_by_callable( + erase_to_union_or_bound(typ), unsound_partition + ) uncallables = [typ] if len(uncallables) else [] return callables, uncallables @@ -4486,10 +4917,11 @@ def partition_by_callable(self, typ: Type, ityp = tuple_fallback(typ) if isinstance(ityp, Instance): - method = ityp.type.get_method('__call__') + method = ityp.type.get_method("__call__") if method and method.type: - callables, uncallables = self.partition_by_callable(method.type, - unsound_partition=False) + callables, uncallables = self.partition_by_callable( + method.type, unsound_partition=False + ) if len(callables) and not len(uncallables): # Only consider the type callable if its __call__ method is # definitely callable. @@ -4508,9 +4940,9 @@ def partition_by_callable(self, typ: Type, # We don't know how properly make the type callable. return [typ], [typ] - def conditional_callable_type_map(self, expr: Expression, - current_type: Optional[Type], - ) -> Tuple[TypeMap, TypeMap]: + def conditional_callable_type_map( + self, expr: Expression, current_type: Optional[Type] + ) -> Tuple[TypeMap, TypeMap]: """Takes in an expression and the current type of the expression. Returns a 2-tuple: The first element is a map from the expression to @@ -4524,13 +4956,13 @@ def conditional_callable_type_map(self, expr: Expression, if isinstance(get_proper_type(current_type), AnyType): return {}, {} - callables, uncallables = self.partition_by_callable(current_type, - unsound_partition=False) + callables, uncallables = self.partition_by_callable(current_type, unsound_partition=False) if len(callables) and len(uncallables): callable_map = {expr: UnionType.make_union(callables)} if len(callables) else None - uncallable_map = { - expr: UnionType.make_union(uncallables)} if len(uncallables) else None + uncallable_map = ( + {expr: UnionType.make_union(uncallables)} if len(uncallables) else None + ) return callable_map, uncallable_map elif len(callables): @@ -4541,15 +4973,15 @@ def conditional_callable_type_map(self, expr: Expression, def _is_truthy_type(self, t: ProperType) -> bool: return ( ( - isinstance(t, Instance) and - bool(t.type) and - not t.type.has_readable_member('__bool__') and - not t.type.has_readable_member('__len__') + isinstance(t, Instance) + and bool(t.type) + and not t.type.has_readable_member("__bool__") + and not t.type.has_readable_member("__len__") ) or isinstance(t, FunctionLike) or ( - isinstance(t, UnionType) and - all(self._is_truthy_type(t) for t in get_proper_types(t.items)) + isinstance(t, UnionType) + and all(self._is_truthy_type(t) for t in get_proper_types(t.items)) ) ) @@ -4572,20 +5004,20 @@ def format_expr_type() -> str: return f'"{expr.callee.name}" returns {typ}' elif isinstance(expr.callee, RefExpr) and expr.callee.fullname: return f'"{expr.callee.fullname}" returns {typ}' - return f'Call returns {typ}' + return f"Call returns {typ}" else: - return f'Expression has type {typ}' + return f"Expression has type {typ}" if isinstance(t, FunctionLike): self.fail(message_registry.FUNCTION_ALWAYS_TRUE.format(format_type(t)), expr) elif isinstance(t, UnionType): - self.fail(message_registry.TYPE_ALWAYS_TRUE_UNIONTYPE.format(format_expr_type()), - expr) + self.fail(message_registry.TYPE_ALWAYS_TRUE_UNIONTYPE.format(format_expr_type()), expr) else: self.fail(message_registry.TYPE_ALWAYS_TRUE.format(format_expr_type()), expr) - def find_type_equals_check(self, node: ComparisonExpr, expr_indices: List[int] - ) -> Tuple[TypeMap, TypeMap]: + def find_type_equals_check( + self, node: ComparisonExpr, expr_indices: List[int] + ) -> Tuple[TypeMap, TypeMap]: """Narrow types based on any checks of the type ``type(x) == T`` Args: @@ -4593,10 +5025,10 @@ def find_type_equals_check(self, node: ComparisonExpr, expr_indices: List[int] expr_indices: The list of indices of expressions in ``node`` that are being compared """ + def is_type_call(expr: CallExpr) -> bool: """Is expr a call to type with one argument?""" - return (refers_to_fullname(expr.callee, 'builtins.type') - and len(expr.args) == 1) + return refers_to_fullname(expr.callee, "builtins.type") and len(expr.args) == 1 # exprs that are being passed into type exprs_in_type_calls: List[Expression] = [] @@ -4634,13 +5066,11 @@ def is_type_call(expr: CallExpr) -> bool: else_maps: List[TypeMap] = [] for expr in exprs_in_type_calls: current_if_type, current_else_type = self.conditional_types_with_intersection( - self.lookup_type(expr), - type_being_compared, - expr + self.lookup_type(expr), type_being_compared, expr + ) + current_if_map, current_else_map = conditional_types_to_typemaps( + expr, current_if_type, current_else_type ) - current_if_map, current_else_map = conditional_types_to_typemaps(expr, - current_if_type, - current_else_type) if_maps.append(current_if_map) else_maps.append(current_else_map) @@ -4663,8 +5093,7 @@ def combine_maps(list_maps: List[TypeMap]) -> TypeMap: else_map = {} return if_map, else_map - def find_isinstance_check(self, node: Expression - ) -> Tuple[TypeMap, TypeMap]: + def find_isinstance_check(self, node: Expression) -> Tuple[TypeMap, TypeMap]: """Find any isinstance checks (within a chain of ands). Includes implicit and explicit checks for None and calls to callable. Also includes TypeGuard functions. @@ -4691,24 +5120,22 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM if isinstance(node, CallExpr) and len(node.args) != 0: expr = collapse_walrus(node.args[0]) - if refers_to_fullname(node.callee, 'builtins.isinstance'): + if refers_to_fullname(node.callee, "builtins.isinstance"): if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: return conditional_types_to_typemaps( expr, *self.conditional_types_with_intersection( - self.lookup_type(expr), - self.get_isinstance_type(node.args[1]), - expr - ) + self.lookup_type(expr), self.get_isinstance_type(node.args[1]), expr + ), ) - elif refers_to_fullname(node.callee, 'builtins.issubclass'): + elif refers_to_fullname(node.callee, "builtins.issubclass"): if len(node.args) != 2: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: return self.infer_issubclass_maps(node, expr) - elif refers_to_fullname(node.callee, 'builtins.callable'): + elif refers_to_fullname(node.callee, "builtins.callable"): if len(node.args) != 1: # the error will be reported elsewhere return {}, {} if literal(expr) == LITERAL_TYPE: @@ -4741,9 +5168,11 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM expr_type = self.lookup_type(expr) operand_types.append(expr_type) - if (literal(expr) == LITERAL_TYPE - and not is_literal_none(expr) - and not self.is_literal_enum(expr)): + if ( + literal(expr) == LITERAL_TYPE + and not is_literal_none(expr) + and not self.is_literal_enum(expr) + ): h = literal_hash(expr) if h is not None: narrowable_operand_index_to_hash[i] = h @@ -4765,9 +5194,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # in practice. simplified_operator_list = group_comparison_operands( - node.pairwise(), - narrowable_operand_index_to_hash, - {'==', 'is'}, + node.pairwise(), narrowable_operand_index_to_hash, {"==", "is"} ) # Step 3: Analyze each group and infer more precise type maps for each @@ -4776,7 +5203,7 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM partial_type_maps = [] for operator, expr_indices in simplified_operator_list: - if operator in {'is', 'is not', '==', '!='}: + if operator in {"is", "is not", "==", "!="}: # is_valid_target: # Controls which types we're allowed to narrow exprs to. Note that # we cannot use 'is_literal_type_like' in both cases since doing @@ -4793,17 +5220,19 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM # should_narrow_by_identity: # Set to 'false' only if the user defines custom __eq__ or __ne__ methods # that could cause identity-based narrowing to produce invalid results. - if operator in {'is', 'is not'}: + if operator in {"is", "is not"}: is_valid_target: Callable[[Type], bool] = is_singleton_type coerce_only_in_literal_context = False should_narrow_by_identity = True else: + def is_exactly_literal_type(t: Type) -> bool: return isinstance(get_proper_type(t), LiteralType) def has_no_custom_eq_checks(t: Type) -> bool: - return (not custom_special_method(t, '__eq__', check_all=False) - and not custom_special_method(t, '__ne__', check_all=False)) + return not custom_special_method( + t, "__eq__", check_all=False + ) and not custom_special_method(t, "__ne__", check_all=False) is_valid_target = is_exactly_literal_type coerce_only_in_literal_context = True @@ -4839,7 +5268,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: # explicit type(x) == some_type check if if_map == {} and else_map == {}: if_map, else_map = self.find_type_equals_check(node, expr_indices) - elif operator in {'in', 'not in'}: + elif operator in {"in", "not in"}: assert len(expr_indices) == 2 left_index, right_index = expr_indices if left_index not in narrowable_operand_index_to_hash: @@ -4855,8 +5284,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: collection_item_type = get_proper_type(builtin_item_type(collection_type)) if collection_item_type is None or is_optional(collection_item_type): continue - if (isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == 'builtins.object'): + if ( + isinstance(collection_item_type, Instance) + and collection_item_type.type.fullname == "builtins.object" + ): continue if is_overlapping_erased_types(item_type, collection_item_type): if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {} @@ -4866,7 +5297,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map = {} else_map = {} - if operator in {'is not', '!=', 'not in'}: + if operator in {"is not", "!=", "not in"}: if_map, else_map = else_map, if_map partial_type_maps.append((if_map, else_map)) @@ -4894,23 +5325,27 @@ def has_no_custom_eq_checks(t: Type) -> bool: (None if if_assignment_map is None or if_condition_map is None else if_map), (None if else_assignment_map is None or else_condition_map is None else else_map), ) - elif isinstance(node, OpExpr) and node.op == 'and': + elif isinstance(node, OpExpr) and node.op == "and": left_if_vars, left_else_vars = self.find_isinstance_check(node.left) right_if_vars, right_else_vars = self.find_isinstance_check(node.right) # (e1 and e2) is true if both e1 and e2 are true, # and false if at least one of e1 and e2 is false. - return (and_conditional_maps(left_if_vars, right_if_vars), - or_conditional_maps(left_else_vars, right_else_vars)) - elif isinstance(node, OpExpr) and node.op == 'or': + return ( + and_conditional_maps(left_if_vars, right_if_vars), + or_conditional_maps(left_else_vars, right_else_vars), + ) + elif isinstance(node, OpExpr) and node.op == "or": left_if_vars, left_else_vars = self.find_isinstance_check(node.left) right_if_vars, right_else_vars = self.find_isinstance_check(node.right) # (e1 or e2) is true if at least one of e1 or e2 is true, # and false if both e1 and e2 are false. - return (or_conditional_maps(left_if_vars, right_if_vars), - and_conditional_maps(left_else_vars, right_else_vars)) - elif isinstance(node, UnaryExpr) and node.op == 'not': + return ( + or_conditional_maps(left_if_vars, right_if_vars), + and_conditional_maps(left_else_vars, right_else_vars), + ) + elif isinstance(node, UnaryExpr) and node.op == "not": left, right = self.find_isinstance_check(node.expr) return right, left @@ -4922,20 +5357,11 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_type = true_only(vartype) else_type = false_only(vartype) - if_map = ( - {node: if_type} - if not isinstance(if_type, UninhabitedType) - else None - ) - else_map = ( - {node: else_type} - if not isinstance(else_type, UninhabitedType) - else None - ) + if_map = {node: if_type} if not isinstance(if_type, UninhabitedType) else None + else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None return if_map, else_map - def propagate_up_typemap_info(self, - new_types: TypeMap) -> TypeMap: + def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap: """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. Specifically, this function accepts two mappings of expression to original types: @@ -4978,9 +5404,7 @@ def propagate_up_typemap_info(self, output_map[parent_expr] = proposed_parent_type return output_map - def refine_parent_types(self, - expr: Expression, - expr_type: Type) -> Mapping[Expression, Type]: + def refine_parent_types(self, expr: Expression, expr_type: Type) -> Mapping[Expression, Type]: """Checks if the given expr is a 'lookup operation' into a union and iteratively refines the parent types based on the 'expr_type'. @@ -5022,6 +5446,7 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: return None else: return member_type + elif isinstance(expr, IndexExpr): parent_expr = expr.base parent_type = self.lookup_type_or_none(parent_expr) @@ -5044,9 +5469,11 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: except KeyError: return None return make_simplified_union(member_types) + else: int_literals = try_getting_int_literals_from_type(index_type) if int_literals is not None: + def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: if not isinstance(new_parent_type, TupleType): return None @@ -5056,6 +5483,7 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: except IndexError: return None return make_simplified_union(member_types) + else: return output else: @@ -5096,14 +5524,15 @@ def replay_lookup(new_parent_type: ProperType) -> Optional[Type]: expr = parent_expr expr_type = output[parent_expr] = make_simplified_union(new_parent_types) - def refine_identity_comparison_expression(self, - operands: List[Expression], - operand_types: List[Type], - chain_indices: List[int], - narrowable_operand_indices: AbstractSet[int], - is_valid_target: Callable[[ProperType], bool], - coerce_only_in_literal_context: bool, - ) -> Tuple[TypeMap, TypeMap]: + def refine_identity_comparison_expression( + self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + is_valid_target: Callable[[ProperType], bool], + coerce_only_in_literal_context: bool, + ) -> Tuple[TypeMap, TypeMap]: """Produce conditional type maps refining expressions by an identity/equality comparison. The 'operands' and 'operand_types' lists should be the full list of operands used @@ -5191,8 +5620,9 @@ def refine_identity_comparison_expression(self, sum_type_name = None target = get_proper_type(target) - if (isinstance(target, LiteralType) and - (target.is_enum_literal() or isinstance(target.value, bool))): + if isinstance(target, LiteralType) and ( + target.is_enum_literal() or isinstance(target.value, bool) + ): sum_type_name = target.fallback.type.fullname target_type = [TypeRange(target, is_upper_bound=False)] @@ -5225,12 +5655,13 @@ def refine_identity_comparison_expression(self, return reduce_conditional_maps(partial_type_maps) - def refine_away_none_in_comparison(self, - operands: List[Expression], - operand_types: List[Type], - chain_indices: List[int], - narrowable_operand_indices: AbstractSet[int], - ) -> Tuple[TypeMap, TypeMap]: + def refine_away_none_in_comparison( + self, + operands: List[Expression], + operand_types: List[Type], + chain_indices: List[int], + narrowable_operand_indices: AbstractSet[int], + ) -> Tuple[TypeMap, TypeMap]: """Produces conditional type maps refining away None in an identity/equality chain. For more details about what the different arguments mean, see the @@ -5260,16 +5691,18 @@ def refine_away_none_in_comparison(self, # Helpers # - def check_subtype(self, - subtype: Type, - supertype: Type, - context: Context, - msg: Union[str, ErrorMessage] = message_registry.INCOMPATIBLE_TYPES, - subtype_label: Optional[str] = None, - supertype_label: Optional[str] = None, - *, - code: Optional[ErrorCode] = None, - outer_context: Optional[Context] = None) -> bool: + def check_subtype( + self, + subtype: Type, + supertype: Type, + context: Context, + msg: Union[str, ErrorMessage] = message_registry.INCOMPATIBLE_TYPES, + subtype_label: Optional[str] = None, + supertype_label: Optional[str] = None, + *, + code: Optional[ErrorCode] = None, + outer_context: Optional[Context] = None, + ) -> bool: """Generate an error if the subtype is not compatible with supertype.""" if is_subtype(subtype, supertype, options=self.options): return True @@ -5281,26 +5714,28 @@ def check_subtype(self, msg_text = msg subtype = get_proper_type(subtype) supertype = get_proper_type(supertype) - if self.msg.try_report_long_tuple_assignment_error(subtype, supertype, context, msg_text, - subtype_label, supertype_label, code=code): + if self.msg.try_report_long_tuple_assignment_error( + subtype, supertype, context, msg_text, subtype_label, supertype_label, code=code + ): return False if self.should_suppress_optional_error([subtype]): return False extra_info: List[str] = [] - note_msg = '' + note_msg = "" notes: List[str] = [] if subtype_label is not None or supertype_label is not None: subtype_str, supertype_str = format_type_distinctly(subtype, supertype) if subtype_label is not None: - extra_info.append(subtype_label + ' ' + subtype_str) + extra_info.append(subtype_label + " " + subtype_str) if supertype_label is not None: - extra_info.append(supertype_label + ' ' + supertype_str) - note_msg = make_inferred_type_note(outer_context or context, subtype, - supertype, supertype_str) + extra_info.append(supertype_label + " " + supertype_str) + note_msg = make_inferred_type_note( + outer_context or context, subtype, supertype, supertype_str + ) if isinstance(subtype, Instance) and isinstance(supertype, Instance): notes = append_invariance_notes([], subtype, supertype) if extra_info: - msg_text += ' (' + ', '.join(extra_info) + ')' + msg_text += " (" + ", ".join(extra_info) + ")" self.fail(ErrorMessage(msg_text, code=code), context) for note in notes: @@ -5308,16 +5743,19 @@ def check_subtype(self, if note_msg: self.note(note_msg, context, code=code) self.msg.maybe_note_concatenate_pos_args(subtype, supertype, context, code=code) - if (isinstance(supertype, Instance) and supertype.type.is_protocol and - isinstance(subtype, (Instance, TupleType, TypedDictType))): + if ( + isinstance(supertype, Instance) + and supertype.type.is_protocol + and isinstance(subtype, (Instance, TupleType, TypedDictType)) + ): self.msg.report_protocol_problems(subtype, supertype, context, code=code) if isinstance(supertype, CallableType) and isinstance(subtype, Instance): - call = find_member('__call__', subtype, subtype, is_operator=True) + call = find_member("__call__", subtype, subtype, is_operator=True) if call: self.msg.note_call(subtype, call, context, code=code) if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance): - if supertype.type.is_protocol and supertype.type.protocol_members == ['__call__']: - call = find_member('__call__', supertype, subtype, is_operator=True) + if supertype.type.is_protocol and supertype.type.protocol_members == ["__call__"]: + call = find_member("__call__", supertype, subtype, is_operator=True) assert call is not None self.msg.note_call(supertype, call, context, code=code) self.check_possible_missing_await(subtype, supertype, context) @@ -5334,7 +5772,7 @@ def get_precise_awaitable_type(self, typ: Type, local_errors: ErrorWatcher) -> O return None try: aw_type = self.expr_checker.check_awaitable_expr( - typ, Context(), '', ignore_binder=True + typ, Context(), "", ignore_binder=True ) except KeyError: # This is a hack to speed up tests by not including Awaitable in all typing stubs. @@ -5354,7 +5792,7 @@ def checking_await_set(self) -> Iterator[None]: self.checking_missing_await = False def check_possible_missing_await( - self, subtype: Type, supertype: Type, context: Context + self, subtype: Type, supertype: Type, context: Context ) -> None: """Check if the given type becomes a subtype when awaited.""" if self.checking_missing_await: @@ -5371,11 +5809,14 @@ def check_possible_missing_await( def contains_none(self, t: Type) -> bool: t = get_proper_type(t) return ( - isinstance(t, NoneType) or - (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items)) or - (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items)) or - (isinstance(t, Instance) and bool(t.args) - and any(self.contains_none(it) for it in t.args)) + isinstance(t, NoneType) + or (isinstance(t, UnionType) and any(self.contains_none(ut) for ut in t.items)) + or (isinstance(t, TupleType) and any(self.contains_none(tt) for tt in t.items)) + or ( + isinstance(t, Instance) + and bool(t.args) + and any(self.contains_none(it) for it in t.args) + ) ) def should_suppress_optional_error(self, related_types: List[Type]) -> bool: @@ -5416,11 +5857,11 @@ def lookup_typeinfo(self, fullname: str) -> TypeInfo: def type_type(self) -> Instance: """Return instance type 'type'.""" - return self.named_type('builtins.type') + return self.named_type("builtins.type") def str_type(self) -> Instance: """Return instance type 'str'.""" - return self.named_type('builtins.str') + return self.named_type("builtins.str") def store_type(self, node: Expression, typ: Type) -> None: """Store the type of a node in the type map.""" @@ -5468,28 +5909,27 @@ def in_checked_function(self) -> bool: - Yes in annotated functions. - No otherwise. """ - return (self.options.check_untyped_defs - or not self.dynamic_funcs - or not self.dynamic_funcs[-1]) + return ( + self.options.check_untyped_defs or not self.dynamic_funcs or not self.dynamic_funcs[-1] + ) def lookup(self, name: str) -> SymbolTableNode: - """Look up a definition from the symbol table with the given name. - """ + """Look up a definition from the symbol table with the given name.""" if name in self.globals: return self.globals[name] else: - b = self.globals.get('__builtins__', None) + b = self.globals.get("__builtins__", None) if b: table = cast(MypyFile, b.node).names if name in table: return table[name] - raise KeyError(f'Failed lookup: {name}') + raise KeyError(f"Failed lookup: {name}") def lookup_qualified(self, name: str) -> SymbolTableNode: - if '.' not in name: + if "." not in name: return self.lookup(name) else: - parts = name.split('.') + parts = name.split(".") n = self.modules[parts[0]] for i in range(1, len(parts) - 1): sym = n.names.get(parts[i]) @@ -5498,23 +5938,27 @@ def lookup_qualified(self, name: str) -> SymbolTableNode: last = parts[-1] if last in n.names: return n.names[last] - elif len(parts) == 2 and parts[0] == 'builtins': - fullname = 'builtins.' + last + elif len(parts) == 2 and parts[0] == "builtins": + fullname = "builtins." + last if fullname in SUGGESTED_TEST_FIXTURES: suggestion = ", e.g. add '[builtins fixtures/{}]' to your test".format( - SUGGESTED_TEST_FIXTURES[fullname]) + SUGGESTED_TEST_FIXTURES[fullname] + ) else: - suggestion = '' - raise KeyError("Could not find builtin symbol '{}' (If you are running a " - "test case, use a fixture that " - "defines this symbol{})".format(last, suggestion)) + suggestion = "" + raise KeyError( + "Could not find builtin symbol '{}' (If you are running a " + "test case, use a fixture that " + "defines this symbol{})".format(last, suggestion) + ) else: msg = "Failed qualified lookup: '{}' (fullname = '{}')." raise KeyError(msg.format(last, name)) @contextmanager - def enter_partial_types(self, *, is_function: bool = False, - is_class: bool = False) -> Iterator[None]: + def enter_partial_types( + self, *, is_function: bool = False, is_class: bool = False + ) -> Iterator[None]: """Enter a new scope for collecting partial types. Also report errors for (some) variables which still have partial @@ -5528,9 +5972,7 @@ def enter_partial_types(self, *, is_function: bool = False, # at the toplevel (with allow_untyped_globals) or if it is in an # untyped function being checked with check_untyped_defs. permissive = (self.options.allow_untyped_globals and not is_local) or ( - self.options.check_untyped_defs - and self.dynamic_funcs - and self.dynamic_funcs[-1] + self.options.check_untyped_defs and self.dynamic_funcs and self.dynamic_funcs[-1] ) partial_types, _, _ = self.partial_types.pop() @@ -5549,13 +5991,17 @@ def enter_partial_types(self, *, is_function: bool = False, # checked for compatibility with base classes elsewhere. Without this exception # mypy could require an annotation for an attribute that already has been # declared in a base class, which would be bad. - allow_none = (not self.options.local_partial_types - or is_function - or (is_class and self.is_defined_in_base_class(var))) - if (allow_none - and isinstance(var.type, PartialType) - and var.type.type is None - and not permissive): + allow_none = ( + not self.options.local_partial_types + or is_function + or (is_class and self.is_defined_in_base_class(var)) + ) + if ( + allow_none + and isinstance(var.type, PartialType) + and var.type.type is None + and not permissive + ): var.type = NoneType() else: if var not in self.partial_reported and not permissive: @@ -5565,7 +6011,8 @@ def enter_partial_types(self, *, is_function: bool = False, var.type = self.fixup_partial_type(var.type) def handle_partial_var_type( - self, typ: PartialType, is_lvalue: bool, node: Var, context: Context) -> Type: + self, typ: PartialType, is_lvalue: bool, node: Var, context: Context + ) -> Type: """Handle a reference to a partial type through a var. (Used by checkexpr and checkmember.) @@ -5583,8 +6030,9 @@ def handle_partial_var_type( if in_scope: context = partial_types[node] if is_local or not self.options.allow_untyped_globals: - self.msg.need_annotation_for_var(node, context, - self.options.python_version) + self.msg.need_annotation_for_var( + node, context, self.options.python_version + ) self.partial_reported.add(node) else: # Defer the node -- we might get a better type in the outer scope @@ -5602,9 +6050,7 @@ def fixup_partial_type(self, typ: Type) -> Type: if typ.type is None: return UnionType.make_union([AnyType(TypeOfAny.unannotated), NoneType()]) else: - return Instance( - typ.type, - [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars)) + return Instance(typ.type, [AnyType(TypeOfAny.unannotated)] * len(typ.type.type_vars)) def is_defined_in_base_class(self, var: Var) -> bool: if var.info: @@ -5628,7 +6074,8 @@ def find_partial_types(self, var: Var) -> Optional[Dict[Var, Context]]: return None def find_partial_types_in_all_scopes( - self, var: Var) -> Tuple[bool, bool, Optional[Dict[Var, Context]]]: + self, var: Var + ) -> Tuple[bool, bool, Optional[Dict[Var, Context]]]: """Look for partial type scope containing variable. Return tuple (is the scope active, is the scope a local scope, scope). @@ -5645,8 +6092,9 @@ def find_partial_types_in_all_scopes( # as if --local-partial-types is always on (because it used to be like this). disallow_other_scopes = True - scope_active = (not disallow_other_scopes - or scope.is_local == self.partial_types[-1].is_local) + scope_active = ( + not disallow_other_scopes or scope.is_local == self.partial_types[-1].is_local + ) return scope_active, scope.is_local, scope.map return False, False, None @@ -5654,55 +6102,50 @@ def temp_node(self, t: Type, context: Optional[Context] = None) -> TempNode: """Create a temporary node with the given, fixed type.""" return TempNode(t, context=context) - def fail(self, msg: Union[str, ErrorMessage], context: Context, *, - code: Optional[ErrorCode] = None) -> None: + def fail( + self, msg: Union[str, ErrorMessage], context: Context, *, code: Optional[ErrorCode] = None + ) -> None: """Produce an error message.""" if isinstance(msg, ErrorMessage): self.msg.fail(msg.value, context, code=msg.code) return self.msg.fail(msg, context, code=code) - def note(self, - msg: str, - context: Context, - offset: int = 0, - *, - code: Optional[ErrorCode] = None) -> None: + def note( + self, msg: str, context: Context, offset: int = 0, *, code: Optional[ErrorCode] = None + ) -> None: """Produce a note.""" self.msg.note(msg, context, offset=offset, code=code) def iterable_item_type(self, instance: Instance) -> Type: - iterable = map_instance_to_supertype( - instance, - self.lookup_typeinfo('typing.Iterable')) + iterable = map_instance_to_supertype(instance, self.lookup_typeinfo("typing.Iterable")) item_type = iterable.args[0] if not isinstance(get_proper_type(item_type), AnyType): # This relies on 'map_instance_to_supertype' returning 'Iterable[Any]' # in case there is no explicit base class. return item_type # Try also structural typing. - iter_type = get_proper_type(find_member('__iter__', instance, instance, is_operator=True)) + iter_type = get_proper_type(find_member("__iter__", instance, instance, is_operator=True)) if iter_type and isinstance(iter_type, CallableType): ret_type = get_proper_type(iter_type.ret_type) if isinstance(ret_type, Instance): - iterator = map_instance_to_supertype(ret_type, - self.lookup_typeinfo('typing.Iterator')) + iterator = map_instance_to_supertype( + ret_type, self.lookup_typeinfo("typing.Iterator") + ) item_type = iterator.args[0] return item_type def function_type(self, func: FuncBase) -> FunctionLike: - return function_type(func, self.named_type('builtins.function')) + return function_type(func, self.named_type("builtins.function")) - def push_type_map(self, type_map: 'TypeMap') -> None: + def push_type_map(self, type_map: "TypeMap") -> None: if type_map is None: self.binder.unreachable() else: for expr, type in type_map.items(): self.binder.put(expr, type) - def infer_issubclass_maps(self, node: CallExpr, - expr: Expression, - ) -> Tuple[TypeMap, TypeMap]: + def infer_issubclass_maps(self, node: CallExpr, expr: Expression) -> Tuple[TypeMap, TypeMap]: """Infer type restrictions for an expression in issubclass call.""" vartype = self.lookup_type(expr) type = self.get_isinstance_type(node.args[1]) @@ -5721,9 +6164,8 @@ def infer_issubclass_maps(self, node: CallExpr, vartype = UnionType(union_list) elif isinstance(vartype, TypeType): vartype = vartype.item - elif (isinstance(vartype, Instance) and - vartype.type.fullname == 'builtins.type'): - vartype = self.named_type('builtins.object') + elif isinstance(vartype, Instance) and vartype.type.fullname == "builtins.type": + vartype = self.named_type("builtins.object") else: # Any other object whose type we don't know precisely # for example, Any or a custom metaclass. @@ -5734,27 +6176,28 @@ def infer_issubclass_maps(self, node: CallExpr, return yes_map, no_map @overload - def conditional_types_with_intersection(self, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ctx: Context, - default: None = None - ) -> Tuple[Optional[Type], Optional[Type]]: ... + def conditional_types_with_intersection( + self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: None = None, + ) -> Tuple[Optional[Type], Optional[Type]]: + ... @overload - def conditional_types_with_intersection(self, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ctx: Context, - default: Type - ) -> Tuple[Type, Type]: ... - - def conditional_types_with_intersection(self, - expr_type: Type, - type_ranges: Optional[List[TypeRange]], - ctx: Context, - default: Optional[Type] = None - ) -> Tuple[Optional[Type], Optional[Type]]: + def conditional_types_with_intersection( + self, expr_type: Type, type_ranges: Optional[List[TypeRange]], ctx: Context, default: Type + ) -> Tuple[Type, Type]: + ... + + def conditional_types_with_intersection( + self, + expr_type: Type, + type_ranges: Optional[List[TypeRange]], + ctx: Context, + default: Optional[Type] = None, + ) -> Tuple[Optional[Type], Optional[Type]]: initial_types = conditional_types(expr_type, type_ranges, default) # For some reason, doing "yes_map, no_map = conditional_types_to_typemaps(...)" # doesn't work: mypyc will decide that 'yes_map' is of type None if we try. @@ -5806,7 +6249,7 @@ def is_writable_attribute(self, node: Node) -> bool: return False def get_isinstance_type(self, expr: Expression) -> Optional[List[TypeRange]]: - if isinstance(expr, OpExpr) and expr.op == '|': + if isinstance(expr, OpExpr) and expr.op == "|": left = self.get_isinstance_type(expr.left) right = self.get_isinstance_type(expr.right) if left is None or right is None: @@ -5824,7 +6267,7 @@ def get_isinstance_type(self, expr: Expression) -> Optional[List[TypeRange]]: # Type[A] means "any type that is a subtype of A" rather than "precisely type A" # we indicate this by setting is_upper_bound flag types.append(TypeRange(typ.item, is_upper_bound=True)) - elif isinstance(typ, Instance) and typ.type.fullname == 'builtins.type': + elif isinstance(typ, Instance) and typ.type.fullname == "builtins.type": object_type = Instance(typ.type.mro[-1], []) types.append(TypeRange(object_type, is_upper_bound=True)) elif isinstance(typ, AnyType): @@ -5871,12 +6314,15 @@ class Foo(Enum): if not parent_type.is_type_obj(): return False - return (member_type.is_enum_literal() - and member_type.fallback.type == parent_type.type_object()) + return ( + member_type.is_enum_literal() + and member_type.fallback.type == parent_type.type_object() + ) class CollectArgTypes(TypeTraverserVisitor): """Collects the non-nested argument types in a set.""" + def __init__(self) -> None: self.arg_types: Set[TypeVarType] = set() @@ -5885,23 +6331,24 @@ def visit_type_var(self, t: TypeVarType) -> None: @overload -def conditional_types(current_type: Type, - proposed_type_ranges: Optional[List[TypeRange]], - default: None = None - ) -> Tuple[Optional[Type], Optional[Type]]: ... +def conditional_types( + current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], default: None = None +) -> Tuple[Optional[Type], Optional[Type]]: + ... @overload -def conditional_types(current_type: Type, - proposed_type_ranges: Optional[List[TypeRange]], - default: Type - ) -> Tuple[Type, Type]: ... +def conditional_types( + current_type: Type, proposed_type_ranges: Optional[List[TypeRange]], default: Type +) -> Tuple[Type, Type]: + ... -def conditional_types(current_type: Type, - proposed_type_ranges: Optional[List[TypeRange]], - default: Optional[Type] = None - ) -> Tuple[Optional[Type], Optional[Type]]: +def conditional_types( + current_type: Type, + proposed_type_ranges: Optional[List[TypeRange]], + default: Optional[Type] = None, +) -> Tuple[Optional[Type], Optional[Type]]: """Takes in the current type and a proposed type of an expression. Returns a 2-tuple: The first element is the proposed type, if the expression @@ -5913,11 +6360,11 @@ def conditional_types(current_type: Type, if len(proposed_type_ranges) == 1: target = proposed_type_ranges[0].item target = get_proper_type(target) - if isinstance(target, LiteralType) and (target.is_enum_literal() - or isinstance(target.value, bool)): + if isinstance(target, LiteralType) and ( + target.is_enum_literal() or isinstance(target.value, bool) + ): enum_name = target.fallback.type.fullname - current_type = try_expanding_sum_type_to_union(current_type, - enum_name) + current_type = try_expanding_sum_type_to_union(current_type, enum_name) proposed_items = [type_range.item for type_range in proposed_type_ranges] proposed_type = make_simplified_union(proposed_items) if isinstance(proposed_type, AnyType): @@ -5925,19 +6372,25 @@ def conditional_types(current_type: Type, # attempt to narrow anything. Instead, we broaden the expr to Any to # avoid false positives return proposed_type, default - elif (not any(type_range.is_upper_bound for type_range in proposed_type_ranges) - and is_proper_subtype(current_type, proposed_type)): + elif not any( + type_range.is_upper_bound for type_range in proposed_type_ranges + ) and is_proper_subtype(current_type, proposed_type): # Expression is always of one of the types in proposed_type_ranges return default, UninhabitedType() - elif not is_overlapping_types(current_type, proposed_type, - prohibit_none_typevar_overlap=True): + elif not is_overlapping_types( + current_type, proposed_type, prohibit_none_typevar_overlap=True + ): # Expression is never of any type in proposed_type_ranges return UninhabitedType(), default else: # we can only restrict when the type is precise, not bounded - proposed_precise_type = UnionType.make_union([type_range.item - for type_range in proposed_type_ranges - if not type_range.is_upper_bound]) + proposed_precise_type = UnionType.make_union( + [ + type_range.item + for type_range in proposed_type_ranges + if not type_range.is_upper_bound + ] + ) remaining_type = restrict_subtype_away(current_type, proposed_precise_type) return proposed_type, remaining_type else: @@ -5945,10 +6398,9 @@ def conditional_types(current_type: Type, return current_type, default -def conditional_types_to_typemaps(expr: Expression, - yes_type: Optional[Type], - no_type: Optional[Type] - ) -> Tuple[TypeMap, TypeMap]: +def conditional_types_to_typemaps( + expr: Expression, yes_type: Optional[Type], no_type: Optional[Type] +) -> Tuple[TypeMap, TypeMap]: maps: List[TypeMap] = [] for typ in (yes_type, no_type): proper_type = get_proper_type(typ) @@ -5975,23 +6427,21 @@ def gen_unique_name(base: str, table: SymbolTable) -> str: def is_true_literal(n: Expression) -> bool: """Returns true if this expression is the 'True' literal/keyword.""" - return (refers_to_fullname(n, 'builtins.True') - or isinstance(n, IntExpr) and n.value != 0) + return refers_to_fullname(n, "builtins.True") or isinstance(n, IntExpr) and n.value != 0 def is_false_literal(n: Expression) -> bool: """Returns true if this expression is the 'False' literal/keyword.""" - return (refers_to_fullname(n, 'builtins.False') - or isinstance(n, IntExpr) and n.value == 0) + return refers_to_fullname(n, "builtins.False") or isinstance(n, IntExpr) and n.value == 0 def is_literal_none(n: Expression) -> bool: """Returns true if this expression is the 'None' literal/keyword.""" - return isinstance(n, NameExpr) and n.fullname == 'builtins.None' + return isinstance(n, NameExpr) and n.fullname == "builtins.None" def is_literal_not_implemented(n: Expression) -> bool: - return isinstance(n, NameExpr) and n.fullname == 'builtins.NotImplemented' + return isinstance(n, NameExpr) and n.fullname == "builtins.NotImplemented" def builtin_item_type(tp: Type) -> Optional[Type]: @@ -6012,24 +6462,28 @@ def builtin_item_type(tp: Type) -> Optional[Type]: if isinstance(tp, Instance): if tp.type.fullname in [ - 'builtins.list', 'builtins.tuple', 'builtins.dict', - 'builtins.set', 'builtins.frozenset', + "builtins.list", + "builtins.tuple", + "builtins.dict", + "builtins.set", + "builtins.frozenset", ]: if not tp.args: # TODO: fix tuple in lib-stub/builtins.pyi (it should be generic). return None if not isinstance(get_proper_type(tp.args[0]), AnyType): return tp.args[0] - elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) - for it in get_proper_types(tp.items)): + elif isinstance(tp, TupleType) and all( + not isinstance(it, AnyType) for it in get_proper_types(tp.items) + ): return make_simplified_union(tp.items) # this type is not externally visible elif isinstance(tp, TypedDictType): # TypedDict always has non-optional string keys. Find the key type from the Mapping # base class. for base in tp.fallback.type.mro: - if base.fullname == 'typing.Mapping': + if base.fullname == "typing.Mapping": return map_instance_to_supertype(tp.fallback, base).args[0] - assert False, 'No Mapping base class found for TypedDict fallback' + assert False, "No Mapping base class found for TypedDict fallback" return None @@ -6077,8 +6531,7 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result -def reduce_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]], - ) -> Tuple[TypeMap, TypeMap]: +def reduce_conditional_maps(type_maps: List[Tuple[TypeMap, TypeMap]]) -> Tuple[TypeMap, TypeMap]: """Reduces a list containing pairs of if/else TypeMaps into a single pair. We "and" together all of the if TypeMaps and "or" together the else TypeMaps. So @@ -6167,15 +6620,15 @@ def type(self, type: Type) -> Type: def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool: - """Can a single call match both t and s, based just on positional argument counts? - """ + """Can a single call match both t and s, based just on positional argument counts?""" min_args = max(t.min_args, s.min_args) max_args = min(t.max_possible_positional_args(), s.max_possible_positional_args()) return min_args <= max_args -def is_unsafe_overlapping_overload_signatures(signature: CallableType, - other: CallableType) -> bool: +def is_unsafe_overlapping_overload_signatures( + signature: CallableType, other: CallableType +) -> bool: """Check if two overloaded signatures are unsafely overlapping or partially overlapping. We consider two functions 's' and 't' to be unsafely overlapping if both @@ -6206,18 +6659,23 @@ def is_unsafe_overlapping_overload_signatures(signature: CallableType, # # This discrepancy is unfortunately difficult to get rid of, so we repeat the # checks twice in both directions for now. - return (is_callable_compatible(signature, other, - is_compat=is_overlapping_types_no_promote, - is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), - ignore_return=False, - check_args_covariantly=True, - allow_partial_overlap=True) or - is_callable_compatible(other, signature, - is_compat=is_overlapping_types_no_promote, - is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), - ignore_return=False, - check_args_covariantly=False, - allow_partial_overlap=True)) + return is_callable_compatible( + signature, + other, + is_compat=is_overlapping_types_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), + ignore_return=False, + check_args_covariantly=True, + allow_partial_overlap=True, + ) or is_callable_compatible( + other, + signature, + is_compat=is_overlapping_types_no_promote, + is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), + ignore_return=False, + check_args_covariantly=False, + allow_partial_overlap=True, + ) def detach_callable(typ: CallableType) -> CallableType: @@ -6251,18 +6709,18 @@ def detach_callable(typ: CallableType) -> CallableType: for var in set(all_type_vars): if var.fullname not in used_type_var_names: continue - new_variables.append(TypeVarType( - name=var.name, - fullname=var.fullname, - id=var.id, - values=var.values, - upper_bound=var.upper_bound, - variance=var.variance, - )) + new_variables.append( + TypeVarType( + name=var.name, + fullname=var.fullname, + id=var.id, + values=var.values, + upper_bound=var.upper_bound, + variance=var.variance, + ) + ) out = typ.copy_modified( - variables=new_variables, - arg_types=type_list[:-1], - ret_type=type_list[-1], + variables=new_variables, arg_types=type_list[:-1], ret_type=type_list[-1] ) return out @@ -6282,13 +6740,14 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo # the below subtype check and (surprisingly?) `is_proper_subtype(Any, Any)` # returns `True`. # TODO: find a cleaner solution instead of this ad-hoc erasure. - exp_signature = expand_type(signature, {tvar.id: erase_def_to_union_or_bound(tvar) - for tvar in signature.variables}) + exp_signature = expand_type( + signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables} + ) assert isinstance(exp_signature, ProperType) assert isinstance(exp_signature, CallableType) - return is_callable_compatible(exp_signature, other, - is_compat=is_more_precise, - ignore_return=True) + return is_callable_compatible( + exp_signature, other, is_compat=is_more_precise, ignore_return=True + ) def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: @@ -6297,23 +6756,25 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: # general than one with fewer items (or just one item)? if isinstance(t, CallableType): if isinstance(s, CallableType): - return is_callable_compatible(t, s, - is_compat=is_proper_subtype, - ignore_return=True) + return is_callable_compatible(t, s, is_compat=is_proper_subtype, ignore_return=True) elif isinstance(t, FunctionLike): if isinstance(s, FunctionLike): if len(t.items) == len(s.items): - return all(is_same_arg_prefix(items, itemt) - for items, itemt in zip(t.items, s.items)) + return all( + is_same_arg_prefix(items, itemt) for items, itemt in zip(t.items, s.items) + ) return False def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: - return is_callable_compatible(t, s, - is_compat=is_same_type, - ignore_return=True, - check_args_covariantly=True, - ignore_pos_arg_names=True) + return is_callable_compatible( + t, + s, + is_compat=is_same_type, + ignore_return=True, + check_args_covariantly=True, + ignore_pos_arg_names=True, + ) def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]: @@ -6326,7 +6787,7 @@ def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, st method = operators.op_methods[operator] if isinstance(typ, Instance): if operator in operators.ops_with_inplace_method: - inplace_method = '__i' + method[2:] + inplace_method = "__i" + method[2:] if typ.type.has_readable_member(inplace_method): return True, inplace_method return False, method @@ -6447,8 +6908,8 @@ def push_class(self, info: TypeInfo) -> Iterator[None]: self.stack.pop() -TKey = TypeVar('TKey') -TValue = TypeVar('TValue') +TKey = TypeVar("TKey") +TValue = TypeVar("TValue") class DisjointDict(Generic[TKey, TValue]): @@ -6477,6 +6938,7 @@ class DisjointDict(Generic[TKey, TValue]): tree of height log_2(n). This makes root lookups no longer amoritized constant time when we finally call 'items()'. """ + def __init__(self) -> None: # Each key maps to a unique ID self._key_to_id: Dict[TKey, int] = {} @@ -6545,10 +7007,11 @@ def _lookup_root_id(self, key: TKey) -> int: return i -def group_comparison_operands(pairwise_comparisons: Iterable[Tuple[str, Expression, Expression]], - operand_to_literal_hash: Mapping[int, Key], - operators_to_group: Set[str], - ) -> List[Tuple[str, List[int]]]: +def group_comparison_operands( + pairwise_comparisons: Iterable[Tuple[str, Expression, Expression]], + operand_to_literal_hash: Mapping[int, Key], + operators_to_group: Set[str], +) -> List[Tuple[str, List[int]]]: """Group a series of comparison operands together chained by any operand in the 'operators_to_group' set. All other pairwise operands are kept in groups of size 2. @@ -6648,8 +7111,10 @@ def is_typed_callable(c: Optional[Type]) -> bool: c = get_proper_type(c) if not c or not isinstance(c, CallableType): return False - return not all(isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated - for t in get_proper_types(c.arg_types + [c.ret_type])) + return not all( + isinstance(t, AnyType) and t.type_of_any == TypeOfAny.unannotated + for t in get_proper_types(c.arg_types + [c.ret_type]) + ) def is_untyped_decorator(typ: Optional[Type]) -> bool: @@ -6659,12 +7124,11 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool: elif isinstance(typ, CallableType): return not is_typed_callable(typ) elif isinstance(typ, Instance): - method = typ.type.get_method('__call__') + method = typ.type.get_method("__call__") if method: if isinstance(method, Decorator): - return ( - is_untyped_decorator(method.func.type) - or is_untyped_decorator(method.var.type) + return is_untyped_decorator(method.func.type) or is_untyped_decorator( + method.var.type ) if isinstance(method.type, Overloaded): @@ -6696,7 +7160,7 @@ def is_overlapping_types_no_promote(left: Type, right: Type) -> bool: def is_private(node_name: str) -> bool: """Check if node is private to class definition.""" - return node_name.startswith('__') and not node_name.endswith('__') + return node_name.startswith("__") and not node_name.endswith("__") def is_string_literal(typ: Type) -> bool: @@ -6706,11 +7170,10 @@ def is_string_literal(typ: Type) -> bool: def has_bool_item(typ: ProperType) -> bool: """Return True if type is 'bool' or a union with a 'bool' item.""" - if is_named_instance(typ, 'builtins.bool'): + if is_named_instance(typ, "builtins.bool"): return True if isinstance(typ, UnionType): - return any(is_named_instance(item, 'builtins.bool') - for item in typ.items) + return any(is_named_instance(item, "builtins.bool") for item in typ.items) return False diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 055aba8de08b6..ca0db74b32bf2 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1,95 +1,167 @@ """Expression type checker. This file is conceptually part of TypeChecker.""" -from mypy.backports import OrderedDict -from contextlib import contextmanager import itertools -from typing import ( - cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator -) -from typing_extensions import ClassVar, Final, overload, TypeAlias as _TypeAlias +from contextlib import contextmanager +from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast + +from typing_extensions import ClassVar, Final, TypeAlias as _TypeAlias, overload -from mypy.errors import report_internal_error, ErrorWatcher -from mypy.typeanal import ( - has_any_from_unimported_type, check_for_explicit_any, set_any_tvars, expand_type_alias, - make_optional_type, -) -from mypy.semanal_enum import ENUM_BASES -from mypy.traverser import has_await_expression -from mypy.types import ( - Type, AnyType, CallableType, Overloaded, NoneType, TypeVarType, - TupleType, TypedDictType, Instance, ErasedType, UnionType, - PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, LiteralType, LiteralValue, - is_named_instance, FunctionLike, ParamSpecType, ParamSpecFlavor, - StarType, is_optional, remove_optional, is_generic_instance, get_proper_type, ProperType, - get_proper_types, flatten_nested_unions, LITERAL_TYPE_NAMES, -) -from mypy.nodes import ( - AssertTypeExpr, NameExpr, RefExpr, Var, FuncDef, OverloadedFuncDef, TypeInfo, CallExpr, - MemberExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, FloatExpr, - OpExpr, UnaryExpr, IndexExpr, CastExpr, RevealExpr, TypeApplication, ListExpr, - TupleExpr, DictExpr, LambdaExpr, SuperExpr, SliceExpr, Context, Expression, - ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator, - ConditionalExpr, ComparisonExpr, TempNode, SetComprehension, AssignmentExpr, - DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, - YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, - TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode, - ParamSpecExpr, TypeVarTupleExpr, - ArgKind, ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE, -) -from mypy.literals import literal -from mypy import nodes -from mypy import operators import mypy.checker -from mypy import types -from mypy.sametypes import is_same_type -from mypy.erasetype import replace_meta_vars, erase_type, remove_instance_last_known_values -from mypy.maptype import map_instance_to_supertype -from mypy.messages import MessageBuilder -from mypy import message_registry -from mypy.infer import ( - ArgumentInferContext, infer_type_arguments, infer_function_type_arguments, -) -from mypy import join -from mypy.meet import narrow_declared_type, is_overlapping_types -from mypy.subtypes import is_subtype, is_proper_subtype, is_equivalent, non_method_protocol_members -from mypy import applytype -from mypy import erasetype -from mypy.checkmember import analyze_member_access, type_object_type +import mypy.errorcodes as codes +from mypy import applytype, erasetype, join, message_registry, nodes, operators, types from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals +from mypy.backports import OrderedDict +from mypy.checkmember import analyze_member_access, type_object_type from mypy.checkstrformat import StringFormatterChecker +from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars +from mypy.errors import ErrorWatcher, report_internal_error from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars -from mypy.util import split_module_names -from mypy.typevars import fill_typevars -from mypy.visitor import ExpressionVisitor +from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments +from mypy.literals import literal +from mypy.maptype import map_instance_to_supertype +from mypy.meet import is_overlapping_types, narrow_declared_type +from mypy.message_registry import ErrorMessage +from mypy.messages import MessageBuilder +from mypy.nodes import ( + ARG_NAMED, + ARG_POS, + ARG_STAR, + ARG_STAR2, + LITERAL_TYPE, + REVEAL_TYPE, + ArgKind, + AssertTypeExpr, + AssignmentExpr, + AwaitExpr, + BackquoteExpr, + BytesExpr, + CallExpr, + CastExpr, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + Context, + Decorator, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + EnumCallExpr, + Expression, + FloatExpr, + FuncDef, + GeneratorExpr, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + OpExpr, + OverloadedFuncDef, + ParamSpecExpr, + PlaceholderNode, + PromoteExpr, + RefExpr, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + StrExpr, + SuperExpr, + SymbolNode, + TempNode, + TupleExpr, + TypeAlias, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeInfo, + TypeVarExpr, + TypeVarTupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + YieldExpr, + YieldFromExpr, +) from mypy.plugin import ( + FunctionContext, + FunctionSigContext, + MethodContext, + MethodSigContext, Plugin, - MethodContext, MethodSigContext, - FunctionContext, FunctionSigContext, +) +from mypy.sametypes import is_same_type +from mypy.semanal_enum import ENUM_BASES +from mypy.subtypes import is_equivalent, is_proper_subtype, is_subtype, non_method_protocol_members +from mypy.traverser import has_await_expression +from mypy.typeanal import ( + check_for_explicit_any, + expand_type_alias, + has_any_from_unimported_type, + make_optional_type, + set_any_tvars, ) from mypy.typeops import ( - try_expanding_sum_type_to_union, tuple_fallback, make_simplified_union, - true_only, false_only, erase_to_union_or_bound, function_type, - callable_type, try_getting_str_literals, custom_special_method, - is_literal_type_like, simple_literal_type, + callable_type, + custom_special_method, + erase_to_union_or_bound, + false_only, + function_type, + is_literal_type_like, + make_simplified_union, + simple_literal_type, + true_only, + try_expanding_sum_type_to_union, + try_getting_str_literals, + tuple_fallback, ) -from mypy.message_registry import ErrorMessage -import mypy.errorcodes as codes +from mypy.types import ( + LITERAL_TYPE_NAMES, + AnyType, + CallableType, + DeletedType, + ErasedType, + FunctionLike, + Instance, + LiteralType, + LiteralValue, + NoneType, + Overloaded, + ParamSpecFlavor, + ParamSpecType, + PartialType, + ProperType, + StarType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarType, + UninhabitedType, + UnionType, + flatten_nested_unions, + get_proper_type, + get_proper_types, + is_generic_instance, + is_named_instance, + is_optional, + remove_optional, +) +from mypy.typevars import fill_typevars +from mypy.util import split_module_names +from mypy.visitor import ExpressionVisitor # Type of callback user for checking individual function arguments. See # check_args() below for details. -ArgChecker: _TypeAlias = Callable[[ - Type, - Type, - ArgKind, - Type, - int, - int, - CallableType, - Optional[Type], - Context, - Context, - ], - None, +ArgChecker: _TypeAlias = Callable[ + [Type, Type, ArgKind, Type, int, int, CallableType, Optional[Type], Context, Context], None, ] # Maximum nesting level for math union in overloads, setting this to large values @@ -120,12 +192,9 @@ class TooManyUnions(Exception): def allow_fast_container_literal(t: ProperType) -> bool: - return ( - isinstance(t, Instance) - or ( - isinstance(t, TupleType) - and all(allow_fast_container_literal(get_proper_type(it)) for it in t.items) - ) + return isinstance(t, Instance) or ( + isinstance(t, TupleType) + and all(allow_fast_container_literal(get_proper_type(it)) for it in t.items) ) @@ -147,9 +216,9 @@ def extract_refexpr_names(expr: RefExpr) -> Set[str]: if isinstance(expr.node, TypeInfo): # Reference to a class or a nested class output.update(split_module_names(expr.node.module_name)) - elif expr.fullname is not None and '.' in expr.fullname and not is_suppressed_import: + elif expr.fullname is not None and "." in expr.fullname and not is_suppressed_import: # Everything else (that is not a silenced import within a class) - output.add(expr.fullname.rsplit('.', 1)[0]) + output.add(expr.fullname.rsplit(".", 1)[0]) break elif isinstance(expr, MemberExpr): if isinstance(expr.expr, RefExpr): @@ -184,10 +253,9 @@ class ExpressionChecker(ExpressionVisitor[Type]): strfrm_checker: StringFormatterChecker plugin: Plugin - def __init__(self, - chk: 'mypy.checker.TypeChecker', - msg: MessageBuilder, - plugin: Plugin) -> None: + def __init__( + self, chk: "mypy.checker.TypeChecker", msg: MessageBuilder, plugin: Plugin + ) -> None: """Construct an expression type checker.""" self.chk = chk self.msg = msg @@ -230,7 +298,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: result = self.chk.handle_partial_var_type(result, lvalue, node, e) elif isinstance(node, FuncDef): # Reference to a global function. - result = function_type(node, self.named_type('builtins.function')) + result = function_type(node, self.named_type("builtins.function")) elif isinstance(node, OverloadedFuncDef) and node.type is not None: # node.type is None when there are multiple definitions of a function # and it's decorated by something that is not typing.overload @@ -240,8 +308,9 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: elif isinstance(node, TypeInfo): # Reference to a type object. result = type_object_type(node, self.named_type) - if (isinstance(result, CallableType) and - isinstance(result.ret_type, Instance)): # type: ignore + if isinstance(result, CallableType) and isinstance( # type: ignore + result.ret_type, Instance + ): # We need to set correct line and column # TODO: always do this in type_object_type by passing the original context result.ret_type.line = e.line @@ -253,26 +322,26 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type: elif isinstance(node, MypyFile): # Reference to a module object. try: - result = self.named_type('types.ModuleType') + result = self.named_type("types.ModuleType") except KeyError: # In test cases might 'types' may not be available. # Fall back to a dummy 'object' type instead to # avoid a crash. - result = self.named_type('builtins.object') + result = self.named_type("builtins.object") elif isinstance(node, Decorator): result = self.analyze_var_ref(node.var, e) elif isinstance(node, TypeAlias): # Something that refers to a type alias appears in runtime context. # Note that we suppress bogus errors for alias redefinitions, # they are already reported in semanal.py. - result = self.alias_type_in_runtime_context(node, node.no_args, e, - alias_definition=e.is_alias_rvalue - or lvalue) + result = self.alias_type_in_runtime_context( + node, node.no_args, e, alias_definition=e.is_alias_rvalue or lvalue + ) elif isinstance(node, (TypeVarExpr, ParamSpecExpr)): result = self.object_type() else: if isinstance(node, PlaceholderNode): - assert False, f'PlaceholderNode {node.fullname!r} leaked to checker' + assert False, f"PlaceholderNode {node.fullname!r} leaked to checker" # Unknown reference; use any type implicitly to avoid # generating extra type errors. result = AnyType(TypeOfAny.from_error) @@ -285,8 +354,8 @@ def analyze_var_ref(self, var: Var, context: Context) -> Type: if isinstance(var_type, Instance): if self.is_literal_context() and var_type.last_known_value is not None: return var_type.last_known_value - if var.name in {'True', 'False'}: - return self.infer_literal_expr_type(var.name == 'True', 'builtins.bool') + if var.name in {"True", "False"}: + return self.infer_literal_expr_type(var.name == "True", "builtins.bool") return var.type else: if not var.is_ready and self.chk.in_checked_function(): @@ -306,14 +375,21 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type: return self.visit_call_expr_inner(e, allow_none_return=allow_none_return) def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> Type: - if isinstance(e.callee, RefExpr) and isinstance(e.callee.node, TypeInfo) and \ - e.callee.node.typeddict_type is not None: + if ( + isinstance(e.callee, RefExpr) + and isinstance(e.callee.node, TypeInfo) + and e.callee.node.typeddict_type is not None + ): # Use named fallback for better error messages. typeddict_type = e.callee.node.typeddict_type.copy_modified( - fallback=Instance(e.callee.node, [])) + fallback=Instance(e.callee.node, []) + ) return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e) - if (isinstance(e.callee, NameExpr) and e.callee.name in ('isinstance', 'issubclass') - and len(e.args) == 2): + if ( + isinstance(e.callee, NameExpr) + and e.callee.name in ("isinstance", "issubclass") + and len(e.args) == 2 + ): for typ in mypy.checker.flatten(e.args[1]): node = None if isinstance(typ, NameExpr): @@ -325,14 +401,22 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> if is_expr_literal_type(typ): self.msg.cannot_use_function_with_type(e.callee.name, "Literal", e) continue - if (node and isinstance(node.node, TypeAlias) - and isinstance(get_proper_type(node.node.target), AnyType)): + if ( + node + and isinstance(node.node, TypeAlias) + and isinstance(get_proper_type(node.node.target), AnyType) + ): self.msg.cannot_use_function_with_type(e.callee.name, "Any", e) continue - if ((isinstance(typ, IndexExpr) - and isinstance(typ.analyzed, (TypeApplication, TypeAliasExpr))) - or (isinstance(typ, NameExpr) and node and - isinstance(node.node, TypeAlias) and not node.node.no_args)): + if ( + isinstance(typ, IndexExpr) + and isinstance(typ.analyzed, (TypeApplication, TypeAliasExpr)) + ) or ( + isinstance(typ, NameExpr) + and node + and isinstance(node.node, TypeAlias) + and not node.node.no_args + ): self.msg.type_arguments_not_allowed(e) if isinstance(typ, RefExpr) and isinstance(typ.node, TypeInfo): if typ.node.typeddict_type: @@ -343,20 +427,31 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> type_context = None if isinstance(e.callee, LambdaExpr): formal_to_actual = map_actuals_to_formals( - e.arg_kinds, e.arg_names, - e.callee.arg_kinds, e.callee.arg_names, - lambda i: self.accept(e.args[i])) - - arg_types = [join.join_type_list([self.accept(e.args[j]) for j in formal_to_actual[i]]) - for i in range(len(e.callee.arg_kinds))] - type_context = CallableType(arg_types, e.callee.arg_kinds, e.callee.arg_names, - ret_type=self.object_type(), - fallback=self.named_type('builtins.function')) + e.arg_kinds, + e.arg_names, + e.callee.arg_kinds, + e.callee.arg_names, + lambda i: self.accept(e.args[i]), + ) + + arg_types = [ + join.join_type_list([self.accept(e.args[j]) for j in formal_to_actual[i]]) + for i in range(len(e.callee.arg_kinds)) + ] + type_context = CallableType( + arg_types, + e.callee.arg_kinds, + e.callee.arg_names, + ret_type=self.object_type(), + fallback=self.named_type("builtins.function"), + ) callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True)) - if (self.chk.options.disallow_untyped_calls and - self.chk.in_checked_function() and - isinstance(callee_type, CallableType) - and callee_type.implicit): + if ( + self.chk.options.disallow_untyped_calls + and self.chk.in_checked_function() + and isinstance(callee_type, CallableType) + and callee_type.implicit + ): self.msg.untyped_function_call(callee_type, e) # Figure out the full name of the callee for plugin lookup. @@ -376,19 +471,22 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> # method_fullname() for details on supported objects); # get_method_hook() and get_method_signature_hook() will # be invoked for these. - if (fullname is None - and isinstance(e.callee, MemberExpr) - and self.chk.has_type(e.callee.expr)): + if ( + fullname is None + and isinstance(e.callee, MemberExpr) + and self.chk.has_type(e.callee.expr) + ): member = e.callee.name object_type = self.chk.lookup_type(e.callee.expr) - ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, - object_type, member) + ret_type = self.check_call_expr_with_callee_type( + callee_type, e, fullname, object_type, member + ) if isinstance(e.callee, RefExpr) and len(e.args) == 2: - if e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass'): + if e.callee.fullname in ("builtins.isinstance", "builtins.issubclass"): self.check_runtime_protocol_test(e) - if e.callee.fullname == 'builtins.issubclass': + if e.callee.fullname == "builtins.issubclass": self.check_protocol_issubclass(e) - if isinstance(e.callee, MemberExpr) and e.callee.name == 'format': + if isinstance(e.callee, MemberExpr) and e.callee.name == "format": self.check_str_format_call(e) ret_type = get_proper_type(ret_type) if isinstance(ret_type, UnionType): @@ -398,8 +496,11 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> # Warn on calls to functions that always return None. The check # of ret_type is both a common-case optimization and prevents reporting # the error in dynamic functions (where it will be Any). - if (not allow_none_return and isinstance(ret_type, NoneType) - and self.always_returns_none(e.callee)): + if ( + not allow_none_return + and isinstance(ret_type, NoneType) + and self.always_returns_none(e.callee) + ): self.chk.msg.does_not_return_value(callee_type, e) return AnyType(TypeOfAny.from_error) return ret_type @@ -441,7 +542,7 @@ def method_fullname(self, object_type: Type, method_name: str) -> Optional[str]: type_name = tuple_fallback(object_type).type.fullname if type_name is not None: - return f'{type_name}.{method_name}' + return f"{type_name}.{method_name}" else: return None @@ -470,17 +571,21 @@ def always_returns_none(self, node: Expression) -> bool: def defn_returns_none(self, defn: Optional[SymbolNode]) -> bool: """Check if `defn` can _only_ return None.""" if isinstance(defn, FuncDef): - return (isinstance(defn.type, CallableType) and - isinstance(get_proper_type(defn.type.ret_type), NoneType)) + return isinstance(defn.type, CallableType) and isinstance( + get_proper_type(defn.type.ret_type), NoneType + ) if isinstance(defn, OverloadedFuncDef): return all(self.defn_returns_none(item) for item in defn.items) if isinstance(defn, Var): typ = get_proper_type(defn.type) - if (not defn.is_inferred and isinstance(typ, CallableType) and - isinstance(get_proper_type(typ.ret_type), NoneType)): + if ( + not defn.is_inferred + and isinstance(typ, CallableType) + and isinstance(get_proper_type(typ.ret_type), NoneType) + ): return True if isinstance(typ, Instance): - sym = typ.type.get('__call__') + sym = typ.type.get("__call__") if sym and self.defn_returns_none(sym.node): return True return False @@ -488,33 +593,38 @@ def defn_returns_none(self, defn: Optional[SymbolNode]) -> bool: def check_runtime_protocol_test(self, e: CallExpr) -> None: for expr in mypy.checker.flatten(e.args[1]): tp = get_proper_type(self.chk.lookup_type(expr)) - if (isinstance(tp, CallableType) and tp.is_type_obj() and - tp.type_object().is_protocol and - not tp.type_object().runtime_protocol): + if ( + isinstance(tp, CallableType) + and tp.is_type_obj() + and tp.type_object().is_protocol + and not tp.type_object().runtime_protocol + ): self.chk.fail(message_registry.RUNTIME_PROTOCOL_EXPECTED, e) def check_protocol_issubclass(self, e: CallExpr) -> None: for expr in mypy.checker.flatten(e.args[1]): tp = get_proper_type(self.chk.lookup_type(expr)) - if (isinstance(tp, CallableType) and tp.is_type_obj() and - tp.type_object().is_protocol): + if isinstance(tp, CallableType) and tp.is_type_obj() and tp.type_object().is_protocol: attr_members = non_method_protocol_members(tp.type_object()) if attr_members: - self.chk.msg.report_non_method_protocol(tp.type_object(), - attr_members, e) - - def check_typeddict_call(self, callee: TypedDictType, - arg_kinds: List[ArgKind], - arg_names: Sequence[Optional[str]], - args: List[Expression], - context: Context) -> Type: + self.chk.msg.report_non_method_protocol(tp.type_object(), attr_members, e) + + def check_typeddict_call( + self, + callee: TypedDictType, + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + args: List[Expression], + context: Context, + ) -> Type: if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]): # ex: Point(x=42, y=1337) assert all(arg_name is not None for arg_name in arg_names) item_names = cast(List[str], arg_names) item_args = args return self.check_typeddict_call_with_kwargs( - callee, OrderedDict(zip(item_names, item_args)), context) + callee, OrderedDict(zip(item_names, item_args)), context + ) if len(args) == 1 and arg_kinds[0] == ARG_POS: unique_arg = args[0] @@ -527,14 +637,14 @@ def check_typeddict_call(self, callee: TypedDictType, if len(args) == 0: # ex: EmptyDict() - return self.check_typeddict_call_with_kwargs( - callee, OrderedDict(), context) + return self.check_typeddict_call_with_kwargs(callee, OrderedDict(), context) self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context) return AnyType(TypeOfAny.from_error) def validate_typeddict_kwargs( - self, kwargs: DictExpr) -> 'Optional[OrderedDict[str, Expression]]': + self, kwargs: DictExpr + ) -> "Optional[OrderedDict[str, Expression]]": item_args = [item[1] for item in kwargs.items] item_names = [] # List[str] @@ -547,58 +657,59 @@ def validate_typeddict_kwargs( literal_value = values[0] if literal_value is None: key_context = item_name_expr or item_arg - self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, - key_context) + self.chk.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, key_context) return None else: item_names.append(literal_value) return OrderedDict(zip(item_names, item_args)) - def match_typeddict_call_with_dict(self, callee: TypedDictType, - kwargs: DictExpr, - context: Context) -> bool: + def match_typeddict_call_with_dict( + self, callee: TypedDictType, kwargs: DictExpr, context: Context + ) -> bool: validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) if validated_kwargs is not None: - return (callee.required_keys <= set(validated_kwargs.keys()) - <= set(callee.items.keys())) + return callee.required_keys <= set(validated_kwargs.keys()) <= set(callee.items.keys()) else: return False - def check_typeddict_call_with_dict(self, callee: TypedDictType, - kwargs: DictExpr, - context: Context) -> Type: + def check_typeddict_call_with_dict( + self, callee: TypedDictType, kwargs: DictExpr, context: Context + ) -> Type: validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs) if validated_kwargs is not None: return self.check_typeddict_call_with_kwargs( - callee, - kwargs=validated_kwargs, - context=context) + callee, kwargs=validated_kwargs, context=context + ) else: return AnyType(TypeOfAny.from_error) - def check_typeddict_call_with_kwargs(self, callee: TypedDictType, - kwargs: 'OrderedDict[str, Expression]', - context: Context) -> Type: + def check_typeddict_call_with_kwargs( + self, callee: TypedDictType, kwargs: "OrderedDict[str, Expression]", context: Context + ) -> Type: if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())): - expected_keys = [key for key in callee.items.keys() - if key in callee.required_keys or key in kwargs.keys()] + expected_keys = [ + key + for key in callee.items.keys() + if key in callee.required_keys or key in kwargs.keys() + ] actual_keys = kwargs.keys() self.msg.unexpected_typeddict_keys( - callee, - expected_keys=expected_keys, - actual_keys=list(actual_keys), - context=context) + callee, expected_keys=expected_keys, actual_keys=list(actual_keys), context=context + ) return AnyType(TypeOfAny.from_error) for (item_name, item_expected_type) in callee.items.items(): if item_name in kwargs: item_value = kwargs[item_name] self.chk.check_simple_assignment( - lvalue_type=item_expected_type, rvalue=item_value, context=item_value, + lvalue_type=item_expected_type, + rvalue=item_value, + context=item_value, msg=message_registry.INCOMPATIBLE_TYPES, lvalue_name=f'TypedDict item "{item_name}"', - rvalue_name='expression', - code=codes.TYPEDDICT_ITEM) + rvalue_name="expression", + code=codes.TYPEDDICT_ITEM, + ) return callee @@ -608,8 +719,11 @@ def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]: If the expression is not a self attribute, or attribute is not variable, or variable is not partial, return None. """ - if not (isinstance(expr.expr, NameExpr) and - isinstance(expr.expr.node, Var) and expr.expr.node.is_self): + if not ( + isinstance(expr.expr, NameExpr) + and isinstance(expr.expr.node, Var) + and expr.expr.node.is_self + ): # Not a self.attr expression. return None info = self.chk.scope.enclosing_class() @@ -669,8 +783,7 @@ def try_infer_partial_type(self, e: CallExpr) -> None: # Store inferred partial type. assert partial_type.type is not None typename = partial_type.type.fullname - var.type = self.chk.named_generic_type(typename, - [key_type, value_type]) + var.type = self.chk.named_generic_type(typename, [key_type, value_type]) del partial_types[var] def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context]]]: @@ -685,10 +798,8 @@ def get_partial_var(self, ref: RefExpr) -> Optional[Tuple[Var, Dict[Var, Context return var, partial_types def try_infer_partial_value_type_from_call( - self, - e: CallExpr, - methodname: str, - var: Var) -> Optional[Instance]: + self, e: CallExpr, methodname: str, var: Var + ) -> Optional[Instance]: """Try to make partial type precise from a call such as 'x.append(y)'.""" if self.chk.current_node_deferred: return None @@ -702,37 +813,45 @@ def try_infer_partial_value_type_from_call( typename = partial_type.type.fullname # Sometimes we can infer a full type for a partial List, Dict or Set type. # TODO: Don't infer argument expression twice. - if (typename in self.item_args and methodname in self.item_args[typename] - and e.arg_kinds == [ARG_POS]): + if ( + typename in self.item_args + and methodname in self.item_args[typename] + and e.arg_kinds == [ARG_POS] + ): item_type = self.accept(e.args[0]) if mypy.checker.is_valid_inferred_type(item_type): return self.chk.named_generic_type(typename, [item_type]) - elif (typename in self.container_args - and methodname in self.container_args[typename] - and e.arg_kinds == [ARG_POS]): + elif ( + typename in self.container_args + and methodname in self.container_args[typename] + and e.arg_kinds == [ARG_POS] + ): arg_type = get_proper_type(self.accept(e.args[0])) if isinstance(arg_type, Instance): arg_typename = arg_type.type.fullname if arg_typename in self.container_args[typename][methodname]: - if all(mypy.checker.is_valid_inferred_type(item_type) - for item_type in arg_type.args): - return self.chk.named_generic_type(typename, - list(arg_type.args)) + if all( + mypy.checker.is_valid_inferred_type(item_type) + for item_type in arg_type.args + ): + return self.chk.named_generic_type(typename, list(arg_type.args)) elif isinstance(arg_type, AnyType): return self.chk.named_type(typename) return None - def apply_function_plugin(self, - callee: CallableType, - arg_kinds: List[ArgKind], - arg_types: List[Type], - arg_names: Optional[Sequence[Optional[str]]], - formal_to_actual: List[List[int]], - args: List[Expression], - fullname: str, - object_type: Optional[Type], - context: Context) -> Type: + def apply_function_plugin( + self, + callee: CallableType, + arg_kinds: List[ArgKind], + arg_types: List[Type], + arg_names: Optional[Sequence[Optional[str]]], + formal_to_actual: List[List[int]], + args: List[Expression], + fullname: str, + object_type: Optional[Type], + context: Context, + ) -> Type: """Use special case logic to infer the return type of a specific named function/method. Caller must ensure that a plugin hook exists. There are two different cases: @@ -762,34 +881,54 @@ def apply_function_plugin(self, callback = self.plugin.get_function_hook(fullname) assert callback is not None # Assume that caller ensures this return callback( - FunctionContext(formal_arg_types, formal_arg_kinds, - callee.arg_names, formal_arg_names, - callee.ret_type, formal_arg_exprs, context, self.chk)) + FunctionContext( + formal_arg_types, + formal_arg_kinds, + callee.arg_names, + formal_arg_names, + callee.ret_type, + formal_arg_exprs, + context, + self.chk, + ) + ) else: # Apply method plugin method_callback = self.plugin.get_method_hook(fullname) assert method_callback is not None # Assume that caller ensures this object_type = get_proper_type(object_type) return method_callback( - MethodContext(object_type, formal_arg_types, formal_arg_kinds, - callee.arg_names, formal_arg_names, - callee.ret_type, formal_arg_exprs, context, self.chk)) + MethodContext( + object_type, + formal_arg_types, + formal_arg_kinds, + callee.arg_names, + formal_arg_names, + callee.ret_type, + formal_arg_exprs, + context, + self.chk, + ) + ) def apply_signature_hook( - self, callee: FunctionLike, args: List[Expression], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - hook: Callable[ - [List[List[Expression]], CallableType], - FunctionLike, - ]) -> FunctionLike: + self, + callee: FunctionLike, + args: List[Expression], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + hook: Callable[[List[List[Expression]], CallableType], FunctionLike], + ) -> FunctionLike: """Helper to apply a signature hook for either a function or method""" if isinstance(callee, CallableType): num_formals = len(callee.arg_kinds) formal_to_actual = map_actuals_to_formals( - arg_kinds, arg_names, - callee.arg_kinds, callee.arg_names, - lambda i: self.accept(args[i])) + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) formal_arg_exprs: List[List[Expression]] = [[] for _ in range(num_formals)] for formal, actuals in enumerate(formal_to_actual): for actual in actuals: @@ -799,40 +938,63 @@ def apply_signature_hook( assert isinstance(callee, Overloaded) items = [] for item in callee.items: - adjusted = self.apply_signature_hook( - item, args, arg_kinds, arg_names, hook) + adjusted = self.apply_signature_hook(item, args, arg_kinds, arg_names, hook) assert isinstance(adjusted, CallableType) items.append(adjusted) return Overloaded(items) def apply_function_signature_hook( - self, callee: FunctionLike, args: List[Expression], - arg_kinds: List[ArgKind], context: Context, - arg_names: Optional[Sequence[Optional[str]]], - signature_hook: Callable[[FunctionSigContext], FunctionLike]) -> FunctionLike: + self, + callee: FunctionLike, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + arg_names: Optional[Sequence[Optional[str]]], + signature_hook: Callable[[FunctionSigContext], FunctionLike], + ) -> FunctionLike: """Apply a plugin hook that may infer a more precise signature for a function.""" return self.apply_signature_hook( - callee, args, arg_kinds, arg_names, - (lambda args, sig: - signature_hook(FunctionSigContext(args, sig, context, self.chk)))) + callee, + args, + arg_kinds, + arg_names, + (lambda args, sig: signature_hook(FunctionSigContext(args, sig, context, self.chk))), + ) def apply_method_signature_hook( - self, callee: FunctionLike, args: List[Expression], - arg_kinds: List[ArgKind], context: Context, - arg_names: Optional[Sequence[Optional[str]]], object_type: Type, - signature_hook: Callable[[MethodSigContext], FunctionLike]) -> FunctionLike: + self, + callee: FunctionLike, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + arg_names: Optional[Sequence[Optional[str]]], + object_type: Type, + signature_hook: Callable[[MethodSigContext], FunctionLike], + ) -> FunctionLike: """Apply a plugin hook that may infer a more precise signature for a method.""" pobject_type = get_proper_type(object_type) return self.apply_signature_hook( - callee, args, arg_kinds, arg_names, - (lambda args, sig: - signature_hook(MethodSigContext(pobject_type, args, sig, context, self.chk)))) + callee, + args, + arg_kinds, + arg_names, + ( + lambda args, sig: signature_hook( + MethodSigContext(pobject_type, args, sig, context, self.chk) + ) + ), + ) def transform_callee_type( - self, callable_name: Optional[str], callee: Type, args: List[Expression], - arg_kinds: List[ArgKind], context: Context, - arg_names: Optional[Sequence[Optional[str]]] = None, - object_type: Optional[Type] = None) -> Type: + self, + callable_name: Optional[str], + callee: Type, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + arg_names: Optional[Sequence[Optional[str]]] = None, + object_type: Optional[Type] = None, + ) -> Type: """Attempt to determine a more accurate signature for a method call. This is done by looking up and applying a method signature hook (if one exists for the @@ -853,21 +1015,25 @@ def transform_callee_type( method_sig_hook = self.plugin.get_method_signature_hook(callable_name) if method_sig_hook: return self.apply_method_signature_hook( - callee, args, arg_kinds, context, arg_names, object_type, method_sig_hook) + callee, args, arg_kinds, context, arg_names, object_type, method_sig_hook + ) else: function_sig_hook = self.plugin.get_function_signature_hook(callable_name) if function_sig_hook: return self.apply_function_signature_hook( - callee, args, arg_kinds, context, arg_names, function_sig_hook) + callee, args, arg_kinds, context, arg_names, function_sig_hook + ) return callee - def check_call_expr_with_callee_type(self, - callee_type: Type, - e: CallExpr, - callable_name: Optional[str], - object_type: Optional[Type], - member: Optional[str] = None) -> Type: + def check_call_expr_with_callee_type( + self, + callee_type: Type, + e: CallExpr, + callable_name: Optional[str], + object_type: Optional[Type], + member: Optional[str] = None, + ) -> Type: """Type check call expression. The callee_type should be used as the type of callee expression. In particular, @@ -886,52 +1052,71 @@ def check_call_expr_with_callee_type(self, if callable_name: # Try to refine the call signature using plugin hooks before checking the call. callee_type = self.transform_callee_type( - callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type) + callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type + ) # Unions are special-cased to allow plugins to act on each item in the union. elif member is not None and isinstance(object_type, UnionType): return self.check_union_call_expr(e, object_type, member) ret_type, callee_type = self.check_call( - callee_type, e.args, e.arg_kinds, e, - e.arg_names, callable_node=e.callee, + callee_type, + e.args, + e.arg_kinds, + e, + e.arg_names, + callable_node=e.callee, callable_name=callable_name, object_type=object_type, ) proper_callee = get_proper_type(callee_type) - if (isinstance(e.callee, RefExpr) - and isinstance(proper_callee, CallableType) - and proper_callee.type_guard is not None): + if ( + isinstance(e.callee, RefExpr) + and isinstance(proper_callee, CallableType) + and proper_callee.type_guard is not None + ): # Cache it for find_isinstance_check() e.callee.type_guard = proper_callee.type_guard return ret_type def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type: - """"Type check calling a member expression where the base type is a union.""" + """ "Type check calling a member expression where the base type is a union.""" res: List[Type] = [] for typ in object_type.relevant_items(): # Member access errors are already reported when visiting the member expression. with self.msg.filter_errors(): - item = analyze_member_access(member, typ, e, False, False, False, - self.msg, original_type=object_type, chk=self.chk, - in_literal_context=self.is_literal_context(), - self_type=typ) + item = analyze_member_access( + member, + typ, + e, + False, + False, + False, + self.msg, + original_type=object_type, + chk=self.chk, + in_literal_context=self.is_literal_context(), + self_type=typ, + ) narrowed = self.narrow_type_from_binder(e.callee, item, skip_non_overlapping=True) if narrowed is None: continue callable_name = self.method_fullname(typ, member) item_object_type = typ if callable_name else None - res.append(self.check_call_expr_with_callee_type(narrowed, e, callable_name, - item_object_type)) + res.append( + self.check_call_expr_with_callee_type(narrowed, e, callable_name, item_object_type) + ) return make_simplified_union(res) - def check_call(self, - callee: Type, - args: List[Expression], - arg_kinds: List[ArgKind], - context: Context, - arg_names: Optional[Sequence[Optional[str]]] = None, - callable_node: Optional[Expression] = None, - callable_name: Optional[str] = None, - object_type: Optional[Type] = None) -> Tuple[Type, Type]: + def check_call( + self, + callee: Type, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + arg_names: Optional[Sequence[Optional[str]]] = None, + callable_node: Optional[Expression] = None, + callable_name: Optional[str] = None, + object_type: Optional[Type] = None, + ) -> Tuple[Type, Type]: """Type check a call. Also infer type arguments if the callee is a generic function. @@ -955,54 +1140,89 @@ def check_call(self, callee = get_proper_type(callee) if isinstance(callee, CallableType): - return self.check_callable_call(callee, args, arg_kinds, context, arg_names, - callable_node, callable_name, object_type) + return self.check_callable_call( + callee, + args, + arg_kinds, + context, + arg_names, + callable_node, + callable_name, + object_type, + ) elif isinstance(callee, Overloaded): - return self.check_overload_call(callee, args, arg_kinds, arg_names, callable_name, - object_type, context) + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): return self.check_union_call(callee, args, arg_kinds, arg_names, context) elif isinstance(callee, Instance): - call_function = analyze_member_access('__call__', callee, context, is_lvalue=False, - is_super=False, is_operator=True, msg=self.msg, - original_type=callee, chk=self.chk, - in_literal_context=self.is_literal_context()) + call_function = analyze_member_access( + "__call__", + callee, + context, + is_lvalue=False, + is_super=False, + is_operator=True, + msg=self.msg, + original_type=callee, + chk=self.chk, + in_literal_context=self.is_literal_context(), + ) callable_name = callee.type.fullname + ".__call__" # Apply method signature hook, if one exists call_function = self.transform_callee_type( - callable_name, call_function, args, arg_kinds, context, arg_names, callee) - result = self.check_call(call_function, args, arg_kinds, context, arg_names, - callable_node, callable_name, callee) + callable_name, call_function, args, arg_kinds, context, arg_names, callee + ) + result = self.check_call( + call_function, + args, + arg_kinds, + context, + arg_names, + callable_node, + callable_name, + callee, + ) if callable_node: # check_call() stored "call_function" as the type, which is incorrect. # Override the type. self.chk.store_type(callable_node, callee) return result elif isinstance(callee, TypeVarType): - return self.check_call(callee.upper_bound, args, arg_kinds, context, arg_names, - callable_node) + return self.check_call( + callee.upper_bound, args, arg_kinds, context, arg_names, callable_node + ) elif isinstance(callee, TypeType): item = self.analyze_type_type_callee(callee.item, context) - return self.check_call(item, args, arg_kinds, context, arg_names, - callable_node) + return self.check_call(item, args, arg_kinds, context, arg_names, callable_node) elif isinstance(callee, TupleType): - return self.check_call(tuple_fallback(callee), args, arg_kinds, context, - arg_names, callable_node, callable_name, - object_type) + return self.check_call( + tuple_fallback(callee), + args, + arg_kinds, + context, + arg_names, + callable_node, + callable_name, + object_type, + ) else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) - def check_callable_call(self, - callee: CallableType, - args: List[Expression], - arg_kinds: List[ArgKind], - context: Context, - arg_names: Optional[Sequence[Optional[str]]], - callable_node: Optional[Expression], - callable_name: Optional[str], - object_type: Optional[Type]) -> Tuple[Type, Type]: + def check_callable_call( + self, + callee: CallableType, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + arg_names: Optional[Sequence[Optional[str]]], + callable_node: Optional[Expression], + callable_name: Optional[str], + object_type: Optional[Type], + ) -> Tuple[Type, Type]: """Type check a call that targets a callable value. See the docstring of check_call for more information. @@ -1016,76 +1236,104 @@ def check_callable_call(self, # An Enum() call that failed SemanticAnalyzerPass2.check_enum_call(). return callee.ret_type, callee - if (callee.is_type_obj() and callee.type_object().is_abstract - # Exception for Type[...] - and not callee.from_type_type - and not callee.type_object().fallback_to_any): + if ( + callee.is_type_obj() + and callee.type_object().is_abstract + # Exception for Type[...] + and not callee.from_type_type + and not callee.type_object().fallback_to_any + ): type = callee.type_object() self.msg.cannot_instantiate_abstract_class( - callee.type_object().name, type.abstract_attributes, - context) - elif (callee.is_type_obj() and callee.type_object().is_protocol - # Exception for Type[...] - and not callee.from_type_type): - self.chk.fail(message_registry.CANNOT_INSTANTIATE_PROTOCOL - .format(callee.type_object().name), context) + callee.type_object().name, type.abstract_attributes, context + ) + elif ( + callee.is_type_obj() + and callee.type_object().is_protocol + # Exception for Type[...] + and not callee.from_type_type + ): + self.chk.fail( + message_registry.CANNOT_INSTANTIATE_PROTOCOL.format(callee.type_object().name), + context, + ) formal_to_actual = map_actuals_to_formals( - arg_kinds, arg_names, - callee.arg_kinds, callee.arg_names, - lambda i: self.accept(args[i])) + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) if callee.is_generic(): need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables) callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context( - callee, context) + callee = self.infer_function_type_arguments_using_context(callee, context) callee = self.infer_function_type_arguments( - callee, args, arg_kinds, formal_to_actual, context) + callee, args, arg_kinds, formal_to_actual, context + ) if need_refresh: # Argument kinds etc. may have changed due to # ParamSpec variables being replaced with an arbitrary # number of arguments; recalculate actual-to-formal map formal_to_actual = map_actuals_to_formals( - arg_kinds, arg_names, - callee.arg_kinds, callee.arg_names, - lambda i: self.accept(args[i])) + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) param_spec = callee.param_spec() if param_spec is not None and arg_kinds == [ARG_STAR, ARG_STAR2]: arg1 = self.accept(args[0]) arg2 = self.accept(args[1]) - if (isinstance(arg1, ParamSpecType) - and isinstance(arg2, ParamSpecType) - and arg1.flavor == ParamSpecFlavor.ARGS - and arg2.flavor == ParamSpecFlavor.KWARGS - and arg1.id == arg2.id == param_spec.id): + if ( + isinstance(arg1, ParamSpecType) + and isinstance(arg2, ParamSpecType) + and arg1.flavor == ParamSpecFlavor.ARGS + and arg2.flavor == ParamSpecFlavor.KWARGS + and arg1.id == arg2.id == param_spec.id + ): return callee.ret_type, callee - arg_types = self.infer_arg_types_in_context( - callee, args, arg_kinds, formal_to_actual) + arg_types = self.infer_arg_types_in_context(callee, args, arg_kinds, formal_to_actual) - self.check_argument_count(callee, arg_types, arg_kinds, - arg_names, formal_to_actual, context) + self.check_argument_count( + callee, arg_types, arg_kinds, arg_names, formal_to_actual, context + ) - self.check_argument_types(arg_types, arg_kinds, args, callee, formal_to_actual, context, - object_type=object_type) + self.check_argument_types( + arg_types, arg_kinds, args, callee, formal_to_actual, context, object_type=object_type + ) - if (callee.is_type_obj() and (len(arg_types) == 1) - and is_equivalent(callee.ret_type, self.named_type('builtins.type'))): + if ( + callee.is_type_obj() + and (len(arg_types) == 1) + and is_equivalent(callee.ret_type, self.named_type("builtins.type")) + ): callee = callee.copy_modified(ret_type=TypeType.make_normalized(arg_types[0])) if callable_node: # Store the inferred callable type. self.chk.store_type(callable_node, callee) - if (callable_name - and ((object_type is None and self.plugin.get_function_hook(callable_name)) - or (object_type is not None - and self.plugin.get_method_hook(callable_name)))): + if callable_name and ( + (object_type is None and self.plugin.get_function_hook(callable_name)) + or (object_type is not None and self.plugin.get_method_hook(callable_name)) + ): new_ret_type = self.apply_function_plugin( - callee, arg_kinds, arg_types, arg_names, formal_to_actual, args, - callable_name, object_type, context) + callee, + arg_kinds, + arg_types, + arg_names, + formal_to_actual, + args, + callable_name, + object_type, + context, + ) callee = callee.copy_modified(ret_type=new_ret_type) return callee.ret_type, callee @@ -1107,8 +1355,13 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: expanded = expanded.copy_modified(variables=[]) return expanded if isinstance(item, UnionType): - return UnionType([self.analyze_type_type_callee(get_proper_type(tp), context) - for tp in item.relevant_items()], item.line) + return UnionType( + [ + self.analyze_type_type_callee(get_proper_type(tp), context) + for tp in item.relevant_items() + ], + item.line, + ) if isinstance(item, TypeVarType): # Pretend we're calling the typevar's upper bound, # i.e. its constructor (a poor approximation for reality, @@ -1119,12 +1372,10 @@ def analyze_type_type_callee(self, item: ProperType, context: Context) -> Type: if isinstance(callee, CallableType): callee = callee.copy_modified(ret_type=item) elif isinstance(callee, Overloaded): - callee = Overloaded([c.copy_modified(ret_type=item) - for c in callee.items]) + callee = Overloaded([c.copy_modified(ret_type=item) for c in callee.items]) return callee # We support Type of namedtuples but not of tuples in general - if (isinstance(item, TupleType) - and tuple_fallback(item).type.fullname != 'builtins.tuple'): + if isinstance(item, TupleType) and tuple_fallback(item).type.fullname != "builtins.tuple": return self.analyze_type_type_callee(tuple_fallback(item), context) self.msg.unsupported_type_type(item, context) @@ -1147,8 +1398,12 @@ def infer_arg_types_in_empty_context(self, args: List[Expression]) -> List[Type] return res def infer_arg_types_in_context( - self, callee: CallableType, args: List[Expression], arg_kinds: List[ArgKind], - formal_to_actual: List[List[int]]) -> List[Type]: + self, + callee: CallableType, + args: List[Expression], + arg_kinds: List[ArgKind], + formal_to_actual: List[List[int]], + ) -> List[Type]: """Infer argument expression types using a callable type as context. For example, if callee argument 2 has type List[int], infer the @@ -1171,7 +1426,8 @@ def infer_arg_types_in_context( return cast(List[Type], res) def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context) -> CallableType: + self, callable: CallableType, error_context: Context + ) -> CallableType: """Unify callable return type to type context to infer type vars. For example, if the return type is set[t] where 't' is a type variable @@ -1241,14 +1497,18 @@ def infer_function_type_arguments_using_context( new_args.append(arg) # Don't show errors after we have only used the outer context for inference. # We will use argument context to infer more variables. - return self.apply_generic_arguments(callable, new_args, error_context, - skip_unsatisfied=True) - - def infer_function_type_arguments(self, callee_type: CallableType, - args: List[Expression], - arg_kinds: List[ArgKind], - formal_to_actual: List[List[int]], - context: Context) -> CallableType: + return self.apply_generic_arguments( + callable, new_args, error_context, skip_unsatisfied=True + ) + + def infer_function_type_arguments( + self, + callee_type: CallableType, + args: List[Expression], + arg_kinds: List[ArgKind], + formal_to_actual: List[List[int]], + context: Context, + ) -> CallableType: """Infer the type arguments for a generic callee type. Infer based on the types of arguments. @@ -1262,10 +1522,12 @@ def infer_function_type_arguments(self, callee_type: CallableType, # inferred again later. with self.msg.filter_errors(): arg_types = self.infer_arg_types_in_context( - callee_type, args, arg_kinds, formal_to_actual) + callee_type, args, arg_kinds, formal_to_actual + ) arg_pass_nums = self.get_arg_infer_passes( - callee_type.arg_types, formal_to_actual, len(args)) + callee_type.arg_types, formal_to_actual, len(args) + ) pass1_args: List[Optional[Type]] = [] for i, arg in enumerate(arg_types): @@ -1275,19 +1537,25 @@ def infer_function_type_arguments(self, callee_type: CallableType, pass1_args.append(arg) inferred_args = infer_function_type_arguments( - callee_type, pass1_args, arg_kinds, formal_to_actual, + callee_type, + pass1_args, + arg_kinds, + formal_to_actual, context=self.argument_infer_context(), - strict=self.chk.in_checked_function()) + strict=self.chk.in_checked_function(), + ) if 2 in arg_pass_nums: # Second pass of type inference. - (callee_type, - inferred_args) = self.infer_function_type_arguments_pass2( - callee_type, args, arg_kinds, formal_to_actual, - inferred_args, context) + (callee_type, inferred_args) = self.infer_function_type_arguments_pass2( + callee_type, args, arg_kinds, formal_to_actual, inferred_args, context + ) - if callee_type.special_sig == 'dict' and len(inferred_args) == 2 and ( - ARG_NAMED in arg_kinds or ARG_STAR2 in arg_kinds): + if ( + callee_type.special_sig == "dict" + and len(inferred_args) == 2 + and (ARG_NAMED in arg_kinds or ARG_STAR2 in arg_kinds) + ): # HACK: Infer str key type for dict(...) with keyword args. The type system # can't represent this so we special case it, as this is a pretty common # thing. This doesn't quite work with all possible subclasses of dict @@ -1296,24 +1564,24 @@ def infer_function_type_arguments(self, callee_type: CallableType, # a little tricky to fix so it's left unfixed for now. first_arg = get_proper_type(inferred_args[0]) if isinstance(first_arg, (NoneType, UninhabitedType)): - inferred_args[0] = self.named_type('builtins.str') - elif not first_arg or not is_subtype(self.named_type('builtins.str'), first_arg): - self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, - context) + inferred_args[0] = self.named_type("builtins.str") + elif not first_arg or not is_subtype(self.named_type("builtins.str"), first_arg): + self.chk.fail(message_registry.KEYWORD_ARGUMENT_REQUIRES_STR_KEY_TYPE, context) else: # In dynamically typed functions use implicit 'Any' types for # type variables. inferred_args = [AnyType(TypeOfAny.unannotated)] * len(callee_type.variables) - return self.apply_inferred_arguments(callee_type, inferred_args, - context) + return self.apply_inferred_arguments(callee_type, inferred_args, context) def infer_function_type_arguments_pass2( - self, callee_type: CallableType, - args: List[Expression], - arg_kinds: List[ArgKind], - formal_to_actual: List[List[int]], - old_inferred_args: Sequence[Optional[Type]], - context: Context) -> Tuple[CallableType, List[Optional[Type]]]: + self, + callee_type: CallableType, + args: List[Expression], + arg_kinds: List[ArgKind], + formal_to_actual: List[List[int]], + old_inferred_args: Sequence[Optional[Type]], + context: Context, + ) -> Tuple[CallableType, List[Optional[Type]]]: """Perform second pass of generic function type argument inference. The second pass is needed for arguments with types such as Callable[[T], S], @@ -1334,11 +1602,13 @@ def infer_function_type_arguments_pass2( inferred_args[i] = None callee_type = self.apply_generic_arguments(callee_type, inferred_args, context) - arg_types = self.infer_arg_types_in_context( - callee_type, args, arg_kinds, formal_to_actual) + arg_types = self.infer_arg_types_in_context(callee_type, args, arg_kinds, formal_to_actual) inferred_args = infer_function_type_arguments( - callee_type, arg_types, arg_kinds, formal_to_actual, + callee_type, + arg_types, + arg_kinds, + formal_to_actual, context=self.argument_infer_context(), ) @@ -1346,13 +1616,12 @@ def infer_function_type_arguments_pass2( def argument_infer_context(self) -> ArgumentInferContext: return ArgumentInferContext( - self.chk.named_type('typing.Mapping'), - self.chk.named_type('typing.Iterable'), + self.chk.named_type("typing.Mapping"), self.chk.named_type("typing.Iterable") ) - def get_arg_infer_passes(self, arg_types: List[Type], - formal_to_actual: List[List[int]], - num_actuals: int) -> List[int]: + def get_arg_infer_passes( + self, arg_types: List[Type], formal_to_actual: List[List[int]], num_actuals: int + ) -> List[int]: """Return pass numbers for args for two-pass argument type inference. For each actual, the pass number is either 1 (first pass) or 2 (second @@ -1368,9 +1637,9 @@ def get_arg_infer_passes(self, arg_types: List[Type], res[j] = 2 return res - def apply_inferred_arguments(self, callee_type: CallableType, - inferred_args: Sequence[Optional[Type]], - context: Context) -> CallableType: + def apply_inferred_arguments( + self, callee_type: CallableType, inferred_args: Sequence[Optional[Type]], context: Context + ) -> CallableType: """Apply inferred values of type arguments to a generic function. Inferred_args contains the values of function type arguments. @@ -1381,21 +1650,22 @@ def apply_inferred_arguments(self, callee_type: CallableType, for i, inferred_type in enumerate(inferred_args): if not inferred_type or has_erased_component(inferred_type): # Could not infer a non-trivial type for a type variable. - self.msg.could_not_infer_type_arguments( - callee_type, i + 1, context) + self.msg.could_not_infer_type_arguments(callee_type, i + 1, context) inferred_args = [AnyType(TypeOfAny.from_error)] * len(inferred_args) # Apply the inferred types to the function type. In this case the # return type must be CallableType, since we give the right number of type # arguments. return self.apply_generic_arguments(callee_type, inferred_args, context) - def check_argument_count(self, - callee: CallableType, - actual_types: List[Type], - actual_kinds: List[ArgKind], - actual_names: Optional[Sequence[Optional[str]]], - formal_to_actual: List[List[int]], - context: Optional[Context]) -> bool: + def check_argument_count( + self, + callee: CallableType, + actual_types: List[Type], + actual_kinds: List[ArgKind], + actual_names: Optional[Sequence[Optional[str]]], + formal_to_actual: List[List[int]], + context: Optional[Context], + ) -> bool: """Check that there is a value for all required arguments to a function. Also check that there are no duplicate values for arguments. Report found errors @@ -1416,7 +1686,8 @@ def check_argument_count(self, all_actuals[a] = all_actuals.get(a, 0) + 1 ok, is_unexpected_arg_error = self.check_for_extra_actual_arguments( - callee, actual_types, actual_kinds, actual_names, all_actuals, context) + callee, actual_types, actual_kinds, actual_names, all_actuals, context + ) # Check for too many or few values for formals. for i, kind in enumerate(callee.arg_kinds): @@ -1429,26 +1700,32 @@ def check_argument_count(self, self.msg.missing_named_argument(callee, context, argname) ok = False elif not kind.is_star() and is_duplicate_mapping( - formal_to_actual[i], actual_types, actual_kinds): - if (self.chk.in_checked_function() or - isinstance(get_proper_type(actual_types[formal_to_actual[i][0]]), - TupleType)): + formal_to_actual[i], actual_types, actual_kinds + ): + if self.chk.in_checked_function() or isinstance( + get_proper_type(actual_types[formal_to_actual[i][0]]), TupleType + ): self.msg.duplicate_argument_value(callee, i, context) ok = False - elif (kind.is_named() and formal_to_actual[i] and - actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2]): + elif ( + kind.is_named() + and formal_to_actual[i] + and actual_kinds[formal_to_actual[i][0]] not in [nodes.ARG_NAMED, nodes.ARG_STAR2] + ): # Positional argument when expecting a keyword argument. self.msg.too_many_positional_arguments(callee, context) ok = False return ok - def check_for_extra_actual_arguments(self, - callee: CallableType, - actual_types: List[Type], - actual_kinds: List[ArgKind], - actual_names: Optional[Sequence[Optional[str]]], - all_actuals: Dict[int, int], - context: Context) -> Tuple[bool, bool]: + def check_for_extra_actual_arguments( + self, + callee: CallableType, + actual_types: List[Type], + actual_kinds: List[ArgKind], + actual_names: Optional[Sequence[Optional[str]]], + all_actuals: Dict[int, int], + context: Context, + ) -> Tuple[bool, bool]: """Check for extra actual arguments. Return tuple (was everything ok, @@ -1459,13 +1736,17 @@ def check_for_extra_actual_arguments(self, ok = True # False if we've found any error for i, kind in enumerate(actual_kinds): - if (i not in all_actuals and - # We accept the other iterables than tuple (including Any) - # as star arguments because they could be empty, resulting no arguments. - (kind != nodes.ARG_STAR or is_non_empty_tuple(actual_types[i])) and - # Accept all types for double-starred arguments, because they could be empty - # dictionaries and we can't tell it from their types - kind != nodes.ARG_STAR2): + if ( + i not in all_actuals + and + # We accept the other iterables than tuple (including Any) + # as star arguments because they could be empty, resulting no arguments. + (kind != nodes.ARG_STAR or is_non_empty_tuple(actual_types[i])) + and + # Accept all types for double-starred arguments, because they could be empty + # dictionaries and we can't tell it from their types + kind != nodes.ARG_STAR2 + ): # Extra actual: not matched by a formal argument. ok = False if kind != nodes.ARG_NAMED: @@ -1477,18 +1758,19 @@ def check_for_extra_actual_arguments(self, act_type = actual_types[i] self.msg.unexpected_keyword_argument(callee, act_name, act_type, context) is_unexpected_arg_error = True - elif ((kind == nodes.ARG_STAR and nodes.ARG_STAR not in callee.arg_kinds) - or kind == nodes.ARG_STAR2): + elif ( + kind == nodes.ARG_STAR and nodes.ARG_STAR not in callee.arg_kinds + ) or kind == nodes.ARG_STAR2: actual_type = get_proper_type(actual_types[i]) if isinstance(actual_type, (TupleType, TypedDictType)): if all_actuals.get(i, 0) < len(actual_type.items): # Too many tuple/dict items as some did not match. - if (kind != nodes.ARG_STAR2 - or not isinstance(actual_type, TypedDictType)): + if kind != nodes.ARG_STAR2 or not isinstance(actual_type, TypedDictType): self.msg.too_many_arguments(callee, context) else: - self.msg.too_many_arguments_from_typed_dict(callee, actual_type, - context) + self.msg.too_many_arguments_from_typed_dict( + callee, actual_type, context + ) is_unexpected_arg_error = True ok = False # *args/**kwargs can be applied even if the function takes a fixed @@ -1496,15 +1778,17 @@ def check_for_extra_actual_arguments(self, return ok, is_unexpected_arg_error - def check_argument_types(self, - arg_types: List[Type], - arg_kinds: List[ArgKind], - args: List[Expression], - callee: CallableType, - formal_to_actual: List[List[int]], - context: Context, - check_arg: Optional[ArgChecker] = None, - object_type: Optional[Type] = None) -> None: + def check_argument_types( + self, + arg_types: List[Type], + arg_kinds: List[ArgKind], + args: List[Expression], + callee: CallableType, + formal_to_actual: List[List[int]], + context: Context, + check_arg: Optional[ArgChecker] = None, + object_type: Optional[Type] = None, + ) -> None: """Check argument types against a callable type. Report errors if the argument types are not compatible. @@ -1521,31 +1805,42 @@ def check_argument_types(self, continue # Some kind of error was already reported. actual_kind = arg_kinds[actual] # Check that a *arg is valid as varargs. - if (actual_kind == nodes.ARG_STAR and - not self.is_valid_var_arg(actual_type)): + if actual_kind == nodes.ARG_STAR and not self.is_valid_var_arg(actual_type): self.msg.invalid_var_arg(actual_type, context) - if (actual_kind == nodes.ARG_STAR2 and - not self.is_valid_keyword_var_arg(actual_type)): - is_mapping = is_subtype(actual_type, self.chk.named_type('typing.Mapping')) + if actual_kind == nodes.ARG_STAR2 and not self.is_valid_keyword_var_arg( + actual_type + ): + is_mapping = is_subtype(actual_type, self.chk.named_type("typing.Mapping")) self.msg.invalid_keyword_var_arg(actual_type, is_mapping, context) expanded_actual = mapper.expand_actual_type( - actual_type, actual_kind, - callee.arg_names[i], callee.arg_kinds[i]) - check_arg(expanded_actual, actual_type, arg_kinds[actual], - callee.arg_types[i], - actual + 1, i + 1, callee, object_type, args[actual], context) - - def check_arg(self, - caller_type: Type, - original_caller_type: Type, - caller_kind: ArgKind, - callee_type: Type, - n: int, - m: int, - callee: CallableType, - object_type: Optional[Type], - context: Context, - outer_context: Context) -> None: + actual_type, actual_kind, callee.arg_names[i], callee.arg_kinds[i] + ) + check_arg( + expanded_actual, + actual_type, + arg_kinds[actual], + callee.arg_types[i], + actual + 1, + i + 1, + callee, + object_type, + args[actual], + context, + ) + + def check_arg( + self, + caller_type: Type, + original_caller_type: Type, + caller_kind: ArgKind, + callee_type: Type, + n: int, + m: int, + callee: CallableType, + object_type: Optional[Type], + context: Context, + outer_context: Context, + ) -> None: """Check the type of a single argument in a call.""" caller_type = get_proper_type(caller_type) original_caller_type = get_proper_type(original_caller_type) @@ -1554,40 +1849,49 @@ def check_arg(self, if isinstance(caller_type, DeletedType): self.msg.deleted_as_rvalue(caller_type, context) # Only non-abstract non-protocol class can be given where Type[...] is expected... - elif (isinstance(caller_type, CallableType) and isinstance(callee_type, TypeType) and - caller_type.is_type_obj() and - (caller_type.type_object().is_abstract or caller_type.type_object().is_protocol) and - isinstance(callee_type.item, Instance) and - (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol)): + elif ( + isinstance(caller_type, CallableType) + and isinstance(callee_type, TypeType) + and caller_type.is_type_obj() + and (caller_type.type_object().is_abstract or caller_type.type_object().is_protocol) + and isinstance(callee_type.item, Instance) + and (callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) + ): self.msg.concrete_only_call(callee_type, context) elif not is_subtype(caller_type, callee_type, options=self.chk.options): if self.chk.should_suppress_optional_error([caller_type, callee_type]): return - code = self.msg.incompatible_argument(n, - m, - callee, - original_caller_type, - caller_kind, - object_type=object_type, - context=context, - outer_context=outer_context) - self.msg.incompatible_argument_note(original_caller_type, callee_type, context, - code=code) + code = self.msg.incompatible_argument( + n, + m, + callee, + original_caller_type, + caller_kind, + object_type=object_type, + context=context, + outer_context=outer_context, + ) + self.msg.incompatible_argument_note( + original_caller_type, callee_type, context, code=code + ) self.chk.check_possible_missing_await(caller_type, callee_type, context) - def check_overload_call(self, - callee: Overloaded, - args: List[Expression], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - callable_name: Optional[str], - object_type: Optional[Type], - context: Context) -> Tuple[Type, Type]: + def check_overload_call( + self, + callee: Overloaded, + args: List[Expression], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + callable_name: Optional[str], + object_type: Optional[Type], + context: Context, + ) -> Tuple[Type, Type]: """Checks a call to an overloaded function.""" arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match - plausible_targets = self.plausible_overload_call_targets(arg_types, arg_kinds, - arg_names, callee) + plausible_targets = self.plausible_overload_call_targets( + arg_types, arg_kinds, arg_names, callee + ) # Step 2: If the arguments contain a union, we try performing union math first, # instead of picking the first matching overload. @@ -1600,10 +1904,16 @@ def check_overload_call(self, if any(self.real_union(arg) for arg in arg_types): try: with self.msg.filter_errors(): - unioned_return = self.union_overload_result(plausible_targets, args, - arg_types, arg_kinds, arg_names, - callable_name, object_type, - context) + unioned_return = self.union_overload_result( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) except TooManyUnions: union_interrupted = True else: @@ -1615,20 +1925,28 @@ def check_overload_call(self, # a union of inferred callables because for example a call # Union[int -> int, str -> str](Union[int, str]) is invalid and # we don't want to introduce internal inconsistencies. - unioned_result = (make_simplified_union(list(returns), - context.line, - context.column), - self.combine_function_signatures(inferred_types)) + unioned_result = ( + make_simplified_union(list(returns), context.line, context.column), + self.combine_function_signatures(inferred_types), + ) # Step 3: We try checking each branch one-by-one. - inferred_result = self.infer_overload_return_type(plausible_targets, args, arg_types, - arg_kinds, arg_names, callable_name, - object_type, context) + inferred_result = self.infer_overload_return_type( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) # If any of checks succeed, stop early. if inferred_result is not None and unioned_result is not None: # Both unioned and direct checks succeeded, choose the more precise type. - if (is_subtype(inferred_result[0], unioned_result[0]) and - not isinstance(get_proper_type(inferred_result[0]), AnyType)): + if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance( + get_proper_type(inferred_result[0]), AnyType + ): return inferred_result return unioned_result elif unioned_result is not None: @@ -1645,8 +1963,9 @@ def check_overload_call(self, # # Neither alternative matches, but we can guess the user probably wants the # second one. - erased_targets = self.overload_erased_call_targets(plausible_targets, arg_types, - arg_kinds, arg_names, args, context) + erased_targets = self.overload_erased_call_targets( + plausible_targets, arg_types, arg_kinds, arg_names, args, context + ) # Step 5: We try and infer a second-best alternative if possible. If not, fall back # to using 'Any'. @@ -1667,21 +1986,28 @@ def check_overload_call(self, code = None else: code = codes.OPERATOR - self.msg.no_variant_matches_arguments( - callee, arg_types, context, code=code) - - result = self.check_call(target, args, arg_kinds, context, arg_names, - callable_name=callable_name, - object_type=object_type) + self.msg.no_variant_matches_arguments(callee, arg_types, context, code=code) + + result = self.check_call( + target, + args, + arg_kinds, + context, + arg_names, + callable_name=callable_name, + object_type=object_type, + ) if union_interrupted: self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result - def plausible_overload_call_targets(self, - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - overload: Overloaded) -> List[CallableType]: + def plausible_overload_call_targets( + self, + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + overload: Overloaded, + ) -> List[CallableType]: """Returns all overload call targets that having matching argument counts. If the given args contains a star-arg (*arg or **kwarg argument), this method @@ -1695,8 +2021,11 @@ def plausible_overload_call_targets(self, def has_shape(typ: Type) -> bool: typ = get_proper_type(typ) - return (isinstance(typ, TupleType) or isinstance(typ, TypedDictType) - or (isinstance(typ, Instance) and typ.type.is_named_tuple)) + return ( + isinstance(typ, TupleType) + or isinstance(typ, TypedDictType) + or (isinstance(typ, Instance) and typ.type.is_named_tuple) + ) matches: List[CallableType] = [] star_matches: List[CallableType] = [] @@ -1710,13 +2039,14 @@ def has_shape(typ: Type) -> bool: args_have_kw_arg = True for typ in overload.items: - formal_to_actual = map_actuals_to_formals(arg_kinds, arg_names, - typ.arg_kinds, typ.arg_names, - lambda i: arg_types[i]) + formal_to_actual = map_actuals_to_formals( + arg_kinds, arg_names, typ.arg_kinds, typ.arg_names, lambda i: arg_types[i] + ) with self.msg.filter_errors(): - if self.check_argument_count(typ, arg_types, arg_kinds, arg_names, - formal_to_actual, None): + if self.check_argument_count( + typ, arg_types, arg_kinds, arg_names, formal_to_actual, None + ): if args_have_var_arg and typ.is_var_arg: star_matches.append(typ) elif args_have_kw_arg and typ.is_kw_arg: @@ -1726,16 +2056,17 @@ def has_shape(typ: Type) -> bool: return star_matches + matches - def infer_overload_return_type(self, - plausible_targets: List[CallableType], - args: List[Expression], - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - callable_name: Optional[str], - object_type: Optional[Type], - context: Context, - ) -> Optional[Tuple[Type, Type]]: + def infer_overload_return_type( + self, + plausible_targets: List[CallableType], + args: List[Expression], + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + callable_name: Optional[str], + object_type: Optional[Type], + context: Context, + ) -> Optional[Tuple[Type, Type]]: """Attempts to find the first matching callable from the given list. If a match is found, returns a tuple containing the result type and the inferred @@ -1763,7 +2094,8 @@ def infer_overload_return_type(self, arg_names=arg_names, context=context, callable_name=callable_name, - object_type=object_type) + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info so we can @@ -1788,47 +2120,53 @@ def infer_overload_return_type(self, self.chk.store_types(type_maps[0]) return erase_type(return_types[0]), erase_type(inferred_types[0]) else: - return self.check_call(callee=AnyType(TypeOfAny.special_form), - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type) + return self.check_call( + callee=AnyType(TypeOfAny.special_form), + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) else: # Success! No ambiguity; return the first match. self.chk.store_types(type_maps[0]) return return_types[0], inferred_types[0] - def overload_erased_call_targets(self, - plausible_targets: List[CallableType], - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - args: List[Expression], - context: Context) -> List[CallableType]: + def overload_erased_call_targets( + self, + plausible_targets: List[CallableType], + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + args: List[Expression], + context: Context, + ) -> List[CallableType]: """Returns a list of all targets that match the caller after erasing types. Assumes all of the given targets have argument counts compatible with the caller. """ matches: List[CallableType] = [] for typ in plausible_targets: - if self.erased_signature_similarity(arg_types, arg_kinds, arg_names, args, typ, - context): + if self.erased_signature_similarity( + arg_types, arg_kinds, arg_names, args, typ, context + ): matches.append(typ) return matches - def union_overload_result(self, - plausible_targets: List[CallableType], - args: List[Expression], - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - callable_name: Optional[str], - object_type: Optional[Type], - context: Context, - level: int = 0 - ) -> Optional[List[Tuple[Type, Type]]]: + def union_overload_result( + self, + plausible_targets: List[CallableType], + args: List[Expression], + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + callable_name: Optional[str], + object_type: Optional[Type], + context: Context, + level: int = 0, + ) -> Optional[List[Tuple[Type, Type]]]: """Accepts a list of overload signatures and attempts to match calls by destructuring the first union. @@ -1849,9 +2187,16 @@ def union_overload_result(self, else: # No unions in args, just fall back to normal inference with self.type_overrides_set(args, arg_types): - res = self.infer_overload_return_type(plausible_targets, args, arg_types, - arg_kinds, arg_names, callable_name, - object_type, context) + res = self.infer_overload_return_type( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) if res is not None: return [res] return None @@ -1859,11 +2204,17 @@ def union_overload_result(self, # Step 3: Try a direct match before splitting to avoid unnecessary union splits # and save performance. with self.type_overrides_set(args, arg_types): - direct = self.infer_overload_return_type(plausible_targets, args, arg_types, - arg_kinds, arg_names, callable_name, - object_type, context) - if direct is not None and not isinstance(get_proper_type(direct[0]), - (UnionType, AnyType)): + direct = self.infer_overload_return_type( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) + if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)): # We only return non-unions soon, to avoid greedy match. return [direct] @@ -1875,10 +2226,17 @@ def union_overload_result(self, for item in first_union.relevant_items(): new_arg_types = arg_types.copy() new_arg_types[idx] = item - sub_result = self.union_overload_result(plausible_targets, args, new_arg_types, - arg_kinds, arg_names, callable_name, - object_type, context, - level + 1) + sub_result = self.union_overload_result( + plausible_targets, + args, + new_arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + level + 1, + ) if sub_result is not None: res_items.extend(sub_result) else: @@ -1899,8 +2257,9 @@ def real_union(self, typ: Type) -> bool: return isinstance(typ, UnionType) and len(typ.relevant_items()) > 1 @contextmanager - def type_overrides_set(self, exprs: Sequence[Expression], - overrides: Sequence[Type]) -> Iterator[None]: + def type_overrides_set( + self, exprs: Sequence[Expression], overrides: Sequence[Type] + ) -> Iterator[None]: """Set _temporary_ type overrides for given expressions.""" assert len(exprs) == len(overrides) for expr, typ in zip(exprs, overrides): @@ -1978,7 +2337,8 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C arg_names=[None, None], ret_type=union_return, variables=variables, - implicit=True) + implicit=True, + ) final_args = [] for args_list in new_args: @@ -1990,87 +2350,104 @@ def combine_function_signatures(self, types: Sequence[Type]) -> Union[AnyType, C arg_kinds=new_kinds, ret_type=union_return, variables=variables, - implicit=True) - - def erased_signature_similarity(self, - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - args: List[Expression], - callee: CallableType, - context: Context) -> bool: + implicit=True, + ) + + def erased_signature_similarity( + self, + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + args: List[Expression], + callee: CallableType, + context: Context, + ) -> bool: """Determine whether arguments could match the signature at runtime, after erasing types.""" - formal_to_actual = map_actuals_to_formals(arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: arg_types[i]) + formal_to_actual = map_actuals_to_formals( + arg_kinds, arg_names, callee.arg_kinds, callee.arg_names, lambda i: arg_types[i] + ) with self.msg.filter_errors(): - if not self.check_argument_count(callee, arg_types, arg_kinds, arg_names, - formal_to_actual, None): + if not self.check_argument_count( + callee, arg_types, arg_kinds, arg_names, formal_to_actual, None + ): # Too few or many arguments -> no match. return False - def check_arg(caller_type: Type, - original_ccaller_type: Type, - caller_kind: ArgKind, - callee_type: Type, - n: int, - m: int, - callee: CallableType, - object_type: Optional[Type], - context: Context, - outer_context: Context) -> None: + def check_arg( + caller_type: Type, + original_ccaller_type: Type, + caller_kind: ArgKind, + callee_type: Type, + n: int, + m: int, + callee: CallableType, + object_type: Optional[Type], + context: Context, + outer_context: Context, + ) -> None: if not arg_approximate_similarity(caller_type, callee_type): # No match -- exit early since none of the remaining work can change # the result. raise Finished try: - self.check_argument_types(arg_types, arg_kinds, args, callee, - formal_to_actual, context=context, check_arg=check_arg) + self.check_argument_types( + arg_types, + arg_kinds, + args, + callee, + formal_to_actual, + context=context, + check_arg=check_arg, + ) return True except Finished: return False - def apply_generic_arguments(self, callable: CallableType, types: Sequence[Optional[Type]], - context: Context, skip_unsatisfied: bool = False) -> CallableType: + def apply_generic_arguments( + self, + callable: CallableType, + types: Sequence[Optional[Type]], + context: Context, + skip_unsatisfied: bool = False, + ) -> CallableType: """Simple wrapper around mypy.applytype.apply_generic_arguments.""" - return applytype.apply_generic_arguments(callable, types, - self.msg.incompatible_typevar_value, context, - skip_unsatisfied=skip_unsatisfied) + return applytype.apply_generic_arguments( + callable, + types, + self.msg.incompatible_typevar_value, + context, + skip_unsatisfied=skip_unsatisfied, + ) def check_any_type_call(self, args: List[Expression], callee: Type) -> Tuple[Type, Type]: self.infer_arg_types_in_empty_context(args) callee = get_proper_type(callee) if isinstance(callee, AnyType): - return (AnyType(TypeOfAny.from_another_any, source_any=callee), - AnyType(TypeOfAny.from_another_any, source_any=callee)) + return ( + AnyType(TypeOfAny.from_another_any, source_any=callee), + AnyType(TypeOfAny.from_another_any, source_any=callee), + ) else: return AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form) - def check_union_call(self, - callee: UnionType, - args: List[Expression], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]], - context: Context) -> Tuple[Type, Type]: + def check_union_call( + self, + callee: UnionType, + args: List[Expression], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], + context: Context, + ) -> Tuple[Type, Type]: with self.msg.disable_type_names(): results = [ - self.check_call( - subtype, - args, - arg_kinds, - context, - arg_names, - ) + self.check_call(subtype, args, arg_kinds, context, arg_names) for subtype in callee.relevant_items() ] - return (make_simplified_union([res[0] for res in results]), - callee) + return (make_simplified_union([res[0] for res in results]), callee) def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type: """Visit member expression (of form e.id).""" @@ -2078,8 +2455,7 @@ def visit_member_expr(self, e: MemberExpr, is_lvalue: bool = False) -> Type: result = self.analyze_ordinary_member_access(e, is_lvalue) return self.narrow_type_from_binder(e, result) - def analyze_ordinary_member_access(self, e: MemberExpr, - is_lvalue: bool) -> Type: + def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type: """Analyse member expression or member lvalue.""" if e.kind is not None: # This is a reference to a module attribute. @@ -2094,22 +2470,40 @@ def analyze_ordinary_member_access(self, e: MemberExpr, module_symbol_table = base.node.names member_type = analyze_member_access( - e.name, original_type, e, is_lvalue, False, False, - self.msg, original_type=original_type, chk=self.chk, + e.name, + original_type, + e, + is_lvalue, + False, + False, + self.msg, + original_type=original_type, + chk=self.chk, in_literal_context=self.is_literal_context(), - module_symbol_table=module_symbol_table) + module_symbol_table=module_symbol_table, + ) return member_type - def analyze_external_member_access(self, member: str, base_type: Type, - context: Context) -> Type: + def analyze_external_member_access( + self, member: str, base_type: Type, context: Context + ) -> Type: """Analyse member access that is external, i.e. it cannot refer to private definitions. Return the result type. """ # TODO remove; no private definitions in mypy - return analyze_member_access(member, base_type, context, False, False, False, - self.msg, original_type=base_type, chk=self.chk, - in_literal_context=self.is_literal_context()) + return analyze_member_access( + member, + base_type, + context, + False, + False, + False, + self.msg, + original_type=base_type, + chk=self.chk, + in_literal_context=self.is_literal_context(), + ) def is_literal_context(self) -> bool: return is_literal_type_like(self.type_context[-1]) @@ -2135,62 +2529,62 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty if self.is_literal_context(): return LiteralType(value=value, fallback=typ) else: - return typ.copy_modified(last_known_value=LiteralType( - value=value, - fallback=typ, - line=typ.line, - column=typ.column, - )) + return typ.copy_modified( + last_known_value=LiteralType( + value=value, fallback=typ, line=typ.line, column=typ.column + ) + ) def concat_tuples(self, left: TupleType, right: TupleType) -> TupleType: """Concatenate two fixed length tuples.""" - return TupleType(items=left.items + right.items, - fallback=self.named_type('builtins.tuple')) + return TupleType( + items=left.items + right.items, fallback=self.named_type("builtins.tuple") + ) def visit_int_expr(self, e: IntExpr) -> Type: """Type check an integer literal (trivial).""" - return self.infer_literal_expr_type(e.value, 'builtins.int') + return self.infer_literal_expr_type(e.value, "builtins.int") def visit_str_expr(self, e: StrExpr) -> Type: """Type check a string literal (trivial).""" - return self.infer_literal_expr_type(e.value, 'builtins.str') + return self.infer_literal_expr_type(e.value, "builtins.str") def visit_bytes_expr(self, e: BytesExpr) -> Type: """Type check a bytes literal (trivial).""" - return self.infer_literal_expr_type(e.value, 'builtins.bytes') + return self.infer_literal_expr_type(e.value, "builtins.bytes") def visit_unicode_expr(self, e: UnicodeExpr) -> Type: """Type check a unicode literal (trivial).""" - return self.infer_literal_expr_type(e.value, 'builtins.unicode') + return self.infer_literal_expr_type(e.value, "builtins.unicode") def visit_float_expr(self, e: FloatExpr) -> Type: """Type check a float literal (trivial).""" - return self.named_type('builtins.float') + return self.named_type("builtins.float") def visit_complex_expr(self, e: ComplexExpr) -> Type: """Type check a complex literal.""" - return self.named_type('builtins.complex') + return self.named_type("builtins.complex") def visit_ellipsis(self, e: EllipsisExpr) -> Type: """Type check '...'.""" if self.chk.options.python_version[0] >= 3: - return self.named_type('builtins.ellipsis') + return self.named_type("builtins.ellipsis") else: # '...' is not valid in normal Python 2 code, but it can # be used in stubs. The parser makes sure that we only # get this far if we are in a stub, and we can safely # return 'object' as ellipsis is special cased elsewhere. # The builtins.ellipsis type does not exist in Python 2. - return self.named_type('builtins.object') + return self.named_type("builtins.object") def visit_op_expr(self, e: OpExpr) -> Type: """Type check a binary operator expression.""" - if e.op == 'and' or e.op == 'or': + if e.op == "and" or e.op == "or": return self.check_boolean_op(e, e) - if e.op == '*' and isinstance(e.left, ListExpr): + if e.op == "*" and isinstance(e.left, ListExpr): # Expressions of form [...] * e get special type inference. return self.check_list_multiply(e) - if e.op == '%': + if e.op == "%": pyversion = self.chk.options.python_version if pyversion[0] == 3: if isinstance(e.left, BytesExpr) and pyversion[1] >= 5: @@ -2203,23 +2597,22 @@ def visit_op_expr(self, e: OpExpr) -> Type: left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) - if isinstance(proper_left_type, TupleType) and e.op == '+': - left_add_method = proper_left_type.partial_fallback.type.get('__add__') - if left_add_method and left_add_method.fullname == 'builtins.tuple.__add__': + if isinstance(proper_left_type, TupleType) and e.op == "+": + left_add_method = proper_left_type.partial_fallback.type.get("__add__") + if left_add_method and left_add_method.fullname == "builtins.tuple.__add__": proper_right_type = get_proper_type(self.accept(e.right)) if isinstance(proper_right_type, TupleType): - right_radd_method = proper_right_type.partial_fallback.type.get('__radd__') + right_radd_method = proper_right_type.partial_fallback.type.get("__radd__") if right_radd_method is None: return self.concat_tuples(proper_left_type, proper_right_type) if e.op in operators.op_methods: method = self.get_operator_method(e.op) - result, method_type = self.check_op(method, left_type, e.right, e, - allow_reverse=True) + result, method_type = self.check_op(method, left_type, e.right, e, allow_reverse=True) e.method_type = method_type return result else: - raise RuntimeError(f'Unknown operator {e.op}') + raise RuntimeError(f"Unknown operator {e.op}") def visit_comparison_expr(self, e: ComparisonExpr) -> Type: """Type check a comparison expression. @@ -2236,7 +2629,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: method_type: Optional[mypy.types.Type] = None - if operator == 'in' or operator == 'not in': + if operator == "in" or operator == "not in": # If the right operand has partial type, look it up without triggering # a "Need type annotation ..." message, as it would be noise. right_type = self.find_partial_type_ref_fast_path(right) @@ -2247,7 +2640,7 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: # are just to verify whether something is valid typing wise). with self.msg.filter_errors(save_filtered_errors=True) as local_errors: _, method_type = self.check_method_call_by_name( - method='__contains__', + method="__contains__", base_type=right_type, args=[left], arg_kinds=[ARG_POS], @@ -2263,41 +2656,51 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: if isinstance(right_type, PartialType): # We don't really know if this is an error or not, so just shut up. pass - elif (local_errors.has_new_errors() and + elif ( + local_errors.has_new_errors() + and # is_valid_var_arg is True for any Iterable - self.is_valid_var_arg(right_type)): + self.is_valid_var_arg(right_type) + ): _, itertype = self.chk.analyze_iterable_item_type(right) method_type = CallableType( [left_type], [nodes.ARG_POS], [None], self.bool_type(), - self.named_type('builtins.function')) + self.named_type("builtins.function"), + ) if not is_subtype(left_type, itertype): - self.msg.unsupported_operand_types('in', left_type, right_type, e) + self.msg.unsupported_operand_types("in", left_type, right_type, e) # Only show dangerous overlap if there are no other errors. - elif (not local_errors.has_new_errors() and cont_type and - self.dangerous_comparison(left_type, cont_type, - original_container=right_type)): - self.msg.dangerous_comparison(left_type, cont_type, 'container', e) + elif ( + not local_errors.has_new_errors() + and cont_type + and self.dangerous_comparison( + left_type, cont_type, original_container=right_type + ) + ): + self.msg.dangerous_comparison(left_type, cont_type, "container", e) else: self.msg.add_errors(local_errors.filtered_errors()) elif operator in operators.op_methods: method = self.get_operator_method(operator) with ErrorWatcher(self.msg.errors) as w: - sub_result, method_type = self.check_op(method, left_type, right, e, - allow_reverse=True) + sub_result, method_type = self.check_op( + method, left_type, right, e, allow_reverse=True + ) # Only show dangerous overlap if there are no other errors. See # testCustomEqCheckStrictEquality for an example. - if not w.has_new_errors() and operator in ('==', '!='): + if not w.has_new_errors() and operator in ("==", "!="): right_type = self.accept(right) # We suppress the error if there is a custom __eq__() method on either # side. User defined (or even standard library) classes can define this # to return True for comparisons between non-overlapping types. - if (not custom_special_method(left_type, '__eq__') and - not custom_special_method(right_type, '__eq__')): + if not custom_special_method( + left_type, "__eq__" + ) and not custom_special_method(right_type, "__eq__"): # Also flag non-overlapping literals in situations like: # x: Literal['a', 'b'] # if x == 'c': @@ -2305,18 +2708,18 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> Type: left_type = try_getting_literal(left_type) right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): - self.msg.dangerous_comparison(left_type, right_type, 'equality', e) + self.msg.dangerous_comparison(left_type, right_type, "equality", e) - elif operator == 'is' or operator == 'is not': + elif operator == "is" or operator == "is not": right_type = self.accept(right) # validate the right operand sub_result = self.bool_type() left_type = try_getting_literal(left_type) right_type = try_getting_literal(right_type) if self.dangerous_comparison(left_type, right_type): - self.msg.dangerous_comparison(left_type, right_type, 'identity', e) + self.msg.dangerous_comparison(left_type, right_type, "identity", e) method_type = None else: - raise RuntimeError(f'Unknown comparison operator {operator}') + raise RuntimeError(f"Unknown comparison operator {operator}") e.method_types.append(method_type) @@ -2345,8 +2748,9 @@ def find_partial_type_ref_fast_path(self, expr: Expression) -> Optional[Type]: return result return None - def dangerous_comparison(self, left: Type, right: Type, - original_container: Optional[Type] = None) -> bool: + def dangerous_comparison( + self, left: Type, right: Type, original_container: Optional[Type] = None + ) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. The original_container is the original container type for 'in' checks @@ -2387,17 +2791,22 @@ def dangerous_comparison(self, left: Type, right: Type, right = remove_optional(right) left, right = get_proper_types((left, right)) py2 = self.chk.options.python_version < (3, 0) - if (original_container and has_bytes_component(original_container, py2) and - has_bytes_component(left, py2)): + if ( + original_container + and has_bytes_component(original_container, py2) + and has_bytes_component(left, py2) + ): # We need to special case bytes and bytearray, because 97 in b'abc', b'a' in b'abc', # b'a' in bytearray(b'abc') etc. all return True (and we want to show the error only # if the check can _never_ be True). return False if isinstance(left, Instance) and isinstance(right, Instance): # Special case some builtin implementations of AbstractSet. - if (left.type.fullname in OVERLAPPING_TYPES_ALLOWLIST and - right.type.fullname in OVERLAPPING_TYPES_ALLOWLIST): - abstract_set = self.chk.lookup_typeinfo('typing.AbstractSet') + if ( + left.type.fullname in OVERLAPPING_TYPES_ALLOWLIST + and right.type.fullname in OVERLAPPING_TYPES_ALLOWLIST + ): + abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet") left = map_instance_to_supertype(left, abstract_set) right = map_instance_to_supertype(right, abstract_set) return not is_overlapping_types(left.args[0], right.args[0]) @@ -2408,23 +2817,20 @@ def dangerous_comparison(self, left: Type, right: Type, return not is_overlapping_types(left, right, ignore_promotions=False) def get_operator_method(self, op: str) -> str: - if op == '/' and self.chk.options.python_version[0] == 2: - return ( - '__truediv__' - if self.chk.tree.is_future_flag_set('division') - else '__div__' - ) + if op == "/" and self.chk.options.python_version[0] == 2: + return "__truediv__" if self.chk.tree.is_future_flag_set("division") else "__div__" else: return operators.op_methods[op] - def check_method_call_by_name(self, - method: str, - base_type: Type, - args: List[Expression], - arg_kinds: List[ArgKind], - context: Context, - original_type: Optional[Type] = None - ) -> Tuple[Type, Type]: + def check_method_call_by_name( + self, + method: str, + base_type: Type, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + original_type: Optional[Type] = None, + ) -> Tuple[Type, Type]: """Type check a call to a named method on an object. Return tuple (result type, inferred method type). The 'original_type' @@ -2434,25 +2840,33 @@ def check_method_call_by_name(self, # Unions are special-cased to allow plugins to act on each element of the union. base_type = get_proper_type(base_type) if isinstance(base_type, UnionType): - return self.check_union_method_call_by_name(method, base_type, - args, arg_kinds, - context, original_type) - - method_type = analyze_member_access(method, base_type, context, False, False, True, - self.msg, original_type=original_type, - chk=self.chk, - in_literal_context=self.is_literal_context()) - return self.check_method_call( - method, base_type, method_type, args, arg_kinds, context) - - def check_union_method_call_by_name(self, - method: str, - base_type: UnionType, - args: List[Expression], - arg_kinds: List[ArgKind], - context: Context, - original_type: Optional[Type] = None - ) -> Tuple[Type, Type]: + return self.check_union_method_call_by_name( + method, base_type, args, arg_kinds, context, original_type + ) + + method_type = analyze_member_access( + method, + base_type, + context, + False, + False, + True, + self.msg, + original_type=original_type, + chk=self.chk, + in_literal_context=self.is_literal_context(), + ) + return self.check_method_call(method, base_type, method_type, args, arg_kinds, context) + + def check_union_method_call_by_name( + self, + method: str, + base_type: UnionType, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + original_type: Optional[Type] = None, + ) -> Tuple[Type, Type]: """Type check a call to a named method on an object with union type. This essentially checks the call using check_method_call_by_name() for each @@ -2466,20 +2880,21 @@ def check_union_method_call_by_name(self, # mypy.checkmember.analyze_union_member_access(). with self.msg.disable_type_names(): item, meth_item = self.check_method_call_by_name( - method, typ, args, arg_kinds, - context, original_type, + method, typ, args, arg_kinds, context, original_type ) res.append(item) meth_res.append(meth_item) return make_simplified_union(res), make_simplified_union(meth_res) - def check_method_call(self, - method_name: str, - base_type: Type, - method_type: Type, - args: List[Expression], - arg_kinds: List[ArgKind], - context: Context) -> Tuple[Type, Type]: + def check_method_call( + self, + method_name: str, + base_type: Type, + method_type: Type, + args: List[Expression], + arg_kinds: List[ArgKind], + context: Context, + ) -> Tuple[Type, Type]: """Type check a call to a method with the given name and type on an object. Return tuple (result type, inferred method type). @@ -2489,18 +2904,27 @@ def check_method_call(self, # Try to refine the method signature using plugin hooks before checking the call. method_type = self.transform_callee_type( - callable_name, method_type, args, arg_kinds, context, object_type=object_type) - - return self.check_call(method_type, args, arg_kinds, - context, callable_name=callable_name, object_type=base_type) - - def check_op_reversible(self, - op_name: str, - left_type: Type, - left_expr: Expression, - right_type: Type, - right_expr: Expression, - context: Context) -> Tuple[Type, Type]: + callable_name, method_type, args, arg_kinds, context, object_type=object_type + ) + + return self.check_call( + method_type, + args, + arg_kinds, + context, + callable_name=callable_name, + object_type=base_type, + ) + + def check_op_reversible( + self, + op_name: str, + left_type: Type, + left_expr: Expression, + right_type: Type, + right_expr: Expression, + context: Context, + ) -> Tuple[Type, Type]: def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: """Looks up the given operator and returns the corresponding type, if it exists.""" @@ -2523,7 +2947,7 @@ def lookup_operator(op_name: str, base_type: Type) -> Optional[Type]: context=context, msg=self.msg, chk=self.chk, - in_literal_context=self.is_literal_context() + in_literal_context=self.is_literal_context(), ) return None if w.has_new_errors() else member @@ -2581,14 +3005,14 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: # This is the case even if the __add__ method is completely missing and the __radd__ # method is defined. - variants_raw = [ - (left_op, left_type, right_expr) - ] - elif (is_subtype(right_type, left_type) - and isinstance(left_type, Instance) - and isinstance(right_type, Instance) - and left_type.type.alt_promote is not right_type.type - and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name)): + variants_raw = [(left_op, left_type, right_expr)] + elif ( + is_subtype(right_type, left_type) + and isinstance(left_type, Instance) + and isinstance(right_type, Instance) + and left_type.type.alt_promote is not right_type.type + and lookup_definer(left_type, op_name) != lookup_definer(right_type, rev_op_name) + ): # When we do "A() + B()" where B is a subclass of A, we'll actually try calling # B's __radd__ method first, but ONLY if B explicitly defines or overrides the # __radd__ method. @@ -2599,18 +3023,12 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: # As a special case, the alt_promote check makes sure that we don't use the # __radd__ method of int if the LHS is a native int type. - variants_raw = [ - (right_op, right_type, left_expr), - (left_op, left_type, right_expr), - ] + variants_raw = [(right_op, right_type, left_expr), (left_op, left_type, right_expr)] else: # In all other cases, we do the usual thing and call __add__ first and # __radd__ second when doing "A() + B()". - variants_raw = [ - (left_op, left_type, right_expr), - (right_op, right_type, left_expr), - ] + variants_raw = [(left_op, left_type, right_expr), (right_op, right_type, left_expr)] # STEP 2b: # When running Python 2, we might also try calling the __cmp__ method. @@ -2644,8 +3062,7 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: results = [] for method, obj, arg in variants: with self.msg.filter_errors(save_filtered_errors=True) as local_errors: - result = self.check_method_call( - op_name, obj, method, [arg], [ARG_POS], context) + result = self.check_method_call(op_name, obj, method, [arg], [ARG_POS], context) if local_errors.has_new_errors(): errors.append(local_errors.filtered_errors()) results.append(result) @@ -2655,8 +3072,9 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: # We finish invoking above operators and no early return happens. Therefore, # we check if either the LHS or the RHS is Instance and fallbacks to Any, # if so, we also return Any - if ((isinstance(left_type, Instance) and left_type.type.fallback_to_any) or - (isinstance(right_type, Instance) and right_type.type.fallback_to_any)): + if (isinstance(left_type, Instance) and left_type.type.fallback_to_any) or ( + isinstance(right_type, Instance) and right_type.type.fallback_to_any + ): any_type = AnyType(TypeOfAny.special_form) return any_type, any_type @@ -2667,7 +3085,8 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: if not variants: with self.msg.filter_errors(save_filtered_errors=True) as local_errors: result = self.check_method_call_by_name( - op_name, left_type, [right_expr], [ARG_POS], context) + op_name, left_type, [right_expr], [ARG_POS], context + ) if local_errors.has_new_errors(): errors.append(local_errors.filtered_errors()) @@ -2691,9 +3110,14 @@ def lookup_definer(typ: Instance, attr_name: str) -> Optional[str]: result = error_any, error_any return result - def check_op(self, method: str, base_type: Type, - arg: Expression, context: Context, - allow_reverse: bool = False) -> Tuple[Type, Type]: + def check_op( + self, + method: str, + base_type: Type, + arg: Expression, + context: Context, + allow_reverse: bool = False, + ) -> Tuple[Type, Type]: """Type check a binary operation which maps to a method call. Return tuple (result type, inferred operator method type). @@ -2703,9 +3127,12 @@ def check_op(self, method: str, base_type: Type, left_variants = [base_type] base_type = get_proper_type(base_type) if isinstance(base_type, UnionType): - left_variants = [item for item in - flatten_nested_unions(base_type.relevant_items(), - handle_type_alias_type=True)] + left_variants = [ + item + for item in flatten_nested_unions( + base_type.relevant_items(), handle_type_alias_type=True + ) + ] right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -2722,7 +3149,8 @@ def check_op(self, method: str, base_type: Type, left_expr=TempNode(left_possible_type, context=context), right_type=right_type, right_expr=arg, - context=context) + context=context, + ) all_results.append(result) all_inferred.append(inferred) @@ -2747,8 +3175,9 @@ def check_op(self, method: str, base_type: Type, if isinstance(right_type, UnionType): right_variants = [ (item, TempNode(item, context=context)) - for item in flatten_nested_unions(right_type.relevant_items(), - handle_type_alias_type=True) + for item in flatten_nested_unions( + right_type.relevant_items(), handle_type_alias_type=True + ) ] all_results = [] @@ -2763,7 +3192,8 @@ def check_op(self, method: str, base_type: Type, left_expr=TempNode(left_possible_type, context=context), right_type=right_possible_type, right_expr=right_expr, - context=context) + context=context, + ) all_results.append(result) all_inferred.append(inferred) @@ -2777,11 +3207,11 @@ def check_op(self, method: str, base_type: Type, if len(left_variants) >= 2 and len(right_variants) >= 2: self.msg.warn_both_operands_are_from_unions(recent_context) elif len(left_variants) >= 2: - self.msg.warn_operand_was_from_union( - "Left", base_type, context=recent_context) + self.msg.warn_operand_was_from_union("Left", base_type, context=recent_context) elif len(right_variants) >= 2: self.msg.warn_operand_was_from_union( - "Right", right_type, context=recent_context) + "Right", right_type, context=recent_context + ) # See the comment in 'check_overload_call' for more details on why # we call 'combine_function_signature' instead of just unioning the inferred @@ -2799,8 +3229,8 @@ def check_op(self, method: str, base_type: Type, ) def get_reverse_op_method(self, method: str) -> str: - if method == '__div__' and self.chk.options.python_version[0] == 2: - return '__rdiv__' + if method == "__div__" and self.chk.options.python_version[0] == 2: + return "__rdiv__" else: return operators.reverse_op_methods[method] @@ -2819,15 +3249,15 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: self.accept(e.left, ctx), "builtins.bool" ) - assert e.op in ('and', 'or') # Checked by visit_op_expr + assert e.op in ("and", "or") # Checked by visit_op_expr if e.right_always: left_map, right_map = None, {} # type: mypy.checker.TypeMap, mypy.checker.TypeMap elif e.right_unreachable: left_map, right_map = {}, None - elif e.op == 'and': + elif e.op == "and": right_map, left_map = self.chk.find_isinstance_check(e.left) - elif e.op == 'or': + elif e.op == "or": left_map, right_map = self.chk.find_isinstance_check(e.left) # If left_map is None then we know mypy considers the left expression @@ -2866,10 +3296,10 @@ def check_boolean_op(self, e: OpExpr, context: Context) -> Type: assert right_map is not None return right_type - if e.op == 'and': + if e.op == "and": restricted_left_type = false_only(expanded_left_type) result_is_left = not expanded_left_type.can_be_true - elif e.op == 'or': + elif e.op == "or": restricted_left_type = true_only(expanded_left_type) result_is_left = not expanded_left_type.can_be_false @@ -2888,13 +3318,13 @@ def check_list_multiply(self, e: OpExpr) -> Type: Type inference is special-cased for this common construct. """ right_type = self.accept(e.right) - if is_subtype(right_type, self.named_type('builtins.int')): + if is_subtype(right_type, self.named_type("builtins.int")): # Special case: [...] * . Use the type context of the # OpExpr, since the multiplication does not affect the type. left_type = self.accept(e.left, type_context=self.type_context[-1]) else: left_type = self.accept(e.left) - result, method_type = self.check_op('__mul__', left_type, e.right, e) + result, method_type = self.check_op("__mul__", left_type, e.right, e) e.method_type = method_type return result @@ -2910,7 +3340,7 @@ def visit_unary_expr(self, e: UnaryExpr) -> Type: """Type check an unary operation ('not', '-', '+' or '~').""" operand_type = self.accept(e.expr) op = e.op - if op == 'not': + if op == "not": result: Type = self.bool_type() else: method = operators.unary_op_methods[op] @@ -2925,8 +3355,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type: """ result = self.visit_index_expr_helper(e) result = get_proper_type(self.narrow_type_from_binder(e, result)) - if (self.is_literal_context() and isinstance(result, Instance) - and result.last_known_value is not None): + if ( + self.is_literal_context() + and isinstance(result, Instance) + and result.last_known_value is not None + ): result = result.last_known_value return result @@ -2937,8 +3370,9 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type: left_type = self.accept(e.base) return self.visit_index_with_type(left_type, e) - def visit_index_with_type(self, left_type: Type, e: IndexExpr, - original_type: Optional[ProperType] = None) -> Type: + def visit_index_with_type( + self, left_type: Type, e: IndexExpr, original_type: Optional[ProperType] = None + ) -> Type: """Analyze type of an index expression for a given type of base expression. The 'original_type' is used for error messages (currently used for union types). @@ -2952,10 +3386,13 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, if isinstance(left_type, UnionType): original_type = original_type or left_type # Don't combine literal types, since we may need them for type narrowing. - return make_simplified_union([self.visit_index_with_type(typ, e, - original_type) - for typ in left_type.relevant_items()], - contract_literals=False) + return make_simplified_union( + [ + self.visit_index_with_type(typ, e, original_type) + for typ in left_type.relevant_items() + ], + contract_literals=False, + ) elif isinstance(left_type, TupleType) and self.chk.in_checked_function(): # Special case for tuples. They return a more specific type when # indexed by an integer literal. @@ -2978,16 +3415,20 @@ def visit_index_with_type(self, left_type: Type, e: IndexExpr, return self.nonliteral_tuple_index_helper(left_type, index) elif isinstance(left_type, TypedDictType): return self.visit_typeddict_index_expr(left_type, e.index) - elif (isinstance(left_type, CallableType) - and left_type.is_type_obj() and left_type.type_object().is_enum): + elif ( + isinstance(left_type, CallableType) + and left_type.is_type_obj() + and left_type.type_object().is_enum + ): return self.visit_enum_index_expr(left_type.type_object(), e.index, e) - elif (isinstance(left_type, TypeVarType) - and not self.has_member(left_type.upper_bound, "__getitem__")): + elif isinstance(left_type, TypeVarType) and not self.has_member( + left_type.upper_bound, "__getitem__" + ): return self.visit_index_with_type(left_type.upper_bound, e, original_type) else: result, method_type = self.check_method_call_by_name( - '__getitem__', left_type, [e.index], [ARG_POS], e, - original_type=original_type) + "__getitem__", left_type, [e.index], [ARG_POS], e, original_type=original_type + ) e.method_type = method_type return result @@ -3034,7 +3475,7 @@ def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]: if isinstance(index, IntExpr): return [index.value] elif isinstance(index, UnaryExpr): - if index.op == '-': + if index.op == "-": operand = index.expr if isinstance(operand, IntExpr): return [-1 * operand.value] @@ -3055,22 +3496,26 @@ def try_getting_int_literals(self, index: Expression) -> Optional[List[int]]: def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type: index_type = self.accept(index) - expected_type = UnionType.make_union([self.named_type('builtins.int'), - self.named_type('builtins.slice')]) - if not self.chk.check_subtype(index_type, expected_type, index, - message_registry.INVALID_TUPLE_INDEX_TYPE, - 'actual type', 'expected type'): + expected_type = UnionType.make_union( + [self.named_type("builtins.int"), self.named_type("builtins.slice")] + ) + if not self.chk.check_subtype( + index_type, + expected_type, + index, + message_registry.INVALID_TUPLE_INDEX_TYPE, + "actual type", + "expected type", + ): return AnyType(TypeOfAny.from_error) else: union = make_simplified_union(left_type.items) if isinstance(index, SliceExpr): - return self.chk.named_generic_type('builtins.tuple', [union]) + return self.chk.named_generic_type("builtins.tuple", [union]) else: return union - def visit_typeddict_index_expr(self, td_type: TypedDictType, - index: Expression, - ) -> Type: + def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression) -> Type: if isinstance(index, (StrExpr, UnicodeExpr)): key_names = [index.value] else: @@ -3085,9 +3530,11 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, if isinstance(key_type, Instance) and key_type.last_known_value is not None: key_type = key_type.last_known_value - if (isinstance(key_type, LiteralType) - and isinstance(key_type.value, str) - and key_type.fallback.type.fullname != 'builtins.bytes'): + if ( + isinstance(key_type, LiteralType) + and isinstance(key_type.value, str) + and key_type.fallback.type.fullname != "builtins.bytes" + ): key_names.append(key_type.value) else: self.msg.typeddict_key_must_be_string_literal(td_type, index) @@ -3108,30 +3555,46 @@ def visit_enum_index_expr( ) -> Type: string_type: Type = self.named_type("builtins.str") if self.chk.options.python_version[0] < 3: - string_type = UnionType.make_union([string_type, - self.named_type('builtins.unicode')]) - self.chk.check_subtype(self.accept(index), string_type, context, - "Enum index should be a string", "actual index type") + string_type = UnionType.make_union([string_type, self.named_type("builtins.unicode")]) + self.chk.check_subtype( + self.accept(index), + string_type, + context, + "Enum index should be a string", + "actual index type", + ) return Instance(enum_type, []) def visit_cast_expr(self, expr: CastExpr) -> Type: """Type check a cast expression.""" - source_type = self.accept(expr.expr, type_context=AnyType(TypeOfAny.special_form), - allow_none_return=True, always_allow_any=True) + source_type = self.accept( + expr.expr, + type_context=AnyType(TypeOfAny.special_form), + allow_none_return=True, + always_allow_any=True, + ) target_type = expr.type options = self.chk.options - if (options.warn_redundant_casts and not isinstance(get_proper_type(target_type), AnyType) - and is_same_type(source_type, target_type)): + if ( + options.warn_redundant_casts + and not isinstance(get_proper_type(target_type), AnyType) + and is_same_type(source_type, target_type) + ): self.msg.redundant_cast(target_type, expr) if options.disallow_any_unimported and has_any_from_unimported_type(target_type): self.msg.unimported_type_becomes_any("Target type of cast", target_type, expr) - check_for_explicit_any(target_type, self.chk.options, self.chk.is_typeshed_stub, self.msg, - context=expr) + check_for_explicit_any( + target_type, self.chk.options, self.chk.is_typeshed_stub, self.msg, context=expr + ) return target_type def visit_assert_type_expr(self, expr: AssertTypeExpr) -> Type: - source_type = self.accept(expr.expr, type_context=self.type_context[-1], - allow_none_return=True, always_allow_any=True) + source_type = self.accept( + expr.expr, + type_context=self.type_context[-1], + allow_none_return=True, + always_allow_any=True, + ) target_type = expr.type if not is_same_type(source_type, target_type): self.msg.assert_type_fail(source_type, target_type, expr) @@ -3141,13 +3604,15 @@ def visit_reveal_expr(self, expr: RevealExpr) -> Type: """Type check a reveal_type expression.""" if expr.kind == REVEAL_TYPE: assert expr.expr is not None - revealed_type = self.accept(expr.expr, type_context=self.type_context[-1], - allow_none_return=True) + revealed_type = self.accept( + expr.expr, type_context=self.type_context[-1], allow_none_return=True + ) if not self.chk.current_node_deferred: self.msg.reveal_type(revealed_type, expr.expr) if not self.chk.in_checked_function(): - self.msg.note("'reveal_type' always outputs 'Any' in unchecked functions", - expr.expr) + self.msg.note( + "'reveal_type' always outputs 'Any' in unchecked functions", expr.expr + ) return revealed_type else: # REVEAL_LOCALS @@ -3155,9 +3620,11 @@ def visit_reveal_expr(self, expr: RevealExpr) -> Type: # the RevealExpr contains a local_nodes attribute, # calculated at semantic analysis time. Use it to pull out the # corresponding subset of variables in self.chk.type_map - names_to_types = { - var_node.name: var_node.type for var_node in expr.local_nodes - } if expr.local_nodes is not None else {} + names_to_types = ( + {var_node.name: var_node.type for var_node in expr.local_nodes} + if expr.local_nodes is not None + else {} + ) self.msg.reveal_locals(names_to_types, expr) return NoneType() @@ -3174,8 +3641,9 @@ def visit_type_application(self, tapp: TypeApplication) -> Type: """ if isinstance(tapp.expr, RefExpr) and isinstance(tapp.expr.node, TypeAlias): # Subscription of a (generic) alias in runtime context, expand the alias. - item = expand_type_alias(tapp.expr.node, tapp.types, self.chk.fail, - tapp.expr.node.no_args, tapp) + item = expand_type_alias( + tapp.expr.node, tapp.types, self.chk.fail, tapp.expr.node.no_args, tapp + ) item = get_proper_type(item) if isinstance(item, Instance): tp = type_object_type(item.type, self.named_type) @@ -3206,13 +3674,13 @@ def visit_type_alias_expr(self, alias: TypeAliasExpr) -> Type: both `reveal_type` instances will reveal the same type `def (...) -> builtins.list[Any]`. Note that type variables are implicitly substituted with `Any`. """ - return self.alias_type_in_runtime_context(alias.node, alias.no_args, - alias, alias_definition=True) + return self.alias_type_in_runtime_context( + alias.node, alias.no_args, alias, alias_definition=True + ) - def alias_type_in_runtime_context(self, alias: TypeAlias, - no_args: bool, ctx: Context, - *, - alias_definition: bool = False) -> Type: + def alias_type_in_runtime_context( + self, alias: TypeAlias, no_args: bool, ctx: Context, *, alias_definition: bool = False + ) -> Type: """Get type of a type alias (could be generic) in a runtime expression. Note that this function can be called only if the alias appears _not_ @@ -3241,9 +3709,12 @@ class LongName(Generic[T]): ... if no_args: return tp return self.apply_type_arguments_to_callable(tp, item.args, ctx) - elif (isinstance(item, TupleType) and - # Tuple[str, int]() fails at runtime, only named tuples and subclasses work. - tuple_fallback(item).type.fullname != 'builtins.tuple'): + elif ( + isinstance(item, TupleType) + and + # Tuple[str, int]() fails at runtime, only named tuples and subclasses work. + tuple_fallback(item).type.fullname != "builtins.tuple" + ): return type_object_type(tuple_fallback(item).type, self.named_type) elif isinstance(item, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=item) @@ -3251,7 +3722,7 @@ class LongName(Generic[T]): ... if alias_definition: return AnyType(TypeOfAny.special_form) # This type is invalid in most runtime contexts, give it an 'object' type. - return self.named_type('builtins.object') + return self.named_type("builtins.object") def apply_type_arguments_to_callable( self, tp: Type, args: Sequence[Type], ctx: Context @@ -3267,29 +3738,26 @@ def apply_type_arguments_to_callable( if isinstance(tp, CallableType): if len(tp.variables) != len(args): - self.msg.incompatible_type_application(len(tp.variables), - len(args), ctx) + self.msg.incompatible_type_application(len(tp.variables), len(args), ctx) return AnyType(TypeOfAny.from_error) return self.apply_generic_arguments(tp, args, ctx) if isinstance(tp, Overloaded): for it in tp.items: if len(it.variables) != len(args): - self.msg.incompatible_type_application(len(it.variables), - len(args), ctx) + self.msg.incompatible_type_application(len(it.variables), len(args), ctx) return AnyType(TypeOfAny.from_error) - return Overloaded([self.apply_generic_arguments(it, args, ctx) - for it in tp.items]) + return Overloaded([self.apply_generic_arguments(it, args, ctx) for it in tp.items]) return AnyType(TypeOfAny.special_form) def visit_list_expr(self, e: ListExpr) -> Type: """Type check a list expression [...].""" - return self.check_lst_expr(e, 'builtins.list', '') + return self.check_lst_expr(e, "builtins.list", "") def visit_set_expr(self, e: SetExpr) -> Type: - return self.check_lst_expr(e, 'builtins.set', '') + return self.check_lst_expr(e, "builtins.set", "") def fast_container_type( - self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str + self, e: Union[ListExpr, SetExpr, TupleExpr], container_fullname: str ) -> Optional[Type]: """ Fast path to determine the type of a list or set literal, @@ -3322,8 +3790,9 @@ def fast_container_type( self.resolved_type[e] = ct return ct - def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, - tag: str) -> Type: + def check_lst_expr( + self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, tag: str + ) -> Type: # fast path t = self.fast_container_type(e, fullname) if t: @@ -3333,21 +3802,22 @@ def check_lst_expr(self, e: Union[ListExpr, SetExpr, TupleExpr], fullname: str, # Used for list and set expressions, as well as for tuples # containing star expressions that don't refer to a # Tuple. (Note: "lst" stands for list-set-tuple. :-) - tv = TypeVarType('T', 'T', -1, [], self.object_type()) + tv = TypeVarType("T", "T", -1, [], self.object_type()) constructor = CallableType( [tv], [nodes.ARG_STAR], [None], self.chk.named_generic_type(fullname, [tv]), - self.named_type('builtins.function'), + self.named_type("builtins.function"), name=tag, - variables=[tv]) - out = self.check_call(constructor, - [(i.expr if isinstance(i, StarExpr) else i) - for i in e.items], - [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) - for i in e.items], - e)[0] + variables=[tv], + ) + out = self.check_call( + constructor, + [(i.expr if isinstance(i, StarExpr) else i) for i in e.items], + [(nodes.ARG_STAR if isinstance(i, StarExpr) else nodes.ARG_POS) for i in e.items], + e, + )[0] return remove_instance_last_known_values(out) def visit_tuple_expr(self, e: TupleExpr) -> Type: @@ -3356,9 +3826,12 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: type_context = get_proper_type(self.type_context[-1]) type_context_items = None if isinstance(type_context, UnionType): - tuples_in_context = [t for t in get_proper_types(type_context.items) - if (isinstance(t, TupleType) and len(t.items) == len(e.items)) or - is_named_instance(t, 'builtins.tuple')] + tuples_in_context = [ + t + for t in get_proper_types(type_context.items) + if (isinstance(t, TupleType) and len(t.items) == len(e.items)) + or is_named_instance(t, "builtins.tuple") + ] if len(tuples_in_context) == 1: type_context = tuples_in_context[0] else: @@ -3368,7 +3841,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: if isinstance(type_context, TupleType): type_context_items = type_context.items - elif type_context and is_named_instance(type_context, 'builtins.tuple'): + elif type_context and is_named_instance(type_context, "builtins.tuple"): assert isinstance(type_context, Instance) if type_context.args: type_context_items = [type_context.args[0]] * len(e.items) @@ -3397,7 +3870,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: else: # A star expression that's not a Tuple. # Treat the whole thing as a variable-length tuple. - return self.check_lst_expr(e, 'builtins.tuple', '') + return self.check_lst_expr(e, "builtins.tuple", "") else: if not type_context_items or j >= len(type_context_items): tt = self.accept(item) @@ -3407,7 +3880,7 @@ def visit_tuple_expr(self, e: TupleExpr) -> Type: items.append(tt) # This is a partial fallback item type. A precise type will be calculated on demand. fallback_item = AnyType(TypeOfAny.special_form) - return TupleType(items, self.chk.named_generic_type('builtins.tuple', [fallback_item])) + return TupleType(items, self.chk.named_generic_type("builtins.tuple", [fallback_item])) def fast_dict_type(self, e: DictExpr) -> Optional[Type]: """ @@ -3433,9 +3906,9 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: if key is None: st = get_proper_type(self.accept(value)) if ( - isinstance(st, Instance) - and st.type.fullname == 'builtins.dict' - and len(st.args) == 2 + isinstance(st, Instance) + and st.type.fullname == "builtins.dict" + and len(st.args) == 2 ): stargs = (st.args[0], st.args[1]) else: @@ -3452,7 +3925,7 @@ def fast_dict_type(self, e: DictExpr) -> Optional[Type]: if stargs and (stargs[0] != kt or stargs[1] != vt): self.resolved_type[e] = NoneType() return None - dt = self.chk.named_generic_type('builtins.dict', [kt, vt]) + dt = self.chk.named_generic_type("builtins.dict", [kt, vt]) self.resolved_type[e] = dt return dt @@ -3467,11 +3940,7 @@ def visit_dict_expr(self, e: DictExpr) -> Type: # to avoid the second error, we always return TypedDict type that was requested typeddict_context = self.find_typeddict_context(self.type_context[-1], e) if typeddict_context: - self.check_typeddict_call_with_dict( - callee=typeddict_context, - kwargs=e, - context=e - ) + self.check_typeddict_call_with_dict(callee=typeddict_context, kwargs=e, context=e) return typeddict_context.copy_modified() # fast path attempt @@ -3495,8 +3964,8 @@ def visit_dict_expr(self, e: DictExpr) -> Type: tup.column = value.column args.append(tup) # Define type variables (used in constructors below). - kt = TypeVarType('KT', 'KT', -1, [], self.object_type()) - vt = TypeVarType('VT', 'VT', -2, [], self.object_type()) + kt = TypeVarType("KT", "KT", -1, [], self.object_type()) + vt = TypeVarType("VT", "VT", -2, [], self.object_type()) rv = None # Call dict(*args), unless it's empty and stargs is not. if args or not stargs: @@ -3504,13 +3973,14 @@ def visit_dict_expr(self, e: DictExpr) -> Type: # # def (*v: Tuple[kt, vt]) -> Dict[kt, vt]: ... constructor = CallableType( - [TupleType([kt, vt], self.named_type('builtins.tuple'))], + [TupleType([kt, vt], self.named_type("builtins.tuple"))], [nodes.ARG_STAR], [None], - self.chk.named_generic_type('builtins.dict', [kt, vt]), - self.named_type('builtins.function'), - name='', - variables=[kt, vt]) + self.chk.named_generic_type("builtins.dict", [kt, vt]), + self.named_type("builtins.function"), + name="", + variables=[kt, vt], + ) rv = self.check_call(constructor, args, [nodes.ARG_POS] * len(args), e)[0] else: # dict(...) will be called below. @@ -3521,21 +3991,23 @@ def visit_dict_expr(self, e: DictExpr) -> Type: for arg in stargs: if rv is None: constructor = CallableType( - [self.chk.named_generic_type('typing.Mapping', [kt, vt])], + [self.chk.named_generic_type("typing.Mapping", [kt, vt])], [nodes.ARG_POS], [None], - self.chk.named_generic_type('builtins.dict', [kt, vt]), - self.named_type('builtins.function'), - name='', - variables=[kt, vt]) + self.chk.named_generic_type("builtins.dict", [kt, vt]), + self.named_type("builtins.function"), + name="", + variables=[kt, vt], + ) rv = self.check_call(constructor, [arg], [nodes.ARG_POS], arg)[0] else: - self.check_method_call_by_name('update', rv, [arg], [nodes.ARG_POS], arg) + self.check_method_call_by_name("update", rv, [arg], [nodes.ARG_POS], arg) assert rv is not None return rv - def find_typeddict_context(self, context: Optional[Type], - dict_expr: DictExpr) -> Optional[TypedDictType]: + def find_typeddict_context( + self, context: Optional[Type], dict_expr: DictExpr + ) -> Optional[TypedDictType]: context = get_proper_type(context) if isinstance(context, TypedDictType): return context @@ -3543,9 +4015,9 @@ def find_typeddict_context(self, context: Optional[Type], items = [] for item in context.items: item_context = self.find_typeddict_context(item, dict_expr) - if (item_context is not None - and self.match_typeddict_call_with_dict( - item_context, dict_expr, dict_expr)): + if item_context is not None and self.match_typeddict_call_with_dict( + item_context, dict_expr, dict_expr + ): items.append(item_context) if len(items) == 1: # Only one union item is valid TypedDict for the given dict_expr, so use the @@ -3574,7 +4046,7 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type: # This is important as otherwise the following statements would be # considered unreachable. There's no useful type context. ret_type = self.accept(e.expr(), allow_none_return=True) - fallback = self.named_type('builtins.function') + fallback = self.named_type("builtins.function") self.chk.return_types.pop() return callable_type(e, fallback, ret_type) else: @@ -3588,8 +4060,9 @@ def visit_lambda_expr(self, e: LambdaExpr) -> Type: self.chk.return_types.pop() return replace_callable_return_type(inferred_type, ret_type) - def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[CallableType], - Optional[CallableType]]: + def infer_lambda_type_using_context( + self, e: LambdaExpr + ) -> Tuple[Optional[CallableType], Optional[CallableType]]: """Try to infer lambda expression type using context. Return None if could not infer type. @@ -3599,8 +4072,9 @@ def infer_lambda_type_using_context(self, e: LambdaExpr) -> Tuple[Optional[Calla ctx = get_proper_type(self.type_context[-1]) if isinstance(ctx, UnionType): - callables = [t for t in get_proper_types(ctx.relevant_items()) - if isinstance(t, CallableType)] + callables = [ + t for t in get_proper_types(ctx.relevant_items()) if isinstance(t, CallableType) + ] if len(callables) == 1: ctx = callables[0] @@ -3680,26 +4154,28 @@ def visit_super_expr(self, e: SuperExpr) -> Type: self.chk.fail(message_registry.TARGET_CLASS_HAS_NO_BASE_CLASS, e) return AnyType(TypeOfAny.from_error) - for base in mro[index+1:]: + for base in mro[index + 1 :]: if e.name in base.names or base == mro[-1]: if e.info and e.info.fallback_to_any and base == mro[-1]: # There's an undefined base class, and we're at the end of the # chain. That's not an error. return AnyType(TypeOfAny.special_form) - return analyze_member_access(name=e.name, - typ=instance_type, - is_lvalue=False, - is_super=True, - is_operator=False, - original_type=instance_type, - override_info=base, - context=e, - msg=self.msg, - chk=self.chk, - in_literal_context=self.is_literal_context()) + return analyze_member_access( + name=e.name, + typ=instance_type, + is_lvalue=False, + is_super=True, + is_operator=False, + original_type=instance_type, + override_info=base, + context=e, + msg=self.msg, + chk=self.chk, + in_literal_context=self.is_literal_context(), + ) - assert False, 'unreachable' + assert False, "unreachable" def _super_arg_types(self, e: SuperExpr) -> Union[Type, Tuple[Type, Type]]: """ @@ -3763,8 +4239,9 @@ def _super_arg_types(self, e: SuperExpr) -> Union[Type, Tuple[Type, Type]]: else: return AnyType(TypeOfAny.from_another_any, source_any=type_item) - if (not isinstance(type_type, TypeType) - and not (isinstance(type_type, FunctionLike) and type_type.is_type_obj())): + if not isinstance(type_type, TypeType) and not ( + isinstance(type_type, FunctionLike) and type_type.is_type_obj() + ): self.msg.first_argument_for_super_must_be_type(type_type, e) return AnyType(TypeOfAny.from_error) @@ -3786,40 +4263,45 @@ def _super_arg_types(self, e: SuperExpr) -> Union[Type, Tuple[Type, Type]]: return type_type, instance_type def visit_slice_expr(self, e: SliceExpr) -> Type: - expected = make_optional_type(self.named_type('builtins.int')) + expected = make_optional_type(self.named_type("builtins.int")) for index in [e.begin_index, e.end_index, e.stride]: if index: t = self.accept(index) - self.chk.check_subtype(t, expected, - index, message_registry.INVALID_SLICE_INDEX) - return self.named_type('builtins.slice') + self.chk.check_subtype(t, expected, index, message_registry.INVALID_SLICE_INDEX) + return self.named_type("builtins.slice") def visit_list_comprehension(self, e: ListComprehension) -> Type: return self.check_generator_or_comprehension( - e.generator, 'builtins.list', '') + e.generator, "builtins.list", "" + ) def visit_set_comprehension(self, e: SetComprehension) -> Type: return self.check_generator_or_comprehension( - e.generator, 'builtins.set', '') + e.generator, "builtins.set", "" + ) def visit_generator_expr(self, e: GeneratorExpr) -> Type: # If any of the comprehensions use async for, the expression will return an async generator # object, or if the left-side expression uses await. if any(e.is_async) or has_await_expression(e.left_expr): - typ = 'typing.AsyncGenerator' + typ = "typing.AsyncGenerator" # received type is always None in async generator expressions additional_args: List[Type] = [NoneType()] else: - typ = 'typing.Generator' + typ = "typing.Generator" # received type and returned type are None additional_args = [NoneType(), NoneType()] - return self.check_generator_or_comprehension(e, typ, '', - additional_args=additional_args) + return self.check_generator_or_comprehension( + e, typ, "", additional_args=additional_args + ) - def check_generator_or_comprehension(self, gen: GeneratorExpr, - type_name: str, - id_for_messages: str, - additional_args: Optional[List[Type]] = None) -> Type: + def check_generator_or_comprehension( + self, + gen: GeneratorExpr, + type_name: str, + id_for_messages: str, + additional_args: Optional[List[Type]] = None, + ) -> Type: """Type check a generator expression or a list comprehension.""" additional_args = additional_args or [] with self.chk.binder.frame_context(can_skip=True, fall_through=0): @@ -3827,16 +4309,17 @@ def check_generator_or_comprehension(self, gen: GeneratorExpr, # Infer the type of the list comprehension by using a synthetic generic # callable type. - tv = TypeVarType('T', 'T', -1, [], self.object_type()) + tv = TypeVarType("T", "T", -1, [], self.object_type()) tv_list: List[Type] = [tv] constructor = CallableType( tv_list, [nodes.ARG_POS], [None], self.chk.named_generic_type(type_name, tv_list + additional_args), - self.chk.named_type('builtins.function'), + self.chk.named_type("builtins.function"), name=id_for_messages, - variables=[tv]) + variables=[tv], + ) return self.check_call(constructor, [gen.left_expr], [nodes.ARG_POS], gen)[0] def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type: @@ -3846,18 +4329,20 @@ def visit_dictionary_comprehension(self, e: DictionaryComprehension) -> Type: # Infer the type of the list comprehension by using a synthetic generic # callable type. - ktdef = TypeVarType('KT', 'KT', -1, [], self.object_type()) - vtdef = TypeVarType('VT', 'VT', -2, [], self.object_type()) + ktdef = TypeVarType("KT", "KT", -1, [], self.object_type()) + vtdef = TypeVarType("VT", "VT", -2, [], self.object_type()) constructor = CallableType( [ktdef, vtdef], [nodes.ARG_POS, nodes.ARG_POS], [None, None], - self.chk.named_generic_type('builtins.dict', [ktdef, vtdef]), - self.chk.named_type('builtins.function'), - name='', - variables=[ktdef, vtdef]) - return self.check_call(constructor, - [e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e)[0] + self.chk.named_generic_type("builtins.dict", [ktdef, vtdef]), + self.chk.named_type("builtins.function"), + name="", + variables=[ktdef, vtdef], + ) + return self.check_call( + constructor, [e.key, e.value], [nodes.ARG_POS, nodes.ARG_POS], e + )[0] def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> None: """Check the for_comp part of comprehensions. That is the part from 'for': @@ -3865,8 +4350,9 @@ def check_for_comp(self, e: Union[GeneratorExpr, DictionaryComprehension]) -> No Note: This adds the type information derived from the condlists to the current binder. """ - for index, sequence, conditions, is_async in zip(e.indices, e.sequences, - e.condlists, e.is_async): + for index, sequence, conditions, is_async in zip( + e.indices, e.sequences, e.condlists, e.is_async + ): if is_async: _, sequence_type = self.chk.analyze_async_iterable_item_type(sequence) else: @@ -3900,8 +4386,9 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F elif else_map is None: self.msg.redundant_condition_in_if(True, e.cond) - if_type = self.analyze_cond_branch(if_map, e.if_expr, context=ctx, - allow_none_return=allow_none_return) + if_type = self.analyze_cond_branch( + if_map, e.if_expr, context=ctx, allow_none_return=allow_none_return + ) # we want to keep the narrowest value of if_type for union'ing the branches # however, it would be silly to pass a literal as a type context. Pass the @@ -3909,8 +4396,9 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F if_type_fallback = simple_literal_type(get_proper_type(if_type)) or if_type # Analyze the right branch using full type context and store the type - full_context_else_type = self.analyze_cond_branch(else_map, e.else_expr, context=ctx, - allow_none_return=allow_none_return) + full_context_else_type = self.analyze_cond_branch( + else_map, e.else_expr, context=ctx, allow_none_return=allow_none_return + ) if not mypy.checker.is_valid_inferred_type(if_type): # Analyze the right branch disregarding the left branch. @@ -3926,8 +4414,12 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F # TODO: If it's possible that the previous analysis of # the left branch produced errors that are avoided # using this context, suppress those errors. - if_type = self.analyze_cond_branch(if_map, e.if_expr, context=else_type_fallback, - allow_none_return=allow_none_return) + if_type = self.analyze_cond_branch( + if_map, + e.if_expr, + context=else_type_fallback, + allow_none_return=allow_none_return, + ) elif if_type_fallback == ctx: # There is no point re-running the analysis if if_type is equal to ctx. @@ -3939,8 +4431,12 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F else: # Analyze the right branch in the context of the left # branch's type. - else_type = self.analyze_cond_branch(else_map, e.else_expr, context=if_type_fallback, - allow_none_return=allow_none_return) + else_type = self.analyze_cond_branch( + else_map, + e.else_expr, + context=if_type_fallback, + allow_none_return=allow_none_return, + ) # Only create a union type if the type context is a union, to be mostly # compatible with older mypy versions where we always did a join. @@ -3953,9 +4449,13 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F return res - def analyze_cond_branch(self, map: Optional[Dict[Expression, Type]], - node: Expression, context: Optional[Type], - allow_none_return: bool = False) -> Type: + def analyze_cond_branch( + self, + map: Optional[Dict[Expression, Type]], + node: Expression, + context: Optional[Type], + allow_none_return: bool = False, + ) -> Type: with self.chk.binder.frame_context(can_skip=True, fall_through=0): if map is None: # We still need to type check node, in case we want to @@ -3967,18 +4467,19 @@ def analyze_cond_branch(self, map: Optional[Dict[Expression, Type]], def visit_backquote_expr(self, e: BackquoteExpr) -> Type: self.accept(e.expr) - return self.named_type('builtins.str') + return self.named_type("builtins.str") # # Helpers # - def accept(self, - node: Expression, - type_context: Optional[Type] = None, - allow_none_return: bool = False, - always_allow_any: bool = False, - ) -> Type: + def accept( + self, + node: Expression, + type_context: Optional[Type] = None, + allow_none_return: bool = False, + always_allow_any: bool = False, + ) -> Type: """Type check a node in the given type context. If allow_none_return is True and this expression is a call, allow it to return None. This applies only to this expression and not any subexpressions. @@ -3998,18 +4499,22 @@ def accept(self, else: typ = node.accept(self) except Exception as err: - report_internal_error(err, self.chk.errors.file, - node.line, self.chk.errors, self.chk.options) + report_internal_error( + err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options + ) self.type_context.pop() assert typ is not None self.chk.store_type(node, typ) - if (self.chk.options.disallow_any_expr and - not always_allow_any and - not self.chk.is_stub and - self.chk.in_checked_function() and - has_any_type(typ) and not self.chk.current_node_deferred): + if ( + self.chk.options.disallow_any_expr + and not always_allow_any + and not self.chk.is_stub + and self.chk.in_checked_function() + and has_any_type(typ) + and not self.chk.current_node_deferred + ): self.msg.disallowed_any_type(typ, node) if not self.chk.in_checked_function() or self.chk.current_node_deferred: @@ -4026,24 +4531,42 @@ def named_type(self, name: str) -> Instance: def is_valid_var_arg(self, typ: Type) -> bool: """Is a type valid as a *args argument?""" typ = get_proper_type(typ) - return (isinstance(typ, TupleType) or - is_subtype(typ, self.chk.named_generic_type('typing.Iterable', - [AnyType(TypeOfAny.special_form)])) or - isinstance(typ, AnyType) or - isinstance(typ, ParamSpecType)) + return ( + isinstance(typ, TupleType) + or is_subtype( + typ, + self.chk.named_generic_type("typing.Iterable", [AnyType(TypeOfAny.special_form)]), + ) + or isinstance(typ, AnyType) + or isinstance(typ, ParamSpecType) + ) def is_valid_keyword_var_arg(self, typ: Type) -> bool: """Is a type valid as a **kwargs argument?""" ret = ( - is_subtype(typ, self.chk.named_generic_type('typing.Mapping', - [self.named_type('builtins.str'), AnyType(TypeOfAny.special_form)])) or - is_subtype(typ, self.chk.named_generic_type('typing.Mapping', - [UninhabitedType(), UninhabitedType()])) or - isinstance(typ, ParamSpecType) + is_subtype( + typ, + self.chk.named_generic_type( + "typing.Mapping", + [self.named_type("builtins.str"), AnyType(TypeOfAny.special_form)], + ), + ) + or is_subtype( + typ, + self.chk.named_generic_type( + "typing.Mapping", [UninhabitedType(), UninhabitedType()] + ), + ) + or isinstance(typ, ParamSpecType) ) if self.chk.options.python_version[0] < 3: - ret = ret or is_subtype(typ, self.chk.named_generic_type('typing.Mapping', - [self.named_type('builtins.unicode'), AnyType(TypeOfAny.special_form)])) + ret = ret or is_subtype( + typ, + self.chk.named_generic_type( + "typing.Mapping", + [self.named_type("builtins.unicode"), AnyType(TypeOfAny.special_form)], + ), + ) return ret def has_member(self, typ: Type, member: str) -> bool: @@ -4097,41 +4620,50 @@ def visit_yield_expr(self, e: YieldExpr) -> Type: return_type = self.chk.return_types[-1] expected_item_type = self.chk.get_generator_yield_type(return_type, False) if e.expr is None: - if (not isinstance(get_proper_type(expected_item_type), (NoneType, AnyType)) - and self.chk.in_checked_function()): + if ( + not isinstance(get_proper_type(expected_item_type), (NoneType, AnyType)) + and self.chk.in_checked_function() + ): self.chk.fail(message_registry.YIELD_VALUE_EXPECTED, e) else: actual_item_type = self.accept(e.expr, expected_item_type) - self.chk.check_subtype(actual_item_type, expected_item_type, e, - message_registry.INCOMPATIBLE_TYPES_IN_YIELD, - 'actual type', 'expected type') + self.chk.check_subtype( + actual_item_type, + expected_item_type, + e, + message_registry.INCOMPATIBLE_TYPES_IN_YIELD, + "actual type", + "expected type", + ) return self.chk.get_generator_receive_type(return_type, False) def visit_await_expr(self, e: AwaitExpr, allow_none_return: bool = False) -> Type: expected_type = self.type_context[-1] if expected_type is not None: - expected_type = self.chk.named_generic_type('typing.Awaitable', [expected_type]) + expected_type = self.chk.named_generic_type("typing.Awaitable", [expected_type]) actual_type = get_proper_type(self.accept(e.expr, expected_type)) if isinstance(actual_type, AnyType): return AnyType(TypeOfAny.from_another_any, source_any=actual_type) - ret = self.check_awaitable_expr(actual_type, e, - message_registry.INCOMPATIBLE_TYPES_IN_AWAIT) + ret = self.check_awaitable_expr( + actual_type, e, message_registry.INCOMPATIBLE_TYPES_IN_AWAIT + ) if not allow_none_return and isinstance(get_proper_type(ret), NoneType): self.chk.msg.does_not_return_value(None, e) return ret def check_awaitable_expr( - self, t: Type, ctx: Context, msg: Union[str, ErrorMessage], ignore_binder: bool = False + self, t: Type, ctx: Context, msg: Union[str, ErrorMessage], ignore_binder: bool = False ) -> Type: """Check the argument to `await` and extract the type of value. Also used by `async for` and `async with`. """ - if not self.chk.check_subtype(t, self.named_type('typing.Awaitable'), ctx, - msg, 'actual type', 'expected type'): + if not self.chk.check_subtype( + t, self.named_type("typing.Awaitable"), ctx, msg, "actual type", "expected type" + ): return AnyType(TypeOfAny.special_form) else: - generator = self.check_method_call_by_name('__await__', t, [], [], ctx)[0] + generator = self.check_method_call_by_name("__await__", t, [], [], ctx)[0] ret_type = self.chk.get_generator_return_type(generator, False) ret_type = get_proper_type(ret_type) if ( @@ -4164,31 +4696,38 @@ def visit_yield_from_expr(self, e: YieldFromExpr, allow_none_return: bool = Fals self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) any_type = AnyType(TypeOfAny.special_form) - generic_generator_type = self.chk.named_generic_type('typing.Generator', - [any_type, any_type, any_type]) + generic_generator_type = self.chk.named_generic_type( + "typing.Generator", [any_type, any_type, any_type] + ) iter_type, _ = self.check_method_call_by_name( - '__iter__', subexpr_type, [], [], context=generic_generator_type) + "__iter__", subexpr_type, [], [], context=generic_generator_type + ) else: if not (is_async_def(subexpr_type) and has_coroutine_decorator(return_type)): self.chk.msg.yield_from_invalid_operand_type(subexpr_type, e) iter_type = AnyType(TypeOfAny.from_error) else: iter_type = self.check_awaitable_expr( - subexpr_type, e, message_registry.INCOMPATIBLE_TYPES_IN_YIELD_FROM) + subexpr_type, e, message_registry.INCOMPATIBLE_TYPES_IN_YIELD_FROM + ) # Check that the iterator's item type matches the type yielded by the Generator function # containing this `yield from` expression. expected_item_type = self.chk.get_generator_yield_type(return_type, False) actual_item_type = self.chk.get_generator_yield_type(iter_type, False) - self.chk.check_subtype(actual_item_type, expected_item_type, e, - message_registry.INCOMPATIBLE_TYPES_IN_YIELD_FROM, - 'actual type', 'expected type') + self.chk.check_subtype( + actual_item_type, + expected_item_type, + e, + message_registry.INCOMPATIBLE_TYPES_IN_YIELD_FROM, + "actual type", + "expected type", + ) # Determine the type of the entire yield from expression. iter_type = get_proper_type(iter_type) - if (isinstance(iter_type, Instance) and - iter_type.type.fullname == 'typing.Generator'): + if isinstance(iter_type, Instance) and iter_type.type.fullname == "typing.Generator": expr_type = self.chk.get_generator_return_type(iter_type, False) else: # Non-Generators don't return anything from `yield from` expressions. @@ -4222,11 +4761,13 @@ def visit_newtype_expr(self, e: NewTypeExpr) -> Type: def visit_namedtuple_expr(self, e: NamedTupleExpr) -> Type: tuple_type = e.info.tuple_type if tuple_type: - if (self.chk.options.disallow_any_unimported and - has_any_from_unimported_type(tuple_type)): + if self.chk.options.disallow_any_unimported and has_any_from_unimported_type( + tuple_type + ): self.msg.unimported_type_becomes_any("NamedTuple type", tuple_type, e) - check_for_explicit_any(tuple_type, self.chk.options, self.chk.is_typeshed_stub, - self.msg, context=e) + check_for_explicit_any( + tuple_type, self.chk.options, self.chk.is_typeshed_stub, self.msg, context=e + ) return AnyType(TypeOfAny.special_form) def visit_enum_call_expr(self, e: EnumCallExpr) -> Type: @@ -4255,21 +4796,25 @@ def visit_star_expr(self, e: StarExpr) -> StarType: def object_type(self) -> Instance: """Return instance type 'object'.""" - return self.named_type('builtins.object') + return self.named_type("builtins.object") def bool_type(self) -> Instance: """Return instance type 'bool'.""" - return self.named_type('builtins.bool') + return self.named_type("builtins.bool") @overload - def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: ... + def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: + ... @overload - def narrow_type_from_binder(self, expr: Expression, known_type: Type, - skip_non_overlapping: bool) -> Optional[Type]: ... + def narrow_type_from_binder( + self, expr: Expression, known_type: Type, skip_non_overlapping: bool + ) -> Optional[Type]: + ... - def narrow_type_from_binder(self, expr: Expression, known_type: Type, - skip_non_overlapping: bool = False) -> Optional[Type]: + def narrow_type_from_binder( + self, expr: Expression, known_type: Type, skip_non_overlapping: bool = False + ) -> Optional[Type]: """Narrow down a known type of expression using information in conditional type binder. If 'skip_non_overlapping' is True, return None if the type and restriction are @@ -4280,12 +4825,13 @@ def narrow_type_from_binder(self, expr: Expression, known_type: Type, # If the current node is deferred, some variables may get Any types that they # otherwise wouldn't have. We don't want to narrow down these since it may # produce invalid inferred Optional[Any] types, at least. - if restriction and not (isinstance(get_proper_type(known_type), AnyType) - and self.chk.current_node_deferred): + if restriction and not ( + isinstance(get_proper_type(known_type), AnyType) and self.chk.current_node_deferred + ): # Note: this call should match the one in narrow_declared_type(). - if (skip_non_overlapping and - not is_overlapping_types(known_type, restriction, - prohibit_none_typevar_overlap=True)): + if skip_non_overlapping and not is_overlapping_types( + known_type, restriction, prohibit_none_typevar_overlap=True + ): return None return narrow_declared_type(known_type, restriction) return known_type @@ -4313,7 +4859,7 @@ def visit_callable_type(self, t: CallableType) -> bool: def has_coroutine_decorator(t: Type) -> bool: """Whether t came from a function decorated with `@coroutine`.""" t = get_proper_type(t) - return isinstance(t, Instance) and t.type.fullname == 'typing.AwaitableGenerator' + return isinstance(t, Instance) and t.type.fullname == "typing.AwaitableGenerator" def is_async_def(t: Type) -> bool: @@ -4331,11 +4877,13 @@ def is_async_def(t: Type) -> bool: # function was an `async def`, which is orthogonal to its # decorations.) t = get_proper_type(t) - if (isinstance(t, Instance) - and t.type.fullname == 'typing.AwaitableGenerator' - and len(t.args) >= 4): + if ( + isinstance(t, Instance) + and t.type.fullname == "typing.AwaitableGenerator" + and len(t.args) >= 4 + ): t = get_proper_type(t.args[3]) - return isinstance(t, Instance) and t.type.fullname == 'typing.Coroutine' + return isinstance(t, Instance) and t.type.fullname == "typing.Coroutine" def is_non_empty_tuple(t: Type) -> bool: @@ -4343,24 +4891,28 @@ def is_non_empty_tuple(t: Type) -> bool: return isinstance(t, TupleType) and bool(t.items) -def is_duplicate_mapping(mapping: List[int], - actual_types: List[Type], - actual_kinds: List[ArgKind]) -> bool: +def is_duplicate_mapping( + mapping: List[int], actual_types: List[Type], actual_kinds: List[ArgKind] +) -> bool: return ( len(mapping) > 1 # Multiple actuals can map to the same formal if they both come from # varargs (*args and **kwargs); in this case at runtime it is possible # that here are no duplicates. We need to allow this, as the convention # f(..., *args, **kwargs) is common enough. - and not (len(mapping) == 2 - and actual_kinds[mapping[0]] == nodes.ARG_STAR - and actual_kinds[mapping[1]] == nodes.ARG_STAR2) + and not ( + len(mapping) == 2 + and actual_kinds[mapping[0]] == nodes.ARG_STAR + and actual_kinds[mapping[1]] == nodes.ARG_STAR2 + ) # Multiple actuals can map to the same formal if there are multiple # **kwargs which cannot be mapped with certainty (non-TypedDict # **kwargs). - and not all(actual_kinds[m] == nodes.ARG_STAR2 and - not isinstance(get_proper_type(actual_types[m]), TypedDictType) - for m in mapping) + and not all( + actual_kinds[m] == nodes.ARG_STAR2 + and not isinstance(get_proper_type(actual_types[m]), TypedDictType) + for m in mapping + ) ) @@ -4376,6 +4928,7 @@ class ArgInferSecondPassQuery(types.TypeQuery[bool]): type anywhere. For example, the result for Callable[[], T] is True if t is a type variable. """ + def __init__(self) -> None: super().__init__(any) @@ -4385,6 +4938,7 @@ def visit_callable_type(self, t: CallableType) -> bool: class HasTypeVarQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has a type variable component.""" + def __init__(self) -> None: super().__init__(any) @@ -4398,6 +4952,7 @@ def has_erased_component(t: Optional[Type]) -> bool: class HasErasedComponentsQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has an erased component.""" + def __init__(self) -> None: super().__init__(any) @@ -4411,6 +4966,7 @@ def has_uninhabited_component(t: Optional[Type]) -> bool: class HasUninhabitedComponentsQuery(types.TypeQuery[bool]): """Visitor for querying whether a type has an UninhabitedType component.""" + def __init__(self) -> None: super().__init__(any) @@ -4440,9 +4996,11 @@ def arg_approximate_similarity(actual: Type, formal: Type) -> bool: # Callable or Type[...]-ish types def is_typetype_like(typ: ProperType) -> bool: - return (isinstance(typ, TypeType) - or (isinstance(typ, FunctionLike) and typ.is_type_obj()) - or (isinstance(typ, Instance) and typ.type.fullname == "builtins.type")) + return ( + isinstance(typ, TypeType) + or (isinstance(typ, FunctionLike) and typ.is_type_obj()) + or (isinstance(typ, Instance) and typ.type.fullname == "builtins.type") + ) if isinstance(formal, CallableType): if isinstance(actual, (CallableType, Overloaded, TypeType)): @@ -4479,11 +5037,13 @@ def is_typetype_like(typ: ProperType) -> bool: return is_subtype(erasetype.erase_type(actual), erasetype.erase_type(formal)) -def any_causes_overload_ambiguity(items: List[CallableType], - return_types: List[Type], - arg_types: List[Type], - arg_kinds: List[ArgKind], - arg_names: Optional[Sequence[Optional[str]]]) -> bool: +def any_causes_overload_ambiguity( + items: List[CallableType], + return_types: List[Type], + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: Optional[Sequence[Optional[str]]], +) -> bool: """May an argument containing 'Any' cause ambiguous result type on call to overloaded function? Note that this sometimes returns True even if there is no ambiguity, since a correct @@ -4501,7 +5061,8 @@ def any_causes_overload_ambiguity(items: List[CallableType], actual_to_formal = [ map_formals_to_actuals( - arg_kinds, arg_names, item.arg_kinds, item.arg_names, lambda i: arg_types[i]) + arg_kinds, arg_names, item.arg_kinds, item.arg_names, lambda i: arg_types[i] + ) for item in items ] @@ -4510,9 +5071,11 @@ def any_causes_overload_ambiguity(items: List[CallableType], # creators, since that can lead to falsely claiming ambiguity # for overloads between Type and Callable. if has_any_type(arg_type, ignore_in_type_obj=True): - matching_formals_unfiltered = [(item_idx, lookup[arg_idx]) - for item_idx, lookup in enumerate(actual_to_formal) - if lookup[arg_idx]] + matching_formals_unfiltered = [ + (item_idx, lookup[arg_idx]) + for item_idx, lookup in enumerate(actual_to_formal) + if lookup[arg_idx] + ] matching_returns = [] matching_formals = [] @@ -4539,7 +5102,8 @@ def all_same_types(types: List[Type]) -> bool: def merge_typevars_in_callables_by_name( - callables: Sequence[CallableType]) -> Tuple[List[CallableType], List[TypeVarType]]: + callables: Sequence[CallableType], +) -> Tuple[List[CallableType], List[TypeVarType]]: """Takes all the typevars present in the callables and 'combines' the ones with the same name. For example, suppose we have two callables with signatures "f(x: T, y: S) -> T" and @@ -4597,8 +5161,9 @@ def is_expr_literal_type(node: Expression) -> bool: return isinstance(base, RefExpr) and base.fullname in LITERAL_TYPE_NAMES if isinstance(node, NameExpr): underlying = node.node - return isinstance(underlying, TypeAlias) and isinstance(get_proper_type(underlying.target), - LiteralType) + return isinstance(underlying, TypeAlias) and isinstance( + get_proper_type(underlying.target), LiteralType + ) return False @@ -4606,9 +5171,9 @@ def has_bytes_component(typ: Type, py2: bool = False) -> bool: """Is this one of builtin byte types, or a union that contains it?""" typ = get_proper_type(typ) if py2: - byte_types = {'builtins.str', 'builtins.bytearray'} + byte_types = {"builtins.str", "builtins.bytearray"} else: - byte_types = {'builtins.bytes', 'builtins.bytearray'} + byte_types = {"builtins.bytes", "builtins.bytearray"} if isinstance(typ, UnionType): return any(has_bytes_component(t) for t in typ.items) if isinstance(typ, Instance) and typ.type.fullname in byte_types: @@ -4638,11 +5203,12 @@ def type_info_from_type(typ: Type) -> Optional[TypeInfo]: def is_operator_method(fullname: Optional[str]) -> bool: if fullname is None: return False - short_name = fullname.split('.')[-1] + short_name = fullname.split(".")[-1] return ( - short_name in operators.op_methods.values() or - short_name in operators.reverse_op_methods.values() or - short_name in operators.unary_op_methods.values()) + short_name in operators.op_methods.values() + or short_name in operators.reverse_op_methods.values() + or short_name in operators.unary_op_methods.values() + ) def get_partial_instance_type(t: Optional[Type]) -> Optional[PartialType]: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 6f089e35e50f1..2b5b8898950eb 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -1,32 +1,70 @@ """Type checking of attribute access""" -from typing import cast, Callable, Optional, Union, Sequence +from typing import Callable, Optional, Sequence, Union, cast + from typing_extensions import TYPE_CHECKING -from mypy.types import ( - Type, Instance, AnyType, TupleType, TypedDictType, CallableType, FunctionLike, - TypeVarLikeType, Overloaded, TypeVarType, UnionType, PartialType, TypeOfAny, LiteralType, - DeletedType, NoneType, TypeType, has_type_vars, get_proper_type, ProperType, ParamSpecType, - TypeVarTupleType, ENUM_REMOVED_PROPS -) +from mypy import meet, message_registry, subtypes +from mypy.erasetype import erase_typevars +from mypy.expandtype import expand_type_by_instance, freshen_function_type_vars +from mypy.maptype import map_instance_to_supertype +from mypy.messages import MessageBuilder from mypy.nodes import ( - TypeInfo, FuncBase, Var, FuncDef, SymbolNode, SymbolTable, Context, - MypyFile, TypeVarExpr, ARG_POS, ARG_STAR, ARG_STAR2, Decorator, - OverloadedFuncDef, TypeAlias, TempNode, is_final_node, - SYMBOL_FUNCBASE_TYPES, IndexExpr + ARG_POS, + ARG_STAR, + ARG_STAR2, + SYMBOL_FUNCBASE_TYPES, + Context, + Decorator, + FuncBase, + FuncDef, + IndexExpr, + MypyFile, + OverloadedFuncDef, + SymbolNode, + SymbolTable, + TempNode, + TypeAlias, + TypeInfo, + TypeVarExpr, + Var, + is_final_node, ) -from mypy.messages import MessageBuilder -from mypy.maptype import map_instance_to_supertype -from mypy.expandtype import expand_type_by_instance, freshen_function_type_vars -from mypy.erasetype import erase_typevars from mypy.plugin import AttributeContext from mypy.typeanal import set_any_tvars -from mypy import message_registry -from mypy import subtypes -from mypy import meet from mypy.typeops import ( - tuple_fallback, bind_self, erase_to_bound, class_callable, type_object_type_from_function, - make_simplified_union, function_type, + bind_self, + class_callable, + erase_to_bound, + function_type, + make_simplified_union, + tuple_fallback, + type_object_type_from_function, +) +from mypy.types import ( + ENUM_REMOVED_PROPS, + AnyType, + CallableType, + DeletedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UnionType, + get_proper_type, + has_type_vars, ) if TYPE_CHECKING: # import for forward declaration only @@ -41,16 +79,18 @@ class MemberContext: Look at the docstring of analyze_member_access for more information. """ - def __init__(self, - is_lvalue: bool, - is_super: bool, - is_operator: bool, - original_type: Type, - context: Context, - msg: MessageBuilder, - chk: 'mypy.checker.TypeChecker', - self_type: Optional[Type], - module_symbol_table: Optional[SymbolTable] = None) -> None: + def __init__( + self, + is_lvalue: bool, + is_super: bool, + is_operator: bool, + original_type: Type, + context: Context, + msg: MessageBuilder, + chk: "mypy.checker.TypeChecker", + self_type: Optional[Type], + module_symbol_table: Optional[SymbolTable] = None, + ) -> None: self.is_lvalue = is_lvalue self.is_super = is_super self.is_operator = is_operator @@ -67,12 +107,24 @@ def named_type(self, name: str) -> Instance: def not_ready_callback(self, name: str, context: Context) -> None: self.chk.handle_cannot_determine_type(name, context) - def copy_modified(self, *, messages: Optional[MessageBuilder] = None, - self_type: Optional[Type] = None, - is_lvalue: Optional[bool] = None) -> 'MemberContext': - mx = MemberContext(self.is_lvalue, self.is_super, self.is_operator, - self.original_type, self.context, self.msg, self.chk, - self.self_type, self.module_symbol_table) + def copy_modified( + self, + *, + messages: Optional[MessageBuilder] = None, + self_type: Optional[Type] = None, + is_lvalue: Optional[bool] = None, + ) -> "MemberContext": + mx = MemberContext( + self.is_lvalue, + self.is_super, + self.is_operator, + self.original_type, + self.context, + self.msg, + self.chk, + self.self_type, + self.module_symbol_table, + ) if messages is not None: mx.msg = messages if self_type is not None: @@ -82,19 +134,22 @@ def copy_modified(self, *, messages: Optional[MessageBuilder] = None, return mx -def analyze_member_access(name: str, - typ: Type, - context: Context, - is_lvalue: bool, - is_super: bool, - is_operator: bool, - msg: MessageBuilder, *, - original_type: Type, - chk: 'mypy.checker.TypeChecker', - override_info: Optional[TypeInfo] = None, - in_literal_context: bool = False, - self_type: Optional[Type] = None, - module_symbol_table: Optional[SymbolTable] = None) -> Type: +def analyze_member_access( + name: str, + typ: Type, + context: Context, + is_lvalue: bool, + is_super: bool, + is_operator: bool, + msg: MessageBuilder, + *, + original_type: Type, + chk: "mypy.checker.TypeChecker", + override_info: Optional[TypeInfo] = None, + in_literal_context: bool = False, + self_type: Optional[Type] = None, + module_symbol_table: Optional[SymbolTable] = None, +) -> Type: """Return the type of attribute 'name' of 'typ'. The actual implementation is in '_analyze_member_access' and this docstring @@ -118,28 +173,32 @@ def analyze_member_access(name: str, and we want to keep track of the available attributes of the module (since they are not available via the type object directly) """ - mx = MemberContext(is_lvalue, - is_super, - is_operator, - original_type, - context, - msg, - chk=chk, - self_type=self_type, - module_symbol_table=module_symbol_table) + mx = MemberContext( + is_lvalue, + is_super, + is_operator, + original_type, + context, + msg, + chk=chk, + self_type=self_type, + module_symbol_table=module_symbol_table, + ) result = _analyze_member_access(name, typ, mx, override_info) possible_literal = get_proper_type(result) - if (in_literal_context and isinstance(possible_literal, Instance) and - possible_literal.last_known_value is not None): + if ( + in_literal_context + and isinstance(possible_literal, Instance) + and possible_literal.last_known_value is not None + ): return possible_literal.last_known_value else: return result -def _analyze_member_access(name: str, - typ: Type, - mx: MemberContext, - override_info: Optional[TypeInfo] = None) -> Type: +def _analyze_member_access( + name: str, typ: Type, mx: MemberContext, override_info: Optional[TypeInfo] = None +) -> Type: # TODO: This and following functions share some logic with subtypes.find_member; # consider refactoring. typ = get_proper_type(typ) @@ -175,10 +234,7 @@ def _analyze_member_access(name: str, def may_be_awaitable_attribute( - name: str, - typ: Type, - mx: MemberContext, - override_info: Optional[TypeInfo] = None + name: str, typ: Type, mx: MemberContext, override_info: Optional[TypeInfo] = None ) -> bool: """Check if the given type has the attribute when awaited.""" if mx.chk.checking_missing_await: @@ -197,7 +253,7 @@ def report_missing_attribute( typ: Type, name: str, mx: MemberContext, - override_info: Optional[TypeInfo] = None + override_info: Optional[TypeInfo] = None, ) -> Type: res_type = mx.msg.has_no_attr(original_type, typ, name, mx.context, mx.module_symbol_table) if may_be_awaitable_attribute(name, typ, mx, override_info): @@ -209,11 +265,10 @@ def report_missing_attribute( # types and aren't documented individually. -def analyze_instance_member_access(name: str, - typ: Instance, - mx: MemberContext, - override_info: Optional[TypeInfo]) -> Type: - if name == '__init__' and not mx.is_super: +def analyze_instance_member_access( + name: str, typ: Instance, mx: MemberContext, override_info: Optional[TypeInfo] +) -> Type: + if name == "__init__" and not mx.is_super: # Accessing __init__ in statically typed code would compromise # type safety unless used via super(). mx.msg.fail(message_registry.CANNOT_ACCESS_INIT, mx.context) @@ -225,9 +280,11 @@ def analyze_instance_member_access(name: str, if override_info: info = override_info - if (state.find_occurrences and - info.name == state.find_occurrences[0] and - name == state.find_occurrences[1]): + if ( + state.find_occurrences + and info.name == state.find_occurrences[0] + and name == state.find_occurrences[1] + ): mx.msg.note("Occurrence of '{}.{}'".format(*state.find_occurrences), mx.context) # Look up the member. First look up the method dictionary. @@ -239,19 +296,20 @@ def analyze_instance_member_access(name: str, return analyze_var(name, first_item.var, typ, info, mx) if mx.is_lvalue: mx.msg.cant_assign_to_method(mx.context) - signature = function_type(method, mx.named_type('builtins.function')) + signature = function_type(method, mx.named_type("builtins.function")) signature = freshen_function_type_vars(signature) - if name == '__new__': + if name == "__new__": # __new__ is special and behaves like a static method -- don't strip # the first argument. pass else: - if isinstance(signature, FunctionLike) and name != '__call__': + if isinstance(signature, FunctionLike) and name != "__call__": # TODO: use proper treatment of special methods on unions instead # of this hack here and below (i.e. mx.self_type). dispatched_type = meet.meet_types(mx.original_type, typ) - signature = check_self_arg(signature, dispatched_type, method.is_class, - mx.context, name, mx.msg) + signature = check_self_arg( + signature, dispatched_type, method.is_class, mx.context, name, mx.msg + ) signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class) typ = map_instance_to_supertype(typ, method.info) member_type = expand_type_by_instance(signature, typ) @@ -262,9 +320,7 @@ def analyze_instance_member_access(name: str, return analyze_member_var_access(name, typ, info, mx) -def analyze_type_callable_member_access(name: str, - typ: FunctionLike, - mx: MemberContext) -> Type: +def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: MemberContext) -> Type: # Class attribute. # TODO super? ret_type = typ.items[0].ret_type @@ -287,23 +343,23 @@ def analyze_type_callable_member_access(name: str, # the corresponding method in the current instance to avoid this edge case. # See https://github.com/python/mypy/pull/1787 for more info. # TODO: do not rely on same type variables being present in all constructor overloads. - result = analyze_class_attribute_access(ret_type, name, mx, - original_vars=typ.items[0].variables) + result = analyze_class_attribute_access( + ret_type, name, mx, original_vars=typ.items[0].variables + ) if result: return result # Look up from the 'type' type. return _analyze_member_access(name, typ.fallback, mx) else: - assert False, f'Unexpected type {ret_type!r}' + assert False, f"Unexpected type {ret_type!r}" -def analyze_type_type_member_access(name: str, - typ: TypeType, - mx: MemberContext, - override_info: Optional[TypeInfo]) -> Type: +def analyze_type_type_member_access( + name: str, typ: TypeType, mx: MemberContext, override_info: Optional[TypeInfo] +) -> Type: # Similar to analyze_type_callable_attribute_access. item = None - fallback = mx.named_type('builtins.type') + fallback = mx.named_type("builtins.type") if isinstance(typ.item, Instance): item = typ.item elif isinstance(typ.item, AnyType): @@ -357,23 +413,24 @@ def analyze_none_member_access(name: str, typ: NoneType, mx: MemberContext) -> T is_python_3 = mx.chk.options.python_version[0] >= 3 # In Python 2 "None" has exactly the same attributes as "object". Python 3 adds a single # extra attribute, "__bool__". - if is_python_3 and name == '__bool__': - literal_false = LiteralType(False, fallback=mx.named_type('builtins.bool')) - return CallableType(arg_types=[], - arg_kinds=[], - arg_names=[], - ret_type=literal_false, - fallback=mx.named_type('builtins.function')) + if is_python_3 and name == "__bool__": + literal_false = LiteralType(False, fallback=mx.named_type("builtins.bool")) + return CallableType( + arg_types=[], + arg_kinds=[], + arg_names=[], + ret_type=literal_false, + fallback=mx.named_type("builtins.function"), + ) elif mx.chk.should_suppress_optional_error([typ]): return AnyType(TypeOfAny.from_error) else: - return _analyze_member_access(name, mx.named_type('builtins.object'), mx) + return _analyze_member_access(name, mx.named_type("builtins.object"), mx) -def analyze_member_var_access(name: str, - itype: Instance, - info: TypeInfo, - mx: MemberContext) -> Type: +def analyze_member_var_access( + name: str, itype: Instance, info: TypeInfo, mx: MemberContext +) -> Type: """Analyse attribute access that does not target a method. This is logically part of analyze_member_access and the arguments are similar. @@ -417,22 +474,31 @@ def analyze_member_var_access(name: str, return analyze_var(name, v, itype, info, mx, implicit=implicit) elif isinstance(v, FuncDef): assert False, "Did not expect a function" - elif (not v and name not in ['__getattr__', '__setattr__', '__getattribute__'] and - not mx.is_operator and mx.module_symbol_table is None): + elif ( + not v + and name not in ["__getattr__", "__setattr__", "__getattribute__"] + and not mx.is_operator + and mx.module_symbol_table is None + ): # Above we skip ModuleType.__getattr__ etc. if we have a # module symbol table, since the symbol table allows precise # checking. if not mx.is_lvalue: - for method_name in ('__getattribute__', '__getattr__'): + for method_name in ("__getattribute__", "__getattr__"): method = info.get_method(method_name) # __getattribute__ is defined on builtins.object and returns Any, so without # the guard this search will always find object.__getattribute__ and conclude # that the attribute exists - if method and method.info.fullname != 'builtins.object': + if method and method.info.fullname != "builtins.object": bound_method = analyze_decorator_or_funcbase_access( - defn=method, itype=itype, info=info, - self_type=mx.self_type, name=method_name, mx=mx) + defn=method, + itype=itype, + info=info, + self_type=mx.self_type, + name=method_name, + mx=mx, + ) typ = map_instance_to_supertype(itype, method.info) getattr_type = get_proper_type(expand_type_by_instance(bound_method, typ)) if isinstance(getattr_type, CallableType): @@ -441,19 +507,26 @@ def analyze_member_var_access(name: str, result = getattr_type # Call the attribute hook before returning. - fullname = f'{method.info.fullname}.{name}' + fullname = f"{method.info.fullname}.{name}" hook = mx.chk.plugin.get_attribute_hook(fullname) if hook: - result = hook(AttributeContext(get_proper_type(mx.original_type), - result, mx.context, mx.chk)) + result = hook( + AttributeContext( + get_proper_type(mx.original_type), result, mx.context, mx.chk + ) + ) return result else: - setattr_meth = info.get_method('__setattr__') - if setattr_meth and setattr_meth.info.fullname != 'builtins.object': + setattr_meth = info.get_method("__setattr__") + if setattr_meth and setattr_meth.info.fullname != "builtins.object": bound_type = analyze_decorator_or_funcbase_access( - defn=setattr_meth, itype=itype, info=info, - self_type=mx.self_type, name=name, - mx=mx.copy_modified(is_lvalue=False)) + defn=setattr_meth, + itype=itype, + info=info, + self_type=mx.self_type, + name=name, + mx=mx.copy_modified(is_lvalue=False), + ) typ = map_instance_to_supertype(itype, setattr_meth.info) setattr_type = get_proper_type(expand_type_by_instance(bound_type, typ)) if isinstance(setattr_type, CallableType) and len(setattr_type.arg_types) > 0: @@ -480,8 +553,7 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx) -def analyze_descriptor_access(descriptor_type: Type, - mx: MemberContext) -> Type: +def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type: """Type check descriptor access. Arguments: @@ -496,25 +568,30 @@ def analyze_descriptor_access(descriptor_type: Type, if isinstance(descriptor_type, UnionType): # Map the access over union types - return make_simplified_union([ - analyze_descriptor_access(typ, mx) - for typ in descriptor_type.items - ]) + return make_simplified_union( + [analyze_descriptor_access(typ, mx) for typ in descriptor_type.items] + ) elif not isinstance(descriptor_type, Instance): return descriptor_type - if not descriptor_type.type.has_readable_member('__get__'): + if not descriptor_type.type.has_readable_member("__get__"): return descriptor_type - dunder_get = descriptor_type.type.get_method('__get__') + dunder_get = descriptor_type.type.get_method("__get__") if dunder_get is None: - mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), - mx.context) + mx.msg.fail( + message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), mx.context + ) return AnyType(TypeOfAny.from_error) bound_method = analyze_decorator_or_funcbase_access( - defn=dunder_get, itype=descriptor_type, info=descriptor_type.type, - self_type=descriptor_type, name='__set__', mx=mx) + defn=dunder_get, + itype=descriptor_type, + info=descriptor_type.type, + self_type=descriptor_type, + name="__set__", + mx=mx, + ) typ = map_instance_to_supertype(descriptor_type, dunder_get.info) dunder_get_type = expand_type_by_instance(bound_method, typ) @@ -530,18 +607,28 @@ def analyze_descriptor_access(descriptor_type: Type, callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__get__") dunder_get_type = mx.chk.expr_checker.transform_callee_type( - callable_name, dunder_get_type, - [TempNode(instance_type, context=mx.context), - TempNode(TypeType.make_normalized(owner_type), context=mx.context)], - [ARG_POS, ARG_POS], mx.context, object_type=descriptor_type, + callable_name, + dunder_get_type, + [ + TempNode(instance_type, context=mx.context), + TempNode(TypeType.make_normalized(owner_type), context=mx.context), + ], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, ) _, inferred_dunder_get_type = mx.chk.expr_checker.check_call( dunder_get_type, - [TempNode(instance_type, context=mx.context), - TempNode(TypeType.make_normalized(owner_type), context=mx.context)], - [ARG_POS, ARG_POS], mx.context, object_type=descriptor_type, - callable_name=callable_name) + [ + TempNode(instance_type, context=mx.context), + TempNode(TypeType.make_normalized(owner_type), context=mx.context), + ], + [ARG_POS, ARG_POS], + mx.context, + object_type=descriptor_type, + callable_name=callable_name, + ) inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type) if isinstance(inferred_dunder_get_type, AnyType): @@ -549,34 +636,38 @@ def analyze_descriptor_access(descriptor_type: Type, return inferred_dunder_get_type if not isinstance(inferred_dunder_get_type, CallableType): - mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), - mx.context) + mx.msg.fail( + message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), mx.context + ) return AnyType(TypeOfAny.from_error) return inferred_dunder_get_type.ret_type -def instance_alias_type(alias: TypeAlias, - named_type: Callable[[str], Instance]) -> Type: +def instance_alias_type(alias: TypeAlias, named_type: Callable[[str], Instance]) -> Type: """Type of a type alias node targeting an instance, when appears in runtime context. As usual, we first erase any unbound type variables to Any. """ target: Type = get_proper_type(alias.target) - assert isinstance(get_proper_type(target), - Instance), "Must be called only with aliases to classes" + assert isinstance( + get_proper_type(target), Instance + ), "Must be called only with aliases to classes" target = get_proper_type(set_any_tvars(alias, alias.line, alias.column)) assert isinstance(target, Instance) tp = type_object_type(target.type, named_type) return expand_type_by_instance(tp, target) -def analyze_var(name: str, - var: Var, - itype: Instance, - info: TypeInfo, - mx: MemberContext, *, - implicit: bool = False) -> Type: +def analyze_var( + name: str, + var: Var, + itype: Instance, + info: TypeInfo, + mx: MemberContext, + *, + implicit: bool = False, +) -> Type: """Analyze access to an attribute via a Var node. This is conceptually part of analyze_member_access and the arguments are similar. @@ -620,8 +711,9 @@ def analyze_var(name: str, # and similarly for B1 when checking against B dispatched_type = meet.meet_types(mx.original_type, itype) signature = freshen_function_type_vars(functype) - signature = check_self_arg(signature, dispatched_type, var.is_classmethod, - mx.context, name, mx.msg) + signature = check_self_arg( + signature, dispatched_type, var.is_classmethod, mx.context, name, mx.msg + ) signature = bind_self(signature, mx.self_type, var.is_classmethod) expanded_signature = get_proper_type(expand_type_by_instance(signature, itype)) freeze_type_vars(expanded_signature) @@ -636,13 +728,14 @@ def analyze_var(name: str, mx.not_ready_callback(var.name, mx.context) # Implicit 'Any' type. result = AnyType(TypeOfAny.special_form) - fullname = f'{var.info.fullname}.{name}' + fullname = f"{var.info.fullname}.{name}" hook = mx.chk.plugin.get_attribute_hook(fullname) if result and not mx.is_lvalue and not implicit: result = analyze_descriptor_access(result, mx) if hook: - result = hook(AttributeContext(get_proper_type(mx.original_type), - result, mx.context, mx.chk)) + result = hook( + AttributeContext(get_proper_type(mx.original_type), result, mx.context, mx.chk) + ) return result @@ -658,8 +751,9 @@ def freeze_type_vars(member_type: Type) -> None: v.id.meta_level = 0 -def lookup_member_var_or_accessor(info: TypeInfo, name: str, - is_lvalue: bool) -> Optional[SymbolNode]: +def lookup_member_var_or_accessor( + info: TypeInfo, name: str, is_lvalue: bool +) -> Optional[SymbolNode]: """Find the attribute/accessor node that refers to a member of a type.""" # TODO handle lvalues node = info.get(name) @@ -669,11 +763,14 @@ def lookup_member_var_or_accessor(info: TypeInfo, name: str, return None -def check_self_arg(functype: FunctionLike, - dispatched_arg_type: Type, - is_classmethod: bool, - context: Context, name: str, - msg: MessageBuilder) -> FunctionLike: +def check_self_arg( + functype: FunctionLike, + dispatched_arg_type: Type, + is_classmethod: bool, + context: Context, + name: str, + msg: MessageBuilder, +) -> FunctionLike: """Check that an instance has a valid type for a method with annotated 'self'. For example if the method is defined as: @@ -712,20 +809,22 @@ def f(self: S) -> T: ... raise NotImplementedError if not new_items: # Choose first item for the message (it may be not very helpful for overloads). - msg.incompatible_self_argument(name, dispatched_arg_type, items[0], - is_classmethod, context) + msg.incompatible_self_argument( + name, dispatched_arg_type, items[0], is_classmethod, context + ) return functype if len(new_items) == 1: return new_items[0] return Overloaded(new_items) -def analyze_class_attribute_access(itype: Instance, - name: str, - mx: MemberContext, - override_info: Optional[TypeInfo] = None, - original_vars: Optional[Sequence[TypeVarLikeType]] = None - ) -> Optional[Type]: +def analyze_class_attribute_access( + itype: Instance, + name: str, + mx: MemberContext, + override_info: Optional[TypeInfo] = None, + original_vars: Optional[Sequence[TypeVarLikeType]] = None, +) -> Optional[Type]: """Analyze access to an attribute on a class object. itype is the return type of the class object callable, original_type is the type @@ -736,7 +835,7 @@ def analyze_class_attribute_access(itype: Instance, if override_info: info = override_info - fullname = '{}.{}'.format(info.fullname, name) + fullname = "{}.{}".format(info.fullname, name) hook = mx.chk.plugin.get_class_attribute_hook(fullname) node = info.get(name) @@ -756,8 +855,9 @@ def analyze_class_attribute_access(itype: Instance, # If a final attribute was declared on `self` in `__init__`, then it # can't be accessed on the class object. if node.implicit and isinstance(node.node, Var) and node.node.is_final: - mx.msg.fail(message_registry.CANNOT_ACCESS_FINAL_INSTANCE_ATTR - .format(node.node.name), mx.context) + mx.msg.fail( + message_registry.CANNOT_ACCESS_FINAL_INSTANCE_ATTR.format(node.node.name), mx.context + ) # An assignment to final attribute on class object is also always an error, # independently of types. @@ -774,9 +874,9 @@ def analyze_class_attribute_access(itype: Instance, if isinstance(t, PartialType): symnode = node.node assert isinstance(symnode, Var) - return apply_class_attr_hook(mx, hook, - mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode, - mx.context)) + return apply_class_attr_hook( + mx, hook, mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode, mx.context) + ) # Find the class where method/variable was defined. if isinstance(node.node, Decorator): @@ -818,13 +918,15 @@ def analyze_class_attribute_access(itype: Instance, # C[int].x -> int t = erase_typevars(expand_type_by_instance(t, isuper)) - is_classmethod = ((is_decorated and cast(Decorator, node.node).func.is_class) - or (isinstance(node.node, FuncBase) and node.node.is_class)) + is_classmethod = (is_decorated and cast(Decorator, node.node).func.is_class) or ( + isinstance(node.node, FuncBase) and node.node.is_class + ) t = get_proper_type(t) if isinstance(t, FunctionLike) and is_classmethod: t = check_self_arg(t, mx.self_type, False, mx.context, name, mx.msg) - result = add_class_tvars(t, isuper, is_classmethod, - mx.self_type, original_vars=original_vars) + result = add_class_tvars( + t, isuper, is_classmethod, mx.self_type, original_vars=original_vars + ) if not mx.is_lvalue: result = analyze_descriptor_access(result, mx) @@ -834,8 +936,9 @@ def analyze_class_attribute_access(itype: Instance, return AnyType(TypeOfAny.special_form) if isinstance(node.node, TypeVarExpr): - mx.msg.fail(message_registry.CANNOT_USE_TYPEVAR_AS_EXPRESSION.format( - info.name, name), mx.context) + mx.msg.fail( + message_registry.CANNOT_USE_TYPEVAR_AS_EXPRESSION.format(info.name, name), mx.context + ) return AnyType(TypeOfAny.from_error) if isinstance(node.node, TypeInfo): @@ -843,10 +946,11 @@ def analyze_class_attribute_access(itype: Instance, if isinstance(node.node, MypyFile): # Reference to a module object. - return mx.named_type('types.ModuleType') + return mx.named_type("types.ModuleType") - if (isinstance(node.node, TypeAlias) and - isinstance(get_proper_type(node.node.target), Instance)): + if isinstance(node.node, TypeAlias) and isinstance( + get_proper_type(node.node.target), Instance + ): return instance_alias_type(node.node, mx.named_type) if is_decorated: @@ -858,7 +962,7 @@ def analyze_class_attribute_access(itype: Instance, return AnyType(TypeOfAny.from_error) else: assert isinstance(node.node, FuncBase) - typ = function_type(node.node, mx.named_type('builtins.function')) + typ = function_type(node.node, mx.named_type("builtins.function")) # Note: if we are accessing class method on class object, the cls argument is bound. # Annotated and/or explicit class methods go through other code paths above, for # unannotated implicit class methods we do this here. @@ -867,39 +971,38 @@ def analyze_class_attribute_access(itype: Instance, return apply_class_attr_hook(mx, hook, typ) -def apply_class_attr_hook(mx: MemberContext, - hook: Optional[Callable[[AttributeContext], Type]], - result: Type, - ) -> Optional[Type]: +def apply_class_attr_hook( + mx: MemberContext, hook: Optional[Callable[[AttributeContext], Type]], result: Type +) -> Optional[Type]: if hook: - result = hook(AttributeContext(get_proper_type(mx.original_type), - result, mx.context, mx.chk)) + result = hook( + AttributeContext(get_proper_type(mx.original_type), result, mx.context, mx.chk) + ) return result -def analyze_enum_class_attribute_access(itype: Instance, - name: str, - mx: MemberContext, - ) -> Optional[Type]: +def analyze_enum_class_attribute_access( + itype: Instance, name: str, mx: MemberContext +) -> Optional[Type]: # Skip these since Enum will remove it if name in ENUM_REMOVED_PROPS: return report_missing_attribute(mx.original_type, itype, name, mx) # For other names surrendered by underscores, we don't make them Enum members - if name.startswith('__') and name.endswith("__") and name.replace('_', '') != '': + if name.startswith("__") and name.endswith("__") and name.replace("_", "") != "": return None enum_literal = LiteralType(name, fallback=itype) return itype.copy_modified(last_known_value=enum_literal) -def analyze_typeddict_access(name: str, typ: TypedDictType, - mx: MemberContext, override_info: Optional[TypeInfo]) -> Type: - if name == '__setitem__': +def analyze_typeddict_access( + name: str, typ: TypedDictType, mx: MemberContext, override_info: Optional[TypeInfo] +) -> Type: + if name == "__setitem__": if isinstance(mx.context, IndexExpr): # Since we can get this during `a['key'] = ...` # it is safe to assume that the context is `IndexExpr`. - item_type = mx.chk.expr_checker.visit_typeddict_index_expr( - typ, mx.context.index) + item_type = mx.chk.expr_checker.visit_typeddict_index_expr(typ, mx.context.index) else: # It can also be `a.__setitem__(...)` direct call. # In this case `item_type` can be `Any`, @@ -907,29 +1010,32 @@ def analyze_typeddict_access(name: str, typ: TypedDictType, # TODO: check in `default` plugin that `__setitem__` is correct. item_type = AnyType(TypeOfAny.implementation_artifact) return CallableType( - arg_types=[mx.chk.named_type('builtins.str'), item_type], + arg_types=[mx.chk.named_type("builtins.str"), item_type], arg_kinds=[ARG_POS, ARG_POS], arg_names=[None, None], ret_type=NoneType(), - fallback=mx.chk.named_type('builtins.function'), + fallback=mx.chk.named_type("builtins.function"), name=name, ) - elif name == '__delitem__': + elif name == "__delitem__": return CallableType( - arg_types=[mx.chk.named_type('builtins.str')], + arg_types=[mx.chk.named_type("builtins.str")], arg_kinds=[ARG_POS], arg_names=[None], ret_type=NoneType(), - fallback=mx.chk.named_type('builtins.function'), + fallback=mx.chk.named_type("builtins.function"), name=name, ) return _analyze_member_access(name, typ.fallback, mx, override_info) -def add_class_tvars(t: ProperType, isuper: Optional[Instance], - is_classmethod: bool, - original_type: Type, - original_vars: Optional[Sequence[TypeVarLikeType]] = None) -> Type: +def add_class_tvars( + t: ProperType, + isuper: Optional[Instance], + is_classmethod: bool, + original_type: Type, + original_vars: Optional[Sequence[TypeVarLikeType]] = None, +) -> Type: """Instantiate type variables during analyze_class_attribute_access, e.g T and Q in the following: @@ -975,10 +1081,17 @@ class B(A[str]): pass freeze_type_vars(t) return t.copy_modified(variables=list(tvars) + list(t.variables)) elif isinstance(t, Overloaded): - return Overloaded([cast(CallableType, add_class_tvars(item, isuper, - is_classmethod, original_type, - original_vars=original_vars)) - for item in t.items]) + return Overloaded( + [ + cast( + CallableType, + add_class_tvars( + item, isuper, is_classmethod, original_type, original_vars=original_vars + ), + ) + for item in t.items + ] + ) if isuper is not None: t = cast(ProperType, expand_type_by_instance(t, isuper)) return t @@ -997,8 +1110,8 @@ def type_object_type(info: TypeInfo, named_type: Callable[[str], Instance]) -> P # We take the type from whichever of __init__ and __new__ is first # in the MRO, preferring __init__ if there is a tie. - init_method = info.get('__init__') - new_method = info.get('__new__') + init_method = info.get("__init__") + new_method = info.get("__new__") if not init_method or not is_valid_constructor(init_method.node): # Must be an invalid class definition. return AnyType(TypeOfAny.from_error) @@ -1016,7 +1129,7 @@ def type_object_type(info: TypeInfo, named_type: Callable[[str], Instance]) -> P init_index = info.mro.index(init_method.node.info) new_index = info.mro.index(new_method.node.info) - fallback = info.metaclass_type or named_type('builtins.type') + fallback = info.metaclass_type or named_type("builtins.type") if init_index < new_index: method: Union[FuncBase, Decorator] = init_method.node is_new = False @@ -1024,17 +1137,19 @@ def type_object_type(info: TypeInfo, named_type: Callable[[str], Instance]) -> P method = new_method.node is_new = True else: - if init_method.node.info.fullname == 'builtins.object': + if init_method.node.info.fullname == "builtins.object": # Both are defined by object. But if we've got a bogus # base class, we can't know for sure, so check for that. if info.fallback_to_any: # Construct a universal callable as the prototype. any_type = AnyType(TypeOfAny.special_form) - sig = CallableType(arg_types=[any_type, any_type], - arg_kinds=[ARG_STAR, ARG_STAR2], - arg_names=["_args", "_kwds"], - ret_type=any_type, - fallback=named_type('builtins.function')) + sig = CallableType( + arg_types=[any_type, any_type], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=["_args", "_kwds"], + ret_type=any_type, + fallback=named_type("builtins.function"), + ) return class_callable(sig, info, fallback, None, is_new=False) # Otherwise prefer __init__ in a tie. It isn't clear that this @@ -1069,8 +1184,7 @@ def analyze_decorator_or_funcbase_access( if isinstance(defn, Decorator): return analyze_var(name, defn.var, itype, info, mx) return bind_self( - function_type(defn, mx.chk.named_type('builtins.function')), - original_type=self_type, + function_type(defn, mx.chk.named_type("builtins.function")), original_type=self_type ) diff --git a/mypy/checkpattern.py b/mypy/checkpattern.py index 978b03b342f59..a6390367d6a72 100644 --- a/mypy/checkpattern.py +++ b/mypy/checkpattern.py @@ -1,30 +1,51 @@ """Pattern checker. This file is conceptually part of TypeChecker.""" from collections import defaultdict -from typing import List, Optional, Tuple, Dict, NamedTuple, Set, Union +from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union + from typing_extensions import Final import mypy.checker +from mypy import message_registry from mypy.checkmember import analyze_member_access from mypy.expandtype import expand_type_by_instance from mypy.join import join_types from mypy.literals import literal_hash from mypy.maptype import map_instance_to_supertype from mypy.meet import narrow_declared_type -from mypy import message_registry from mypy.messages import MessageBuilder -from mypy.nodes import Expression, ARG_POS, TypeAlias, TypeInfo, Var, NameExpr +from mypy.nodes import ARG_POS, Expression, NameExpr, TypeAlias, TypeInfo, Var from mypy.patterns import ( - Pattern, AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, - ClassPattern, SingletonPattern + AsPattern, + ClassPattern, + MappingPattern, + OrPattern, + Pattern, + SequencePattern, + SingletonPattern, + StarredPattern, + ValuePattern, ) from mypy.plugin import Plugin from mypy.subtypes import is_subtype -from mypy.typeops import try_getting_str_literals_from_type, make_simplified_union, \ - coerce_to_literal +from mypy.typeops import ( + coerce_to_literal, + make_simplified_union, + try_getting_str_literals_from_type, +) from mypy.types import ( - LiteralType, ProperType, AnyType, TypeOfAny, Instance, Type, UninhabitedType, get_proper_type, - TypedDictType, TupleType, NoneType, UnionType + AnyType, + Instance, + LiteralType, + NoneType, + ProperType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + UninhabitedType, + UnionType, + get_proper_type, ) from mypy.typevars import fill_typevars from mypy.visitor import PatternVisitor @@ -43,11 +64,7 @@ "builtins.tuple", ] -non_sequence_match_type_names: Final = [ - "builtins.str", - "builtins.bytes", - "builtins.bytearray" -] +non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"] # For every Pattern a PatternType can be calculated. This requires recursively calculating @@ -67,7 +84,7 @@ class PatternChecker(PatternVisitor[PatternType]): """ # Some services are provided by a TypeChecker instance. - chk: 'mypy.checker.TypeChecker' + chk: "mypy.checker.TypeChecker" # This is shared with TypeChecker, but stored also here for convenience. msg: MessageBuilder # Currently unused @@ -85,10 +102,9 @@ class PatternChecker(PatternVisitor[PatternType]): # non_sequence_match_type_names non_sequence_match_types: List[Type] - def __init__(self, - chk: 'mypy.checker.TypeChecker', - msg: MessageBuilder, plugin: Plugin - ) -> None: + def __init__( + self, chk: "mypy.checker.TypeChecker", msg: MessageBuilder, plugin: Plugin + ) -> None: self.chk = chk self.msg = msg self.plugin = plugin @@ -115,10 +131,9 @@ def visit_as_pattern(self, o: AsPattern) -> PatternType: typ, rest_type, type_map = current_type, UninhabitedType(), {} if not is_uninhabited(typ) and o.name is not None: - typ, _ = self.chk.conditional_types_with_intersection(current_type, - [get_type_range(typ)], - o, - default=current_type) + typ, _ = self.chk.conditional_types_with_intersection( + current_type, [get_type_range(typ)], o, default=current_type + ) if not is_uninhabited(typ): type_map[o.name] = typ @@ -178,10 +193,7 @@ def visit_value_pattern(self, o: ValuePattern) -> PatternType: typ = self.chk.expr_checker.accept(o.expr) typ = coerce_to_literal(typ) narrowed_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, - [get_type_range(typ)], - o, - default=current_type + current_type, [get_type_range(typ)], o, default=current_type ) if not isinstance(get_proper_type(narrowed_type), (LiteralType, UninhabitedType)): return PatternType(narrowed_type, UnionType.make_union([narrowed_type, rest_type]), {}) @@ -198,10 +210,7 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType: assert False narrowed_type, rest_type = self.chk.conditional_types_with_intersection( - current_type, - [get_type_range(typ)], - o, - default=current_type + current_type, [get_type_range(typ)], o, default=current_type ) return PatternType(narrowed_type, rest_type, {}) @@ -245,9 +254,9 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: contracted_rest_inner_types: List[Type] = [] captures: Dict[Expression, Type] = {} - contracted_inner_types = self.contract_starred_pattern_types(inner_types, - star_position, - required_patterns) + contracted_inner_types = self.contract_starred_pattern_types( + inner_types, star_position, required_patterns + ) can_match = True for p, t in zip(o.patterns, contracted_inner_types): pattern_type = self.accept(p, t) @@ -258,12 +267,12 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: contracted_new_inner_types.append(typ) contracted_rest_inner_types.append(rest) self.update_type_map(captures, type_map) - new_inner_types = self.expand_starred_pattern_types(contracted_new_inner_types, - star_position, - len(inner_types)) - rest_inner_types = self.expand_starred_pattern_types(contracted_rest_inner_types, - star_position, - len(inner_types)) + new_inner_types = self.expand_starred_pattern_types( + contracted_new_inner_types, star_position, len(inner_types) + ) + rest_inner_types = self.expand_starred_pattern_types( + contracted_rest_inner_types, star_position, len(inner_types) + ) # # Calculate new type @@ -276,13 +285,12 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: narrowed_inner_types = [] inner_rest_types = [] for inner_type, new_inner_type in zip(inner_types, new_inner_types): - narrowed_inner_type, inner_rest_type = \ - self.chk.conditional_types_with_intersection( - new_inner_type, - [get_type_range(inner_type)], - o, - default=new_inner_type - ) + ( + narrowed_inner_type, + inner_rest_type, + ) = self.chk.conditional_types_with_intersection( + new_inner_type, [get_type_range(inner_type)], o, default=new_inner_type + ) narrowed_inner_types.append(narrowed_inner_type) inner_rest_types.append(inner_rest_type) if all(not is_uninhabited(typ) for typ in narrowed_inner_types): @@ -300,10 +308,7 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType: new_type = self.construct_sequence_child(current_type, new_inner_type) if is_subtype(new_type, current_type): new_type, _ = self.chk.conditional_types_with_intersection( - current_type, - [get_type_range(new_type)], - o, - default=current_type + current_type, [get_type_range(new_type)], o, default=current_type ) else: new_type = current_type @@ -326,11 +331,9 @@ def get_sequence_type(self, t: Type) -> Optional[Type]: else: return None - def contract_starred_pattern_types(self, - types: List[Type], - star_pos: Optional[int], - num_patterns: int - ) -> List[Type]: + def contract_starred_pattern_types( + self, types: List[Type], star_pos: Optional[int], num_patterns: int + ) -> List[Type]: """ Contracts a list of types in a sequence pattern depending on the position of a starred capture pattern. @@ -344,16 +347,14 @@ def contract_starred_pattern_types(self, return types new_types = types[:star_pos] star_length = len(types) - num_patterns - new_types.append(make_simplified_union(types[star_pos:star_pos+star_length])) - new_types += types[star_pos+star_length:] + new_types.append(make_simplified_union(types[star_pos : star_pos + star_length])) + new_types += types[star_pos + star_length :] return new_types - def expand_starred_pattern_types(self, - types: List[Type], - star_pos: Optional[int], - num_types: int - ) -> List[Type]: + def expand_starred_pattern_types( + self, types: List[Type], star_pos: Optional[int], num_types: int + ) -> List[Type]: """Undoes the contraction done by contract_starred_pattern_types. For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended @@ -364,14 +365,14 @@ def expand_starred_pattern_types(self, new_types = types[:star_pos] star_length = num_types - len(types) + 1 new_types += [types[star_pos]] * star_length - new_types += types[star_pos+1:] + new_types += types[star_pos + 1 :] return new_types def visit_starred_pattern(self, o: StarredPattern) -> PatternType: captures: Dict[Expression, Type] = {} if o.capture is not None: - list_type = self.chk.named_generic_type('builtins.list', [self.type_context[-1]]) + list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]]) captures[o.capture] = list_type return PatternType(self.type_context[-1], UninhabitedType(), captures) @@ -398,8 +399,9 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: rest_type = Instance(dict_typeinfo, mapping_inst.args) else: object_type = self.chk.named_type("builtins.object") - rest_type = self.chk.named_generic_type("builtins.dict", - [object_type, object_type]) + rest_type = self.chk.named_generic_type( + "builtins.dict", [object_type, object_type] + ) captures[o.rest] = rest_type @@ -410,44 +412,35 @@ def visit_mapping_pattern(self, o: MappingPattern) -> PatternType: new_type = UninhabitedType() return PatternType(new_type, current_type, captures) - def get_mapping_item_type(self, - pattern: MappingPattern, - mapping_type: Type, - key: Expression - ) -> Optional[Type]: + def get_mapping_item_type( + self, pattern: MappingPattern, mapping_type: Type, key: Expression + ) -> Optional[Type]: mapping_type = get_proper_type(mapping_type) if isinstance(mapping_type, TypedDictType): with self.msg.filter_errors() as local_errors: result: Optional[Type] = self.chk.expr_checker.visit_typeddict_index_expr( - mapping_type, key) + mapping_type, key + ) has_local_errors = local_errors.has_new_errors() # If we can't determine the type statically fall back to treating it as a normal # mapping if has_local_errors: with self.msg.filter_errors() as local_errors: - result = self.get_simple_mapping_item_type(pattern, - mapping_type, - key) + result = self.get_simple_mapping_item_type(pattern, mapping_type, key) if local_errors.has_new_errors(): result = None else: with self.msg.filter_errors(): - result = self.get_simple_mapping_item_type(pattern, - mapping_type, - key) + result = self.get_simple_mapping_item_type(pattern, mapping_type, key) return result - def get_simple_mapping_item_type(self, - pattern: MappingPattern, - mapping_type: Type, - key: Expression - ) -> Type: - result, _ = self.chk.expr_checker.check_method_call_by_name('__getitem__', - mapping_type, - [key], - [ARG_POS], - pattern) + def get_simple_mapping_item_type( + self, pattern: MappingPattern, mapping_type: Type, key: Expression + ) -> Type: + result, _ = self.chk.expr_checker.check_method_call_by_name( + "__getitem__", mapping_type, [key], [ARG_POS], pattern + ) return result def visit_class_pattern(self, o: ClassPattern) -> PatternType: @@ -497,17 +490,25 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o) pattern_type = self.accept(o.positionals[0], narrowed_type) if not is_uninhabited(pattern_type.type): - return PatternType(pattern_type.type, - join_types(rest_type, pattern_type.rest_type), - pattern_type.captures) + return PatternType( + pattern_type.type, + join_types(rest_type, pattern_type.rest_type), + pattern_type.captures, + ) captures = pattern_type.captures else: with self.msg.filter_errors() as local_errors: - match_args_type = analyze_member_access("__match_args__", typ, o, - False, False, False, - self.msg, - original_type=typ, - chk=self.chk) + match_args_type = analyze_member_access( + "__match_args__", + typ, + o, + False, + False, + False, + self.msg, + original_type=typ, + chk=self.chk, + ) has_local_errors = local_errors.has_new_errors() if has_local_errors: self.msg.fail(message_registry.MISSING_MATCH_ARGS.format(typ), o) @@ -537,13 +538,13 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: keyword_pairs.append((key, value)) if key in match_arg_set: self.msg.fail( - message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), - value + message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value ) has_duplicates = True elif key in keyword_arg_set: - self.msg.fail(message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), - value) + self.msg.fail( + message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value + ) has_duplicates = True keyword_arg_set.add(key) @@ -558,22 +559,25 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType: key_type: Optional[Type] = None with self.msg.filter_errors() as local_errors: if keyword is not None: - key_type = analyze_member_access(keyword, - narrowed_type, - pattern, - False, - False, - False, - self.msg, - original_type=new_type, - chk=self.chk) + key_type = analyze_member_access( + keyword, + narrowed_type, + pattern, + False, + False, + False, + self.msg, + original_type=new_type, + chk=self.chk, + ) else: key_type = AnyType(TypeOfAny.from_error) has_local_errors = local_errors.has_new_errors() if has_local_errors or key_type is None: key_type = AnyType(TypeOfAny.from_error) - self.msg.fail(message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword), - pattern) + self.msg.fail( + message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(typ, keyword), pattern + ) inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type) if is_uninhabited(inner_type): @@ -615,24 +619,24 @@ def generate_types_from_names(self, type_names: List[str]) -> List[Type]: types.append(self.chk.named_type(name)) except KeyError as e: # Some built in types are not defined in all test cases - if not name.startswith('builtins.'): + if not name.startswith("builtins."): raise e pass return types - def update_type_map(self, - original_type_map: Dict[Expression, Type], - extra_type_map: Dict[Expression, Type] - ) -> None: + def update_type_map( + self, original_type_map: Dict[Expression, Type], extra_type_map: Dict[Expression, Type] + ) -> None: # Calculating this would not be needed if TypeMap directly used literal hashes instead of # expressions, as suggested in the TODO above it's definition already_captured = {literal_hash(expr) for expr in original_type_map} for expr, typ in extra_type_map.items(): if literal_hash(expr) in already_captured: node = get_var(expr) - self.msg.fail(message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), - expr) + self.msg.fail( + message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr + ) else: original_type_map[expr] = typ @@ -648,7 +652,8 @@ def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type: proper_type = get_proper_type(outer_type) if isinstance(proper_type, UnionType): types = [ - self.construct_sequence_child(item, inner_type) for item in proper_type.items + self.construct_sequence_child(item, inner_type) + for item in proper_type.items if self.can_match_sequence(get_proper_type(item)) ] return make_simplified_union(types) @@ -688,11 +693,13 @@ def get_var(expr: Expression) -> Var: return node -def get_type_range(typ: Type) -> 'mypy.checker.TypeRange': +def get_type_range(typ: Type) -> "mypy.checker.TypeRange": typ = get_proper_type(typ) - if (isinstance(typ, Instance) - and typ.last_known_value - and isinstance(typ.last_known_value.value, bool)): + if ( + isinstance(typ, Instance) + and typ.last_known_value + and isinstance(typ.last_known_value.value, bool) + ): typ = typ.last_known_value return mypy.checker.TypeRange(typ, is_upper_bound=False) diff --git a/mypy/checkstrformat.py b/mypy/checkstrformat.py index 60a0d35ede080..dc5e5098630b4 100644 --- a/mypy/checkstrformat.py +++ b/mypy/checkstrformat.py @@ -11,34 +11,59 @@ """ import re +from typing import Callable, Dict, List, Match, Optional, Pattern, Set, Tuple, Union, cast -from typing import ( - cast, List, Tuple, Dict, Callable, Union, Optional, Pattern, Match, Set -) -from typing_extensions import Final, TYPE_CHECKING, TypeAlias as _TypeAlias +from typing_extensions import TYPE_CHECKING, Final, TypeAlias as _TypeAlias +import mypy.errorcodes as codes from mypy.errors import Errors -from mypy.types import ( - Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType, - LiteralType, get_proper_types -) from mypy.nodes import ( - StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr, - IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2, - Node, MypyFile, ExpressionStmt, NameExpr, IntExpr + ARG_NAMED, + ARG_POS, + ARG_STAR, + ARG_STAR2, + BytesExpr, + CallExpr, + Context, + DictExpr, + Expression, + ExpressionStmt, + IndexExpr, + IntExpr, + MemberExpr, + MypyFile, + NameExpr, + Node, + StarExpr, + StrExpr, + TempNode, + TupleExpr, + UnicodeExpr, +) +from mypy.types import ( + AnyType, + Instance, + LiteralType, + TupleType, + Type, + TypeOfAny, + TypeVarType, + UnionType, + get_proper_type, + get_proper_types, ) -import mypy.errorcodes as codes if TYPE_CHECKING: # break import cycle only needed for mypy import mypy.checker import mypy.checkexpr + from mypy import message_registry -from mypy.messages import MessageBuilder from mypy.maptype import map_instance_to_supertype -from mypy.typeops import custom_special_method -from mypy.subtypes import is_subtype +from mypy.messages import MessageBuilder from mypy.parse import parse +from mypy.subtypes import is_subtype +from mypy.typeops import custom_special_method FormatStringExpr: _TypeAlias = Union[StrExpr, BytesExpr, UnicodeExpr] Checkers: _TypeAlias = Tuple[Callable[[Expression], None], Callable[[Type], bool]] @@ -51,13 +76,13 @@ def compile_format_re() -> Pattern[str]: See https://docs.python.org/3/library/stdtypes.html#printf-style-string-formatting The regexp is intentionally a bit wider to report better errors. """ - key_re = r'(\((?P[^)]*)\))?' # (optional) parenthesised sequence of characters. - flags_re = r'(?P[#0\-+ ]*)' # (optional) sequence of flags. - width_re = r'(?P[1-9][0-9]*|\*)?' # (optional) minimum field width (* or numbers). - precision_re = r'(?:\.(?P\*|[0-9]+)?)?' # (optional) . followed by * of numbers. - length_mod_re = r'[hlL]?' # (optional) length modifier (unused). - type_re = r'(?P.)?' # conversion type. - format_re = '%' + key_re + flags_re + width_re + precision_re + length_mod_re + type_re + key_re = r"(\((?P[^)]*)\))?" # (optional) parenthesised sequence of characters. + flags_re = r"(?P[#0\-+ ]*)" # (optional) sequence of flags. + width_re = r"(?P[1-9][0-9]*|\*)?" # (optional) minimum field width (* or numbers). + precision_re = r"(?:\.(?P\*|[0-9]+)?)?" # (optional) . followed by * of numbers. + length_mod_re = r"[hlL]?" # (optional) length modifier (unused). + type_re = r"(?P.)?" # conversion type. + format_re = "%" + key_re + flags_re + width_re + precision_re + length_mod_re + type_re return re.compile(format_re) @@ -70,25 +95,25 @@ def compile_new_format_re(custom_spec: bool) -> Pattern[str]: """ # Field (optional) is an integer/identifier possibly followed by several .attr and [index]. - field = r'(?P(?P[^.[!:]*)([^:!]+)?)' + field = r"(?P(?P[^.[!:]*)([^:!]+)?)" # Conversion (optional) is ! followed by one of letters for forced repr(), str(), or ascii(). - conversion = r'(?P![^:])?' + conversion = r"(?P![^:])?" # Format specification (optional) follows its own mini-language: if not custom_spec: # Fill and align is valid for all builtin types. - fill_align = r'(?P.?[<>=^])?' + fill_align = r"(?P.?[<>=^])?" # Number formatting options are only valid for int, float, complex, and Decimal, # except if only width is given (it is valid for all types). # This contains sign, flags (sign, # and/or 0), width, grouping (_ or ,) and precision. - num_spec = r'(?P[+\- ]?#?0?)(?P\d+)?[_,]?(?P\.\d+)?' + num_spec = r"(?P[+\- ]?#?0?)(?P\d+)?[_,]?(?P\.\d+)?" # The last element is type. - conv_type = r'(?P.)?' # only some are supported, but we want to give a better error - format_spec = r'(?P:' + fill_align + num_spec + conv_type + r')?' + conv_type = r"(?P.)?" # only some are supported, but we want to give a better error + format_spec = r"(?P:" + fill_align + num_spec + conv_type + r")?" else: # Custom types can define their own form_spec using __format__(). - format_spec = r'(?P:.*)?' + format_spec = r"(?P:.*)?" return re.compile(field + conversion + format_spec) @@ -99,8 +124,23 @@ def compile_new_format_re(custom_spec: bool) -> Pattern[str]: DUMMY_FIELD_NAME: Final = "__dummy_name__" # Format types supported by str.format() for builtin classes. -SUPPORTED_TYPES_NEW: Final = {"b", "c", "d", "e", "E", "f", "F", - "g", "G", "n", "o", "s", "x", "X", "%"} +SUPPORTED_TYPES_NEW: Final = { + "b", + "c", + "d", + "e", + "E", + "f", + "F", + "g", + "G", + "n", + "o", + "s", + "x", + "X", + "%", +} # Types that require either int or float. NUMERIC_TYPES_OLD: Final = {"d", "i", "o", "u", "x", "X", "e", "E", "f", "F", "g", "G"} @@ -115,36 +155,36 @@ def compile_new_format_re(custom_spec: bool) -> Pattern[str]: class ConversionSpecifier: - def __init__(self, match: Match[str], - start_pos: int = -1, - non_standard_format_spec: bool = False) -> None: + def __init__( + self, match: Match[str], start_pos: int = -1, non_standard_format_spec: bool = False + ) -> None: self.whole_seq = match.group() self.start_pos = start_pos m_dict = match.groupdict() - self.key = m_dict.get('key') + self.key = m_dict.get("key") # Replace unmatched optional groups with empty matches (for convenience). - self.conv_type = m_dict.get('type', '') - self.flags = m_dict.get('flags', '') - self.width = m_dict.get('width', '') - self.precision = m_dict.get('precision', '') + self.conv_type = m_dict.get("type", "") + self.flags = m_dict.get("flags", "") + self.width = m_dict.get("width", "") + self.precision = m_dict.get("precision", "") # Used only for str.format() calls (it may be custom for types with __format__()). - self.format_spec = m_dict.get('format_spec') + self.format_spec = m_dict.get("format_spec") self.non_standard_format_spec = non_standard_format_spec # Used only for str.format() calls. - self.conversion = m_dict.get('conversion') + self.conversion = m_dict.get("conversion") # Full formatted expression (i.e. key plus following attributes and/or indexes). # Used only for str.format() calls. - self.field = m_dict.get('field') + self.field = m_dict.get("field") def has_key(self) -> bool: return self.key is not None def has_star(self) -> bool: - return self.width == '*' or self.precision == '*' + return self.width == "*" or self.precision == "*" def parse_conversion_specifiers(format_str: str) -> List[ConversionSpecifier]: @@ -155,8 +195,9 @@ def parse_conversion_specifiers(format_str: str) -> List[ConversionSpecifier]: return specifiers -def parse_format_value(format_value: str, ctx: Context, msg: MessageBuilder, - nested: bool = False) -> Optional[List[ConversionSpecifier]]: +def parse_format_value( + format_value: str, ctx: Context, msg: MessageBuilder, nested: bool = False +) -> Optional[List[ConversionSpecifier]]: """Parse format string into list of conversion specifiers. The specifiers may be nested (two levels maximum), in this case they are ordered as @@ -175,36 +216,44 @@ def parse_format_value(format_value: str, ctx: Context, msg: MessageBuilder, custom_match = FORMAT_RE_NEW_CUSTOM.fullmatch(target) if custom_match: conv_spec = ConversionSpecifier( - custom_match, start_pos=start_pos, - non_standard_format_spec=True) + custom_match, start_pos=start_pos, non_standard_format_spec=True + ) else: - msg.fail('Invalid conversion specifier in format string', - ctx, code=codes.STRING_FORMATTING) + msg.fail( + "Invalid conversion specifier in format string", + ctx, + code=codes.STRING_FORMATTING, + ) return None - if conv_spec.key and ('{' in conv_spec.key or '}' in conv_spec.key): - msg.fail('Conversion value must not contain { or }', - ctx, code=codes.STRING_FORMATTING) + if conv_spec.key and ("{" in conv_spec.key or "}" in conv_spec.key): + msg.fail("Conversion value must not contain { or }", ctx, code=codes.STRING_FORMATTING) return None result.append(conv_spec) # Parse nested conversions that are allowed in format specifier. - if (conv_spec.format_spec and conv_spec.non_standard_format_spec and - ('{' in conv_spec.format_spec or '}' in conv_spec.format_spec)): + if ( + conv_spec.format_spec + and conv_spec.non_standard_format_spec + and ("{" in conv_spec.format_spec or "}" in conv_spec.format_spec) + ): if nested: - msg.fail('Formatting nesting must be at most two levels deep', - ctx, code=codes.STRING_FORMATTING) + msg.fail( + "Formatting nesting must be at most two levels deep", + ctx, + code=codes.STRING_FORMATTING, + ) return None - sub_conv_specs = parse_format_value(conv_spec.format_spec, ctx, msg, - nested=True) + sub_conv_specs = parse_format_value(conv_spec.format_spec, ctx, msg, nested=True) if sub_conv_specs is None: return None result.extend(sub_conv_specs) return result -def find_non_escaped_targets(format_value: str, ctx: Context, - msg: MessageBuilder) -> Optional[List[Tuple[str, int]]]: +def find_non_escaped_targets( + format_value: str, ctx: Context, msg: MessageBuilder +) -> Optional[List[Tuple[str, int]]]: """Return list of raw (un-parsed) format specifiers in format string. Format specifiers don't include enclosing braces. We don't use regexp for @@ -215,40 +264,46 @@ def find_non_escaped_targets(format_value: str, ctx: Context, Return None in case of an error. """ result = [] - next_spec = '' + next_spec = "" pos = 0 nesting = 0 while pos < len(format_value): c = format_value[pos] if not nesting: # Skip any paired '{{' and '}}', enter nesting on '{', report error on '}'. - if c == '{': - if pos < len(format_value) - 1 and format_value[pos + 1] == '{': + if c == "{": + if pos < len(format_value) - 1 and format_value[pos + 1] == "{": pos += 1 else: nesting = 1 - if c == '}': - if pos < len(format_value) - 1 and format_value[pos + 1] == '}': + if c == "}": + if pos < len(format_value) - 1 and format_value[pos + 1] == "}": pos += 1 else: - msg.fail('Invalid conversion specifier in format string:' - ' unexpected }', ctx, code=codes.STRING_FORMATTING) + msg.fail( + "Invalid conversion specifier in format string:" " unexpected }", + ctx, + code=codes.STRING_FORMATTING, + ) return None else: # Adjust nesting level, then either continue adding chars or move on. - if c == '{': + if c == "{": nesting += 1 - if c == '}': + if c == "}": nesting -= 1 if nesting: next_spec += c else: result.append((next_spec, pos - len(next_spec))) - next_spec = '' + next_spec = "" pos += 1 if nesting: - msg.fail('Invalid conversion specifier in format string:' - ' unmatched {', ctx, code=codes.STRING_FORMATTING) + msg.fail( + "Invalid conversion specifier in format string:" " unmatched {", + ctx, + code=codes.STRING_FORMATTING, + ) return None return result @@ -266,10 +321,12 @@ class StringFormatterChecker: # Some services are provided by a ExpressionChecker instance. exprchk: "mypy.checkexpr.ExpressionChecker" - def __init__(self, - exprchk: 'mypy.checkexpr.ExpressionChecker', - chk: 'mypy.checker.TypeChecker', - msg: MessageBuilder) -> None: + def __init__( + self, + exprchk: "mypy.checkexpr.ExpressionChecker", + chk: "mypy.checker.TypeChecker", + msg: MessageBuilder, + ) -> None: """Construct an expression type checker.""" self.chk = chk self.exprchk = exprchk @@ -306,8 +363,9 @@ def check_str_format_call(self, call: CallExpr, format_value: str) -> None: return self.check_specs_in_format_call(call, conv_specs, format_value) - def check_specs_in_format_call(self, call: CallExpr, - specs: List[ConversionSpecifier], format_value: str) -> None: + def check_specs_in_format_call( + self, call: CallExpr, specs: List[ConversionSpecifier], format_value: str + ) -> None: """Perform pairwise checks for conversion specifiers vs their replacements. The core logic for format checking is implemented in this method. @@ -321,15 +379,23 @@ def check_specs_in_format_call(self, call: CallExpr, assert actual_type is not None # Special case custom formatting. - if (spec.format_spec and spec.non_standard_format_spec and - # Exclude "dynamic" specifiers (i.e. containing nested formatting). - not ('{' in spec.format_spec or '}' in spec.format_spec)): - if (not custom_special_method(actual_type, '__format__', check_all=True) or - spec.conversion): + if ( + spec.format_spec + and spec.non_standard_format_spec + and + # Exclude "dynamic" specifiers (i.e. containing nested formatting). + not ("{" in spec.format_spec or "}" in spec.format_spec) + ): + if ( + not custom_special_method(actual_type, "__format__", check_all=True) + or spec.conversion + ): # TODO: add support for some custom specs like datetime? - self.msg.fail('Unrecognized format' - ' specification "{}"'.format(spec.format_spec[1:]), - call, code=codes.STRING_FORMATTING) + self.msg.fail( + "Unrecognized format" ' specification "{}"'.format(spec.format_spec[1:]), + call, + code=codes.STRING_FORMATTING, + ) continue # Adjust expected and actual types. if not spec.conv_type: @@ -340,34 +406,44 @@ def check_specs_in_format_call(self, call: CallExpr, format_str = call.callee.expr else: format_str = StrExpr(format_value) - expected_type = self.conversion_type(spec.conv_type, call, format_str, - format_call=True) + expected_type = self.conversion_type( + spec.conv_type, call, format_str, format_call=True + ) if spec.conversion is not None: # If the explicit conversion is given, then explicit conversion is called _first_. - if spec.conversion[1] not in 'rsa': - self.msg.fail('Invalid conversion type "{}",' - ' must be one of "r", "s" or "a"'.format(spec.conversion[1]), - call, code=codes.STRING_FORMATTING) - actual_type = self.named_type('builtins.str') + if spec.conversion[1] not in "rsa": + self.msg.fail( + 'Invalid conversion type "{}",' + ' must be one of "r", "s" or "a"'.format(spec.conversion[1]), + call, + code=codes.STRING_FORMATTING, + ) + actual_type = self.named_type("builtins.str") # Perform the checks for given types. if expected_type is None: continue a_type = get_proper_type(actual_type) - actual_items = (get_proper_types(a_type.items) if isinstance(a_type, UnionType) - else [a_type]) + actual_items = ( + get_proper_types(a_type.items) if isinstance(a_type, UnionType) else [a_type] + ) for a_type in actual_items: - if custom_special_method(a_type, '__format__'): + if custom_special_method(a_type, "__format__"): continue self.check_placeholder_type(a_type, expected_type, call) self.perform_special_format_checks(spec, call, repl, a_type, expected_type) - def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExpr, - repl: Expression, actual_type: Type, - expected_type: Type) -> None: + def perform_special_format_checks( + self, + spec: ConversionSpecifier, + call: CallExpr, + repl: Expression, + actual_type: Type, + expected_type: Type, + ) -> None: # TODO: try refactoring to combine this logic with % formatting. - if spec.conv_type == 'c': + if spec.conv_type == "c": if isinstance(repl, (StrExpr, BytesExpr)) and len(repl.value) != 1: self.msg.requires_int_or_char(call, format_call=True) c_typ = get_proper_type(self.chk.lookup_type(repl)) @@ -376,26 +452,36 @@ def perform_special_format_checks(self, spec: ConversionSpecifier, call: CallExp if isinstance(c_typ, LiteralType) and isinstance(c_typ.value, str): if len(c_typ.value) != 1: self.msg.requires_int_or_char(call, format_call=True) - if (not spec.conv_type or spec.conv_type == 's') and not spec.conversion: + if (not spec.conv_type or spec.conv_type == "s") and not spec.conversion: if self.chk.options.python_version >= (3, 0): - if (has_type_component(actual_type, 'builtins.bytes') and - not custom_special_method(actual_type, '__str__')): + if has_type_component(actual_type, "builtins.bytes") and not custom_special_method( + actual_type, "__str__" + ): self.msg.fail( 'On Python 3 formatting "b\'abc\'" with "{}" ' 'produces "b\'abc\'", not "abc"; ' 'use "{!r}" if this is desired behavior', - call, code=codes.STR_BYTES_PY3) + call, + code=codes.STR_BYTES_PY3, + ) if spec.flags: - numeric_types = UnionType([self.named_type('builtins.int'), - self.named_type('builtins.float')]) - if (spec.conv_type and spec.conv_type not in NUMERIC_TYPES_NEW or - not spec.conv_type and not is_subtype(actual_type, numeric_types) and - not custom_special_method(actual_type, '__format__')): - self.msg.fail('Numeric flags are only allowed for numeric types', call, - code=codes.STRING_FORMATTING) - - def find_replacements_in_call(self, call: CallExpr, - keys: List[str]) -> List[Expression]: + numeric_types = UnionType( + [self.named_type("builtins.int"), self.named_type("builtins.float")] + ) + if ( + spec.conv_type + and spec.conv_type not in NUMERIC_TYPES_NEW + or not spec.conv_type + and not is_subtype(actual_type, numeric_types) + and not custom_special_method(actual_type, "__format__") + ): + self.msg.fail( + "Numeric flags are only allowed for numeric types", + call, + code=codes.STRING_FORMATTING, + ) + + def find_replacements_in_call(self, call: CallExpr, keys: List[str]) -> List[Expression]: """Find replacement expression for every specifier in str.format() call. In case of an error use TempNode(AnyType). @@ -406,16 +492,21 @@ def find_replacements_in_call(self, call: CallExpr, if key.isdecimal(): expr = self.get_expr_by_position(int(key), call) if not expr: - self.msg.fail('Cannot find replacement for positional' - ' format specifier {}'.format(key), call, - code=codes.STRING_FORMATTING) + self.msg.fail( + "Cannot find replacement for positional" + " format specifier {}".format(key), + call, + code=codes.STRING_FORMATTING, + ) expr = TempNode(AnyType(TypeOfAny.from_error)) else: expr = self.get_expr_by_name(key, call) if not expr: - self.msg.fail('Cannot find replacement for named' - ' format specifier "{}"'.format(key), call, - code=codes.STRING_FORMATTING) + self.msg.fail( + "Cannot find replacement for named" ' format specifier "{}"'.format(key), + call, + code=codes.STRING_FORMATTING, + ) expr = TempNode(AnyType(TypeOfAny.from_error)) result.append(expr) if not isinstance(expr, TempNode): @@ -443,12 +534,14 @@ def get_expr_by_position(self, pos: int, call: CallExpr) -> Optional[Expression] # Fall back to *args when present in call. star_arg = star_args[0] varargs_type = get_proper_type(self.chk.lookup_type(star_arg)) - if (not isinstance(varargs_type, Instance) or not - varargs_type.type.has_base('typing.Sequence')): + if not isinstance(varargs_type, Instance) or not varargs_type.type.has_base( + "typing.Sequence" + ): # Error should be already reported. return TempNode(AnyType(TypeOfAny.special_form)) - iter_info = self.chk.named_generic_type('typing.Sequence', - [AnyType(TypeOfAny.special_form)]).type + iter_info = self.chk.named_generic_type( + "typing.Sequence", [AnyType(TypeOfAny.special_form)] + ).type return TempNode(map_instance_to_supertype(varargs_type, iter_info).args[0]) def get_expr_by_name(self, key: str, call: CallExpr) -> Optional[Expression]: @@ -457,8 +550,11 @@ def get_expr_by_name(self, key: str, call: CallExpr) -> Optional[Expression]: If the type is from **kwargs, return TempNode(). Return None in case of an error. """ - named_args = [arg for arg, kind, name in zip(call.args, call.arg_kinds, call.arg_names) - if kind == ARG_NAMED and name == key] + named_args = [ + arg + for arg, kind, name in zip(call.args, call.arg_kinds, call.arg_names) + if kind == ARG_NAMED and name == key + ] if named_args: return named_args[0] star_args_2 = [arg for arg, kind in zip(call.args, call.arg_kinds) if kind == ARG_STAR2] @@ -466,17 +562,16 @@ def get_expr_by_name(self, key: str, call: CallExpr) -> Optional[Expression]: return None star_arg_2 = star_args_2[0] kwargs_type = get_proper_type(self.chk.lookup_type(star_arg_2)) - if (not isinstance(kwargs_type, Instance) or not - kwargs_type.type.has_base('typing.Mapping')): + if not isinstance(kwargs_type, Instance) or not kwargs_type.type.has_base( + "typing.Mapping" + ): # Error should be already reported. return TempNode(AnyType(TypeOfAny.special_form)) any_type = AnyType(TypeOfAny.special_form) - mapping_info = self.chk.named_generic_type('typing.Mapping', - [any_type, any_type]).type + mapping_info = self.chk.named_generic_type("typing.Mapping", [any_type, any_type]).type return TempNode(map_instance_to_supertype(kwargs_type, mapping_info).args[1]) - def auto_generate_keys(self, all_specs: List[ConversionSpecifier], - ctx: Context) -> bool: + def auto_generate_keys(self, all_specs: List[ConversionSpecifier], ctx: Context) -> bool: """Translate '{} {name} {}' to '{0} {name} {1}'. Return True if generation was successful, otherwise report an error and return false. @@ -484,8 +579,11 @@ def auto_generate_keys(self, all_specs: List[ConversionSpecifier], some_defined = any(s.key and s.key.isdecimal() for s in all_specs) all_defined = all(bool(s.key) for s in all_specs) if some_defined and not all_defined: - self.msg.fail('Cannot combine automatic field numbering and' - ' manual field specification', ctx, code=codes.STRING_FORMATTING) + self.msg.fail( + "Cannot combine automatic field numbering and" " manual field specification", + ctx, + code=codes.STRING_FORMATTING, + ) return False if all_defined: return True @@ -502,8 +600,9 @@ def auto_generate_keys(self, all_specs: List[ConversionSpecifier], next_index += 1 return True - def apply_field_accessors(self, spec: ConversionSpecifier, repl: Expression, - ctx: Context) -> Expression: + def apply_field_accessors( + self, spec: ConversionSpecifier, repl: Expression, ctx: Context + ) -> Expression: """Transform and validate expr in '{.attr[item]}'.format(expr) into expr.attr['item']. If validation fails, return TempNode(AnyType). @@ -514,13 +613,16 @@ def apply_field_accessors(self, spec: ConversionSpecifier, repl: Expression, assert spec.field temp_errors = Errors() - dummy = DUMMY_FIELD_NAME + spec.field[len(spec.key):] + dummy = DUMMY_FIELD_NAME + spec.field[len(spec.key) :] temp_ast: Node = parse( dummy, fnam="", module=None, options=self.chk.options, errors=temp_errors ) if temp_errors.is_errors(): - self.msg.fail(f'Syntax error in format specifier "{spec.field}"', - ctx, code=codes.STRING_FORMATTING) + self.msg.fail( + f'Syntax error in format specifier "{spec.field}"', + ctx, + code=codes.STRING_FORMATTING, + ) return TempNode(AnyType(TypeOfAny.from_error)) # These asserts are guaranteed by the original regexp. @@ -538,8 +640,13 @@ def apply_field_accessors(self, spec: ConversionSpecifier, repl: Expression, self.exprchk.accept(temp_ast) return temp_ast - def validate_and_transform_accessors(self, temp_ast: Expression, original_repl: Expression, - spec: ConversionSpecifier, ctx: Context) -> bool: + def validate_and_transform_accessors( + self, + temp_ast: Expression, + original_repl: Expression, + spec: ConversionSpecifier, + ctx: Context, + ) -> bool: """Validate and transform (in-place) format field accessors. On error, report it and return False. The transformations include replacing the dummy @@ -553,9 +660,12 @@ class User(TypedDict): '{[id]:d} -> {[name]}'.format(u) """ if not isinstance(temp_ast, (MemberExpr, IndexExpr)): - self.msg.fail('Only index and member expressions are allowed in' - ' format field accessors; got "{}"'.format(spec.field), - ctx, code=codes.STRING_FORMATTING) + self.msg.fail( + "Only index and member expressions are allowed in" + ' format field accessors; got "{}"'.format(spec.field), + ctx, + code=codes.STRING_FORMATTING, + ) return False if isinstance(temp_ast, MemberExpr): node = temp_ast.expr @@ -564,9 +674,12 @@ class User(TypedDict): if not isinstance(temp_ast.index, (NameExpr, IntExpr)): assert spec.key, "Call this method only after auto-generating keys!" assert spec.field - self.msg.fail('Invalid index expression in format field' - ' accessor "{}"'.format(spec.field[len(spec.key):]), ctx, - code=codes.STRING_FORMATTING) + self.msg.fail( + "Invalid index expression in format field" + ' accessor "{}"'.format(spec.field[len(spec.key) :]), + ctx, + code=codes.STRING_FORMATTING, + ) return False if isinstance(temp_ast.index, NameExpr): temp_ast.index = StrExpr(temp_ast.index.name) @@ -580,14 +693,13 @@ class User(TypedDict): return True node.line = ctx.line node.column = ctx.column - return self.validate_and_transform_accessors(node, original_repl=original_repl, - spec=spec, ctx=ctx) + return self.validate_and_transform_accessors( + node, original_repl=original_repl, spec=spec, ctx=ctx + ) # TODO: In Python 3, the bytes formatting has a more restricted set of options # compared to string formatting. - def check_str_interpolation(self, - expr: FormatStringExpr, - replacements: Expression) -> Type: + def check_str_interpolation(self, expr: FormatStringExpr, replacements: Expression) -> Type: """Check the types of the 'replacements' in a string interpolation expression: str % replacements. """ @@ -595,8 +707,11 @@ def check_str_interpolation(self, specifiers = parse_conversion_specifiers(expr.value) has_mapping_keys = self.analyze_conversion_specifiers(specifiers, expr) if isinstance(expr, BytesExpr) and (3, 0) <= self.chk.options.python_version < (3, 5): - self.msg.fail('Bytes formatting is only supported in Python 3.5 and later', - replacements, code=codes.STRING_FORMATTING) + self.msg.fail( + "Bytes formatting is only supported in Python 3.5 and later", + replacements, + code=codes.STRING_FORMATTING, + ) return AnyType(TypeOfAny.from_error) self.unicode_upcast = False @@ -608,22 +723,23 @@ def check_str_interpolation(self, self.check_simple_str_interpolation(specifiers, replacements, expr) if isinstance(expr, BytesExpr): - return self.named_type('builtins.bytes') + return self.named_type("builtins.bytes") elif isinstance(expr, UnicodeExpr): - return self.named_type('builtins.unicode') + return self.named_type("builtins.unicode") elif isinstance(expr, StrExpr): if self.unicode_upcast: - return self.named_type('builtins.unicode') - return self.named_type('builtins.str') + return self.named_type("builtins.unicode") + return self.named_type("builtins.str") else: assert False - def analyze_conversion_specifiers(self, specifiers: List[ConversionSpecifier], - context: Context) -> Optional[bool]: + def analyze_conversion_specifiers( + self, specifiers: List[ConversionSpecifier], context: Context + ) -> Optional[bool]: has_star = any(specifier.has_star() for specifier in specifiers) has_key = any(specifier.has_key() for specifier in specifiers) all_have_keys = all( - specifier.has_key() or specifier.conv_type == '%' for specifier in specifiers + specifier.has_key() or specifier.conv_type == "%" for specifier in specifiers ) if has_key and has_star: @@ -634,8 +750,12 @@ def analyze_conversion_specifiers(self, specifiers: List[ConversionSpecifier], return None return has_key - def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], - replacements: Expression, expr: FormatStringExpr) -> None: + def check_simple_str_interpolation( + self, + specifiers: List[ConversionSpecifier], + replacements: Expression, + expr: FormatStringExpr, + ) -> None: """Check % string interpolation with positional specifiers '%s, %d' % ('yes, 42').""" checkers = self.build_replacement_checkers(specifiers, replacements, expr) if checkers is None: @@ -647,7 +767,7 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], rep_types = rhs_type.items elif isinstance(rhs_type, AnyType): return - elif isinstance(rhs_type, Instance) and rhs_type.type.fullname == 'builtins.tuple': + elif isinstance(rhs_type, Instance) and rhs_type.type.fullname == "builtins.tuple": # Assume that an arbitrary-length tuple has the right number of items. rep_types = [rhs_type.args[0]] * len(checkers) elif isinstance(rhs_type, UnionType): @@ -661,8 +781,9 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], if len(checkers) > len(rep_types): # Only check the fix-length Tuple type. Other Iterable types would skip. - if (is_subtype(rhs_type, self.chk.named_type("typing.Iterable")) and - not isinstance(rhs_type, TupleType)): + if is_subtype(rhs_type, self.chk.named_type("typing.Iterable")) and not isinstance( + rhs_type, TupleType + ): return else: self.msg.too_few_string_formatting_arguments(replacements) @@ -675,8 +796,9 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], check_type(rhs_type.items[0]) else: check_node(replacements) - elif (isinstance(replacements, TupleExpr) - and not any(isinstance(item, StarExpr) for item in replacements.items)): + elif isinstance(replacements, TupleExpr) and not any( + isinstance(item, StarExpr) for item in replacements.items + ): for checks, rep_node in zip(checkers, replacements.items): check_node, check_type = checks check_node(rep_node) @@ -685,25 +807,31 @@ def check_simple_str_interpolation(self, specifiers: List[ConversionSpecifier], check_node, check_type = checks check_type(rep_type) - def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier], - replacements: Expression, - expr: FormatStringExpr) -> None: + def check_mapping_str_interpolation( + self, + specifiers: List[ConversionSpecifier], + replacements: Expression, + expr: FormatStringExpr, + ) -> None: """Check % string interpolation with names specifiers '%(name)s' % {'name': 'John'}.""" - if (isinstance(replacements, DictExpr) and - all(isinstance(k, (StrExpr, BytesExpr, UnicodeExpr)) - for k, v in replacements.items)): + if isinstance(replacements, DictExpr) and all( + isinstance(k, (StrExpr, BytesExpr, UnicodeExpr)) for k, v in replacements.items + ): mapping: Dict[str, Type] = {} for k, v in replacements.items: if self.chk.options.python_version >= (3, 0) and isinstance(expr, BytesExpr): # Special case: for bytes formatting keys must be bytes. if not isinstance(k, BytesExpr): - self.msg.fail('Dictionary keys in bytes formatting must be bytes,' - ' not strings', expr, code=codes.STRING_FORMATTING) + self.msg.fail( + "Dictionary keys in bytes formatting must be bytes," " not strings", + expr, + code=codes.STRING_FORMATTING, + ) key_str = cast(FormatStringExpr, k).value mapping[key_str] = self.accept(v) for specifier in specifiers: - if specifier.conv_type == '%': + if specifier.conv_type == "%": # %% is allowed in mappings, no checking is required continue assert specifier.key is not None @@ -715,47 +843,52 @@ def check_mapping_str_interpolation(self, specifiers: List[ConversionSpecifier], expected_type = self.conversion_type(specifier.conv_type, replacements, expr) if expected_type is None: return - self.chk.check_subtype(rep_type, expected_type, replacements, - message_registry.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION, - 'expression has type', - f'placeholder with key \'{specifier.key}\' has type', - code=codes.STRING_FORMATTING) - if specifier.conv_type == 's': + self.chk.check_subtype( + rep_type, + expected_type, + replacements, + message_registry.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION, + "expression has type", + f"placeholder with key '{specifier.key}' has type", + code=codes.STRING_FORMATTING, + ) + if specifier.conv_type == "s": self.check_s_special_cases(expr, rep_type, expr) else: rep_type = self.accept(replacements) dict_type = self.build_dict_type(expr) - self.chk.check_subtype(rep_type, dict_type, replacements, - message_registry.FORMAT_REQUIRES_MAPPING, - 'expression has type', 'expected type for mapping is', - code=codes.STRING_FORMATTING) + self.chk.check_subtype( + rep_type, + dict_type, + replacements, + message_registry.FORMAT_REQUIRES_MAPPING, + "expression has type", + "expected type for mapping is", + code=codes.STRING_FORMATTING, + ) def build_dict_type(self, expr: FormatStringExpr) -> Type: """Build expected mapping type for right operand in % formatting.""" any_type = AnyType(TypeOfAny.special_form) if self.chk.options.python_version >= (3, 0): if isinstance(expr, BytesExpr): - bytes_type = self.chk.named_generic_type('builtins.bytes', []) - return self.chk.named_generic_type('typing.Mapping', - [bytes_type, any_type]) + bytes_type = self.chk.named_generic_type("builtins.bytes", []) + return self.chk.named_generic_type("typing.Mapping", [bytes_type, any_type]) elif isinstance(expr, StrExpr): - str_type = self.chk.named_generic_type('builtins.str', []) - return self.chk.named_generic_type('typing.Mapping', - [str_type, any_type]) + str_type = self.chk.named_generic_type("builtins.str", []) + return self.chk.named_generic_type("typing.Mapping", [str_type, any_type]) else: assert False, "There should not be UnicodeExpr on Python 3" else: - str_type = self.chk.named_generic_type('builtins.str', []) - unicode_type = self.chk.named_generic_type('builtins.unicode', []) - str_map = self.chk.named_generic_type('typing.Mapping', - [str_type, any_type]) - unicode_map = self.chk.named_generic_type('typing.Mapping', - [unicode_type, any_type]) + str_type = self.chk.named_generic_type("builtins.str", []) + unicode_type = self.chk.named_generic_type("builtins.unicode", []) + str_map = self.chk.named_generic_type("typing.Mapping", [str_type, any_type]) + unicode_map = self.chk.named_generic_type("typing.Mapping", [unicode_type, any_type]) return UnionType.make_union([str_map, unicode_map]) - def build_replacement_checkers(self, specifiers: List[ConversionSpecifier], - context: Context, expr: FormatStringExpr - ) -> Optional[List[Checkers]]: + def build_replacement_checkers( + self, specifiers: List[ConversionSpecifier], context: Context, expr: FormatStringExpr + ) -> Optional[List[Checkers]]: checkers: List[Checkers] = [] for specifier in specifiers: checker = self.replacement_checkers(specifier, context, expr) @@ -764,25 +897,26 @@ def build_replacement_checkers(self, specifiers: List[ConversionSpecifier], checkers.extend(checker) return checkers - def replacement_checkers(self, specifier: ConversionSpecifier, context: Context, - expr: FormatStringExpr) -> Optional[List[Checkers]]: + def replacement_checkers( + self, specifier: ConversionSpecifier, context: Context, expr: FormatStringExpr + ) -> Optional[List[Checkers]]: """Returns a list of tuples of two functions that check whether a replacement is of the right type for the specifier. The first function takes a node and checks its type in the right type context. The second function just checks a type. """ checkers: List[Checkers] = [] - if specifier.width == '*': + if specifier.width == "*": checkers.append(self.checkers_for_star(context)) - if specifier.precision == '*': + if specifier.precision == "*": checkers.append(self.checkers_for_star(context)) - if specifier.conv_type == 'c': + if specifier.conv_type == "c": c = self.checkers_for_c_type(specifier.conv_type, context, expr) if c is None: return None checkers.append(c) - elif specifier.conv_type is not None and specifier.conv_type != '%': + elif specifier.conv_type is not None and specifier.conv_type != "%": c = self.checkers_for_regular_type(specifier.conv_type, context, expr) if c is None: return None @@ -793,12 +927,13 @@ def checkers_for_star(self, context: Context) -> Checkers: """Returns a tuple of check functions that check whether, respectively, a node or a type is compatible with a star in a conversion specifier. """ - expected = self.named_type('builtins.int') + expected = self.named_type("builtins.int") def check_type(type: Type) -> bool: - expected = self.named_type('builtins.int') - return self.chk.check_subtype(type, expected, context, '* wants int', - code=codes.STRING_FORMATTING) + expected = self.named_type("builtins.int") + return self.chk.check_subtype( + type, expected, context, "* wants int", code=codes.STRING_FORMATTING + ) def check_expr(expr: Expression) -> None: type = self.accept(expr, expected) @@ -807,14 +942,19 @@ def check_expr(expr: Expression) -> None: return check_expr, check_type def check_placeholder_type(self, typ: Type, expected_type: Type, context: Context) -> bool: - return self.chk.check_subtype(typ, expected_type, context, - message_registry.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION, - 'expression has type', 'placeholder has type', - code=codes.STRING_FORMATTING) - - def checkers_for_regular_type(self, conv_type: str, - context: Context, - expr: FormatStringExpr) -> Optional[Checkers]: + return self.chk.check_subtype( + typ, + expected_type, + context, + message_registry.INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION, + "expression has type", + "placeholder has type", + code=codes.STRING_FORMATTING, + ) + + def checkers_for_regular_type( + self, conv_type: str, context: Context, expr: FormatStringExpr + ) -> Optional[Checkers]: """Returns a tuple of check functions that check whether, respectively, a node or a type is compatible with 'type'. Return None in case of an error. """ @@ -825,7 +965,7 @@ def checkers_for_regular_type(self, conv_type: str, def check_type(typ: Type) -> bool: assert expected_type is not None ret = self.check_placeholder_type(typ, expected_type, context) - if ret and conv_type == 's': + if ret and conv_type == "s": ret = self.check_s_special_cases(expr, typ, context) return ret @@ -840,28 +980,33 @@ def check_s_special_cases(self, expr: FormatStringExpr, typ: Type, context: Cont if isinstance(expr, StrExpr): # Couple special cases for string formatting. if self.chk.options.python_version >= (3, 0): - if has_type_component(typ, 'builtins.bytes'): + if has_type_component(typ, "builtins.bytes"): self.msg.fail( 'On Python 3 formatting "b\'abc\'" with "%s" ' 'produces "b\'abc\'", not "abc"; ' 'use "%r" if this is desired behavior', - context, code=codes.STR_BYTES_PY3) + context, + code=codes.STR_BYTES_PY3, + ) return False if self.chk.options.python_version < (3, 0): - if has_type_component(typ, 'builtins.unicode'): + if has_type_component(typ, "builtins.unicode"): self.unicode_upcast = True if isinstance(expr, BytesExpr): # A special case for bytes formatting: b'%s' actually requires bytes on Python 3. if self.chk.options.python_version >= (3, 0): - if has_type_component(typ, 'builtins.str'): - self.msg.fail("On Python 3 b'%s' requires bytes, not string", context, - code=codes.STRING_FORMATTING) + if has_type_component(typ, "builtins.str"): + self.msg.fail( + "On Python 3 b'%s' requires bytes, not string", + context, + code=codes.STRING_FORMATTING, + ) return False return True - def checkers_for_c_type(self, type: str, - context: Context, - format_expr: FormatStringExpr) -> Optional[Checkers]: + def checkers_for_c_type( + self, type: str, context: Context, format_expr: FormatStringExpr + ) -> Optional[Checkers]: """Returns a tuple of check functions that check whether, respectively, a node or a type is compatible with 'type' that is a character type. """ @@ -875,9 +1020,14 @@ def check_type(type: Type) -> bool: err_msg = '"%c" requires an integer in range(256) or a single byte' else: err_msg = '"%c" requires int or char' - return self.chk.check_subtype(type, expected_type, context, err_msg, - 'expression has type', - code=codes.STRING_FORMATTING) + return self.chk.check_subtype( + type, + expected_type, + context, + err_msg, + "expression has type", + code=codes.STRING_FORMATTING, + ) def check_expr(expr: Expression) -> None: """int, or str with length 1""" @@ -886,9 +1036,12 @@ def check_expr(expr: Expression) -> None: # it has exact one char or one single byte. if check_type(type): # Python 3 doesn't support b'%c' % str - if (self.chk.options.python_version >= (3, 0) - and isinstance(format_expr, BytesExpr) - and isinstance(expr, BytesExpr) and len(expr.value) != 1): + if ( + self.chk.options.python_version >= (3, 0) + and isinstance(format_expr, BytesExpr) + and isinstance(expr, BytesExpr) + and len(expr.value) != 1 + ): self.msg.requires_int_or_single_byte(context) # In Python 2, b'%c' is the same as '%c' elif isinstance(expr, (StrExpr, BytesExpr)) and len(expr.value) != 1: @@ -896,8 +1049,9 @@ def check_expr(expr: Expression) -> None: return check_expr, check_type - def conversion_type(self, p: str, context: Context, expr: FormatStringExpr, - format_call: bool = False) -> Optional[Type]: + def conversion_type( + self, p: str, context: Context, expr: FormatStringExpr, format_call: bool = False + ) -> Optional[Type]: """Return the type that is accepted for a string interpolation conversion specifier type. Note that both Python's float (e.g. %f) and integer (e.g. %d) @@ -908,44 +1062,57 @@ def conversion_type(self, p: str, context: Context, expr: FormatStringExpr, """ NUMERIC_TYPES = NUMERIC_TYPES_NEW if format_call else NUMERIC_TYPES_OLD INT_TYPES = REQUIRE_INT_NEW if format_call else REQUIRE_INT_OLD - if p == 'b' and not format_call: + if p == "b" and not format_call: if self.chk.options.python_version < (3, 5): - self.msg.fail('Format character "b" is only supported in Python 3.5 and later', - context, code=codes.STRING_FORMATTING) + self.msg.fail( + 'Format character "b" is only supported in Python 3.5 and later', + context, + code=codes.STRING_FORMATTING, + ) return None if not isinstance(expr, BytesExpr): - self.msg.fail('Format character "b" is only supported on bytes patterns', context, - code=codes.STRING_FORMATTING) + self.msg.fail( + 'Format character "b" is only supported on bytes patterns', + context, + code=codes.STRING_FORMATTING, + ) return None - return self.named_type('builtins.bytes') - elif p == 'a': + return self.named_type("builtins.bytes") + elif p == "a": if self.chk.options.python_version < (3, 0): - self.msg.fail('Format character "a" is only supported in Python 3', context, - code=codes.STRING_FORMATTING) + self.msg.fail( + 'Format character "a" is only supported in Python 3', + context, + code=codes.STRING_FORMATTING, + ) return None # TODO: return type object? return AnyType(TypeOfAny.special_form) - elif p in ['s', 'r']: + elif p in ["s", "r"]: return AnyType(TypeOfAny.special_form) elif p in NUMERIC_TYPES: if p in INT_TYPES: - numeric_types = [self.named_type('builtins.int')] + numeric_types = [self.named_type("builtins.int")] else: - numeric_types = [self.named_type('builtins.int'), - self.named_type('builtins.float')] + numeric_types = [ + self.named_type("builtins.int"), + self.named_type("builtins.float"), + ] if not format_call: if p in FLOAT_TYPES: - numeric_types.append(self.named_type('typing.SupportsFloat')) + numeric_types.append(self.named_type("typing.SupportsFloat")) else: - numeric_types.append(self.named_type('typing.SupportsInt')) + numeric_types.append(self.named_type("typing.SupportsInt")) return UnionType.make_union(numeric_types) - elif p in ['c']: + elif p in ["c"]: if isinstance(expr, BytesExpr): - return UnionType([self.named_type('builtins.int'), - self.named_type('builtins.bytes')]) + return UnionType( + [self.named_type("builtins.int"), self.named_type("builtins.bytes")] + ) else: - return UnionType([self.named_type('builtins.int'), - self.named_type('builtins.str')]) + return UnionType( + [self.named_type("builtins.int"), self.named_type("builtins.str")] + ) else: self.msg.unsupported_placeholder(p, context) return None @@ -977,8 +1144,9 @@ def has_type_component(typ: Type, fullname: str) -> bool: if isinstance(typ, Instance): return typ.type.has_base(fullname) elif isinstance(typ, TypeVarType): - return (has_type_component(typ.upper_bound, fullname) or - any(has_type_component(v, fullname) for v in typ.values)) + return has_type_component(typ.upper_bound, fullname) or any( + has_type_component(v, fullname) for v in typ.values + ) elif isinstance(typ, UnionType): return any(has_type_component(t, fullname) for t in typ.relevant_items()) return False diff --git a/mypy/config_parser.py b/mypy/config_parser.py index 90970429ab8c5..52b41d6ee2f79 100644 --- a/mypy/config_parser.py +++ b/mypy/config_parser.py @@ -1,22 +1,35 @@ import argparse import configparser import glob as fileglob -from io import StringIO import os import re import sys +from io import StringIO if sys.version_info >= (3, 11): import tomllib else: import tomli as tomllib -from typing import (Any, Callable, Dict, List, Mapping, MutableMapping, Optional, Sequence, - TextIO, Tuple, Union, Iterable) +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + TextIO, + Tuple, + Union, +) + from typing_extensions import Final, TypeAlias as _TypeAlias from mypy import defaults -from mypy.options import Options, PER_MODULE_OPTIONS +from mypy.options import PER_MODULE_OPTIONS, Options _CONFIG_VALUE_TYPES: _TypeAlias = Union[ str, bool, int, float, Dict[str, str], List[str], Tuple[int, int], @@ -25,20 +38,17 @@ def parse_version(v: Union[str, float]) -> Tuple[int, int]: - m = re.match(r'\A(\d)\.(\d+)\Z', str(v)) + m = re.match(r"\A(\d)\.(\d+)\Z", str(v)) if not m: - raise argparse.ArgumentTypeError( - f"Invalid python version '{v}' (expected format: 'x.y')") + raise argparse.ArgumentTypeError(f"Invalid python version '{v}' (expected format: 'x.y')") major, minor = int(m.group(1)), int(m.group(2)) if major == 2: if minor != 7: - raise argparse.ArgumentTypeError( - f"Python 2.{minor} is not supported (must be 2.7)") + raise argparse.ArgumentTypeError(f"Python 2.{minor} is not supported (must be 2.7)") elif major == 3: if minor < defaults.PYTHON3_VERSION_MIN[1]: msg = "Python 3.{0} is not supported (must be {1}.{2} or higher)".format( - minor, - *defaults.PYTHON3_VERSION_MIN + minor, *defaults.PYTHON3_VERSION_MIN ) if isinstance(v, float): @@ -47,11 +57,12 @@ def parse_version(v: Union[str, float]) -> Tuple[int, int]: raise argparse.ArgumentTypeError(msg) else: raise argparse.ArgumentTypeError( - f"Python major version '{major}' out of range (must be 2 or 3)") + f"Python major version '{major}' out of range (must be 2 or 3)" + ) return major, minor -def try_split(v: Union[str, Sequence[str]], split_regex: str = '[,]') -> List[str]: +def try_split(v: Union[str, Sequence[str]], split_regex: str = "[,]") -> List[str]: """Split and trim a str or list of str into a list of str""" if isinstance(v, str): return [p.strip() for p in re.split(split_regex, v)] @@ -102,16 +113,17 @@ def split_and_match_files(paths: str) -> List[str]: Returns a list of file paths """ - return split_and_match_files_list(paths.split(',')) + return split_and_match_files_list(paths.split(",")) def check_follow_imports(choice: str) -> str: - choices = ['normal', 'silent', 'skip', 'error'] + choices = ["normal", "silent", "skip", "error"] if choice not in choices: raise argparse.ArgumentTypeError( "invalid choice '{}' (choose from {})".format( - choice, - ', '.join(f"'{x}'" for x in choices))) + choice, ", ".join(f"'{x}'" for x in choices) + ) + ) return choice @@ -120,53 +132,58 @@ def check_follow_imports(choice: str) -> str: # exists to specify types for values initialized to None or container # types. ini_config_types: Final[Dict[str, _INI_PARSER_CALLABLE]] = { - 'python_version': parse_version, - 'strict_optional_whitelist': lambda s: s.split(), - 'custom_typing_module': str, - 'custom_typeshed_dir': expand_path, - 'mypy_path': lambda s: [expand_path(p.strip()) for p in re.split('[,:]', s)], - 'files': split_and_match_files, - 'quickstart_file': expand_path, - 'junit_xml': expand_path, + "python_version": parse_version, + "strict_optional_whitelist": lambda s: s.split(), + "custom_typing_module": str, + "custom_typeshed_dir": expand_path, + "mypy_path": lambda s: [expand_path(p.strip()) for p in re.split("[,:]", s)], + "files": split_and_match_files, + "quickstart_file": expand_path, + "junit_xml": expand_path, # These two are for backwards compatibility - 'silent_imports': bool, - 'almost_silent': bool, - 'follow_imports': check_follow_imports, - 'no_site_packages': bool, - 'plugins': lambda s: [p.strip() for p in s.split(',')], - 'always_true': lambda s: [p.strip() for p in s.split(',')], - 'always_false': lambda s: [p.strip() for p in s.split(',')], - 'disable_error_code': lambda s: [p.strip() for p in s.split(',')], - 'enable_error_code': lambda s: [p.strip() for p in s.split(',')], - 'package_root': lambda s: [p.strip() for p in s.split(',')], - 'cache_dir': expand_path, - 'python_executable': expand_path, - 'strict': bool, - 'exclude': lambda s: [s.strip()], + "silent_imports": bool, + "almost_silent": bool, + "follow_imports": check_follow_imports, + "no_site_packages": bool, + "plugins": lambda s: [p.strip() for p in s.split(",")], + "always_true": lambda s: [p.strip() for p in s.split(",")], + "always_false": lambda s: [p.strip() for p in s.split(",")], + "disable_error_code": lambda s: [p.strip() for p in s.split(",")], + "enable_error_code": lambda s: [p.strip() for p in s.split(",")], + "package_root": lambda s: [p.strip() for p in s.split(",")], + "cache_dir": expand_path, + "python_executable": expand_path, + "strict": bool, + "exclude": lambda s: [s.strip()], } # Reuse the ini_config_types and overwrite the diff toml_config_types: Final[Dict[str, _INI_PARSER_CALLABLE]] = ini_config_types.copy() -toml_config_types.update({ - 'python_version': parse_version, - 'strict_optional_whitelist': try_split, - 'mypy_path': lambda s: [expand_path(p) for p in try_split(s, '[,:]')], - 'files': lambda s: split_and_match_files_list(try_split(s)), - 'follow_imports': lambda s: check_follow_imports(str(s)), - 'plugins': try_split, - 'always_true': try_split, - 'always_false': try_split, - 'disable_error_code': try_split, - 'enable_error_code': try_split, - 'package_root': try_split, - 'exclude': str_or_array_as_list, -}) - - -def parse_config_file(options: Options, set_strict_flags: Callable[[], None], - filename: Optional[str], - stdout: Optional[TextIO] = None, - stderr: Optional[TextIO] = None) -> None: +toml_config_types.update( + { + "python_version": parse_version, + "strict_optional_whitelist": try_split, + "mypy_path": lambda s: [expand_path(p) for p in try_split(s, "[,:]")], + "files": lambda s: split_and_match_files_list(try_split(s)), + "follow_imports": lambda s: check_follow_imports(str(s)), + "plugins": try_split, + "always_true": try_split, + "always_false": try_split, + "disable_error_code": try_split, + "enable_error_code": try_split, + "package_root": try_split, + "exclude": str_or_array_as_list, + } +) + + +def parse_config_file( + options: Options, + set_strict_flags: Callable[[], None], + filename: Optional[str], + stdout: Optional[TextIO] = None, + stderr: Optional[TextIO] = None, +) -> None: """Parse a config file into an Options object. Errors are written to stderr but are not fatal. @@ -192,10 +209,10 @@ def parse_config_file(options: Options, set_strict_flags: Callable[[], None], with open(config_file, "rb") as f: toml_data = tomllib.load(f) # Filter down to just mypy relevant toml keys - toml_data = toml_data.get('tool', {}) - if 'mypy' not in toml_data: + toml_data = toml_data.get("tool", {}) + if "mypy" not in toml_data: continue - toml_data = {'mypy': toml_data['mypy']} + toml_data = {"mypy": toml_data["mypy"]} parser: MutableMapping[str, Any] = destructure_overrides(toml_data) config_types = toml_config_types else: @@ -205,7 +222,7 @@ def parse_config_file(options: Options, set_strict_flags: Callable[[], None], except (tomllib.TOMLDecodeError, configparser.Error, ConfigTOMLValueError) as err: print(f"{config_file}: {err}", file=stderr) else: - if config_file in defaults.SHARED_CONFIG_FILES and 'mypy' not in parser: + if config_file in defaults.SHARED_CONFIG_FILES and "mypy" not in parser: continue file_read = config_file options.config_file = file_read @@ -213,63 +230,70 @@ def parse_config_file(options: Options, set_strict_flags: Callable[[], None], else: return - os.environ['MYPY_CONFIG_FILE_DIR'] = os.path.dirname( - os.path.abspath(config_file)) + os.environ["MYPY_CONFIG_FILE_DIR"] = os.path.dirname(os.path.abspath(config_file)) - if 'mypy' not in parser: + if "mypy" not in parser: if filename or file_read not in defaults.SHARED_CONFIG_FILES: print(f"{file_read}: No [mypy] section in config file", file=stderr) else: - section = parser['mypy'] + section = parser["mypy"] prefix = f"{file_read}: [mypy]: " updates, report_dirs = parse_section( - prefix, options, set_strict_flags, section, config_types, stderr) + prefix, options, set_strict_flags, section, config_types, stderr + ) for k, v in updates.items(): setattr(options, k, v) options.report_dirs.update(report_dirs) for name, section in parser.items(): - if name.startswith('mypy-'): + if name.startswith("mypy-"): prefix = get_prefix(file_read, name) updates, report_dirs = parse_section( - prefix, options, set_strict_flags, section, config_types, stderr) + prefix, options, set_strict_flags, section, config_types, stderr + ) if report_dirs: - print("%sPer-module sections should not specify reports (%s)" % - (prefix, ', '.join(s + '_report' for s in sorted(report_dirs))), - file=stderr) + print( + "%sPer-module sections should not specify reports (%s)" + % (prefix, ", ".join(s + "_report" for s in sorted(report_dirs))), + file=stderr, + ) if set(updates) - PER_MODULE_OPTIONS: - print("%sPer-module sections should only specify per-module flags (%s)" % - (prefix, ', '.join(sorted(set(updates) - PER_MODULE_OPTIONS))), - file=stderr) + print( + "%sPer-module sections should only specify per-module flags (%s)" + % (prefix, ", ".join(sorted(set(updates) - PER_MODULE_OPTIONS))), + file=stderr, + ) updates = {k: v for k, v in updates.items() if k in PER_MODULE_OPTIONS} globs = name[5:] - for glob in globs.split(','): + for glob in globs.split(","): # For backwards compatibility, replace (back)slashes with dots. - glob = glob.replace(os.sep, '.') + glob = glob.replace(os.sep, ".") if os.altsep: - glob = glob.replace(os.altsep, '.') - - if (any(c in glob for c in '?[]!') or - any('*' in x and x != '*' for x in glob.split('.'))): - print("%sPatterns must be fully-qualified module names, optionally " - "with '*' in some components (e.g spam.*.eggs.*)" - % prefix, - file=stderr) + glob = glob.replace(os.altsep, ".") + + if any(c in glob for c in "?[]!") or any( + "*" in x and x != "*" for x in glob.split(".") + ): + print( + "%sPatterns must be fully-qualified module names, optionally " + "with '*' in some components (e.g spam.*.eggs.*)" % prefix, + file=stderr, + ) else: options.per_module_options[glob] = updates def get_prefix(file_read: str, name: str) -> str: if is_toml(file_read): - module_name_str = 'module = "%s"' % '-'.join(name.split('-')[1:]) + module_name_str = 'module = "%s"' % "-".join(name.split("-")[1:]) else: module_name_str = name - return f'{file_read}: [{module_name_str}]: ' + return f"{file_read}: [{module_name_str}]: " def is_toml(filename: str) -> bool: - return filename.lower().endswith('.toml') + return filename.lower().endswith(".toml") def destructure_overrides(toml_data: Dict[str, Any]) -> Dict[str, Any]: @@ -304,54 +328,66 @@ def destructure_overrides(toml_data: Dict[str, Any]) -> Dict[str, Any]: }, } """ - if 'overrides' not in toml_data['mypy']: + if "overrides" not in toml_data["mypy"]: return toml_data - if not isinstance(toml_data['mypy']['overrides'], list): - raise ConfigTOMLValueError("tool.mypy.overrides sections must be an array. Please make " - "sure you are using double brackets like so: [[tool.mypy.overrides]]") + if not isinstance(toml_data["mypy"]["overrides"], list): + raise ConfigTOMLValueError( + "tool.mypy.overrides sections must be an array. Please make " + "sure you are using double brackets like so: [[tool.mypy.overrides]]" + ) result = toml_data.copy() - for override in result['mypy']['overrides']: - if 'module' not in override: - raise ConfigTOMLValueError("toml config file contains a [[tool.mypy.overrides]] " - "section, but no module to override was specified.") - - if isinstance(override['module'], str): - modules = [override['module']] - elif isinstance(override['module'], list): - modules = override['module'] + for override in result["mypy"]["overrides"]: + if "module" not in override: + raise ConfigTOMLValueError( + "toml config file contains a [[tool.mypy.overrides]] " + "section, but no module to override was specified." + ) + + if isinstance(override["module"], str): + modules = [override["module"]] + elif isinstance(override["module"], list): + modules = override["module"] else: - raise ConfigTOMLValueError("toml config file contains a [[tool.mypy.overrides]] " - "section with a module value that is not a string or a list of " - "strings") + raise ConfigTOMLValueError( + "toml config file contains a [[tool.mypy.overrides]] " + "section with a module value that is not a string or a list of " + "strings" + ) for module in modules: module_overrides = override.copy() - del module_overrides['module'] - old_config_name = f'mypy-{module}' + del module_overrides["module"] + old_config_name = f"mypy-{module}" if old_config_name not in result: result[old_config_name] = module_overrides else: for new_key, new_value in module_overrides.items(): - if (new_key in result[old_config_name] and - result[old_config_name][new_key] != new_value): - raise ConfigTOMLValueError("toml config file contains " - "[[tool.mypy.overrides]] sections with conflicting " - "values. Module '%s' has two different values for '%s'" - % (module, new_key)) + if ( + new_key in result[old_config_name] + and result[old_config_name][new_key] != new_value + ): + raise ConfigTOMLValueError( + "toml config file contains " + "[[tool.mypy.overrides]] sections with conflicting " + "values. Module '%s' has two different values for '%s'" + % (module, new_key) + ) result[old_config_name][new_key] = new_value - del result['mypy']['overrides'] + del result["mypy"]["overrides"] return result -def parse_section(prefix: str, template: Options, - set_strict_flags: Callable[[], None], - section: Mapping[str, Any], - config_types: Dict[str, Any], - stderr: TextIO = sys.stderr - ) -> Tuple[Dict[str, object], Dict[str, str]]: +def parse_section( + prefix: str, + template: Options, + set_strict_flags: Callable[[], None], + section: Mapping[str, Any], + config_types: Dict[str, Any], + stderr: TextIO = sys.stderr, +) -> Tuple[Dict[str, object], Dict[str, str]]: """Parse one section of a config file. Returns a dict of option values encountered, and a dict of report directories. @@ -367,34 +403,32 @@ def parse_section(prefix: str, template: Options, dv = None # We have to keep new_semantic_analyzer in Options # for plugin compatibility but it is not a valid option anymore. - assert hasattr(template, 'new_semantic_analyzer') - if key != 'new_semantic_analyzer': + assert hasattr(template, "new_semantic_analyzer") + if key != "new_semantic_analyzer": dv = getattr(template, key, None) if dv is None: - if key.endswith('_report'): - report_type = key[:-7].replace('_', '-') + if key.endswith("_report"): + report_type = key[:-7].replace("_", "-") if report_type in defaults.REPORTER_NAMES: report_dirs[report_type] = str(section[key]) else: - print(f"{prefix}Unrecognized report type: {key}", - file=stderr) + print(f"{prefix}Unrecognized report type: {key}", file=stderr) continue - if key.startswith('x_'): + if key.startswith("x_"): pass # Don't complain about `x_blah` flags - elif key.startswith('no_') and hasattr(template, key[3:]): + elif key.startswith("no_") and hasattr(template, key[3:]): options_key = key[3:] invert = True - elif key.startswith('allow') and hasattr(template, 'dis' + key): - options_key = 'dis' + key + elif key.startswith("allow") and hasattr(template, "dis" + key): + options_key = "dis" + key invert = True - elif key.startswith('disallow') and hasattr(template, key[3:]): + elif key.startswith("disallow") and hasattr(template, key[3:]): options_key = key[3:] invert = True - elif key == 'strict': + elif key == "strict": pass # Special handling below else: - print(f"{prefix}Unrecognized option: {key} = {section[key]}", - file=stderr) + print(f"{prefix}Unrecognized option: {key} = {section[key]}", file=stderr) if invert: dv = getattr(template, options_key, None) else: @@ -411,8 +445,7 @@ def parse_section(prefix: str, template: Options, v = not v elif callable(ct): if invert: - print(f"{prefix}Can not invert non-boolean key {options_key}", - file=stderr) + print(f"{prefix}Can not invert non-boolean key {options_key}", file=stderr) continue try: v = ct(section.get(key)) @@ -425,24 +458,29 @@ def parse_section(prefix: str, template: Options, except ValueError as err: print(f"{prefix}{key}: {err}", file=stderr) continue - if key == 'strict': + if key == "strict": if v: set_strict_flags() continue - if key == 'silent_imports': - print("%ssilent_imports has been replaced by " - "ignore_missing_imports=True; follow_imports=skip" % prefix, file=stderr) + if key == "silent_imports": + print( + "%ssilent_imports has been replaced by " + "ignore_missing_imports=True; follow_imports=skip" % prefix, + file=stderr, + ) if v: - if 'ignore_missing_imports' not in results: - results['ignore_missing_imports'] = True - if 'follow_imports' not in results: - results['follow_imports'] = 'skip' - if key == 'almost_silent': - print("%salmost_silent has been replaced by " - "follow_imports=error" % prefix, file=stderr) + if "ignore_missing_imports" not in results: + results["ignore_missing_imports"] = True + if "follow_imports" not in results: + results["follow_imports"] = "skip" + if key == "almost_silent": + print( + "%salmost_silent has been replaced by " "follow_imports=error" % prefix, + file=stderr, + ) if v: - if 'follow_imports' not in results: - results['follow_imports'] = 'error' + if "follow_imports" not in results: + results["follow_imports"] = "error" results[options_key] = v return results, report_dirs @@ -454,7 +492,7 @@ def convert_to_boolean(value: Optional[Any]) -> bool: if not isinstance(value, str): value = str(value) if value.lower() not in configparser.RawConfigParser.BOOLEAN_STATES: - raise ValueError(f'Not a boolean: {value}') + raise ValueError(f"Not a boolean: {value}") return configparser.RawConfigParser.BOOLEAN_STATES[value.lower()] @@ -467,8 +505,8 @@ def split_directive(s: str) -> Tuple[List[str], List[str]]: errors = [] i = 0 while i < len(s): - if s[i] == ',': - parts.append(''.join(cur).strip()) + if s[i] == ",": + parts.append("".join(cur).strip()) cur = [] elif s[i] == '"': i += 1 @@ -482,13 +520,12 @@ def split_directive(s: str) -> Tuple[List[str], List[str]]: cur.append(s[i]) i += 1 if cur: - parts.append(''.join(cur).strip()) + parts.append("".join(cur).strip()) return parts, errors -def mypy_comments_to_config_map(line: str, - template: Options) -> Tuple[Dict[str, str], List[str]]: +def mypy_comments_to_config_map(line: str, template: Options) -> Tuple[Dict[str, str], List[str]]: """Rewrite the mypy comment syntax into ini file syntax. Returns @@ -496,23 +533,23 @@ def mypy_comments_to_config_map(line: str, options = {} entries, errors = split_directive(line) for entry in entries: - if '=' not in entry: + if "=" not in entry: name = entry value = None else: - name, value = (x.strip() for x in entry.split('=', 1)) + name, value = (x.strip() for x in entry.split("=", 1)) - name = name.replace('-', '_') + name = name.replace("-", "_") if value is None: - value = 'True' + value = "True" options[name] = value return options, errors def parse_mypy_comments( - args: List[Tuple[int, str]], - template: Options) -> Tuple[Dict[str, object], List[Tuple[int, str]]]: + args: List[Tuple[int, str]], template: Options +) -> Tuple[Dict[str, object], List[Tuple[int, str]]]: """Parse a collection of inline mypy: configuration comments. Returns a dictionary of options to be applied and a list of error messages @@ -528,7 +565,7 @@ def parse_mypy_comments( # method is to create a config parser. parser = configparser.RawConfigParser() options, parse_errors = mypy_comments_to_config_map(line, template) - parser['dummy'] = options + parser["dummy"] = options errors.extend((lineno, x) for x in parse_errors) stderr = StringIO() @@ -539,15 +576,20 @@ def set_strict_flags() -> None: strict_found = True new_sections, reports = parse_section( - '', template, set_strict_flags, parser['dummy'], ini_config_types, stderr=stderr) - errors.extend((lineno, x) for x in stderr.getvalue().strip().split('\n') if x) + "", template, set_strict_flags, parser["dummy"], ini_config_types, stderr=stderr + ) + errors.extend((lineno, x) for x in stderr.getvalue().strip().split("\n") if x) if reports: errors.append((lineno, "Reports not supported in inline configuration")) if strict_found: - errors.append((lineno, - 'Setting "strict" not supported in inline configuration: specify it in ' - 'a configuration file instead, or set individual inline flags ' - '(see "mypy -h" for the list of flags enabled in strict mode)')) + errors.append( + ( + lineno, + 'Setting "strict" not supported in inline configuration: specify it in ' + "a configuration file instead, or set individual inline flags " + '(see "mypy -h" for the list of flags enabled in strict mode)', + ) + ) sections.update(new_sections) @@ -556,7 +598,7 @@ def set_strict_flags() -> None: def get_config_module_names(filename: Optional[str], modules: List[str]) -> str: if not filename or not modules: - return '' + return "" if not is_toml(filename): return ", ".join(f"[mypy-{module}]" for module in modules) diff --git a/mypy/constraints.py b/mypy/constraints.py index 9212964071f34..00309462db27a 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -1,29 +1,57 @@ """Type inference constraints.""" from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence + from typing_extensions import Final -from mypy.types import ( - CallableType, Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarType, Instance, - TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, DeletedType, - UninhabitedType, TypeType, TypeVarId, TypeQuery, is_named_instance, TypeOfAny, LiteralType, - ProperType, ParamSpecType, get_proper_type, TypeAliasType, is_union_with_any, - UnpackType, callable_with_ellipsis, Parameters, TUPLE_LIKE_INSTANCE_NAMES, TypeVarTupleType, - TypeList, -) -from mypy.maptype import map_instance_to_supertype -import mypy.subtypes import mypy.sametypes +import mypy.subtypes import mypy.typeops -from mypy.erasetype import erase_typevars -from mypy.nodes import COVARIANT, CONTRAVARIANT, ArgKind from mypy.argmap import ArgTypeExpander +from mypy.erasetype import erase_typevars +from mypy.maptype import map_instance_to_supertype +from mypy.nodes import CONTRAVARIANT, COVARIANT, ArgKind +from mypy.types import ( + TUPLE_LIKE_INSTANCE_NAMES, + AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeOfAny, + TypeQuery, + TypeType, + TypeVarId, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + callable_with_ellipsis, + get_proper_type, + is_named_instance, + is_union_with_any, +) from mypy.typestate import TypeState from mypy.typevartuples import ( - split_with_instance, - split_with_prefix_and_suffix, extract_unpack, find_unpack_in_list, + split_with_instance, + split_with_prefix_and_suffix, ) if TYPE_CHECKING: @@ -40,7 +68,7 @@ class Constraint: """ type_var: TypeVarId - op = 0 # SUBTYPE_OF or SUPERTYPE_OF + op = 0 # SUBTYPE_OF or SUPERTYPE_OF target: Type def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None: @@ -49,10 +77,10 @@ def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None: self.target = target def __repr__(self) -> str: - op_str = '<:' + op_str = "<:" if self.op == SUPERTYPE_OF: - op_str = ':>' - return f'{self.type_var} {op_str} {self.target}' + op_str = ":>" + return f"{self.type_var} {op_str} {self.target}" def __hash__(self) -> int: return hash((self.type_var, self.op, self.target)) @@ -64,11 +92,12 @@ def __eq__(self, other: object) -> bool: def infer_constraints_for_callable( - callee: CallableType, - arg_types: Sequence[Optional[Type]], - arg_kinds: List[ArgKind], - formal_to_actual: List[List[int]], - context: 'ArgumentInferContext') -> List[Constraint]: + callee: CallableType, + arg_types: Sequence[Optional[Type]], + arg_kinds: List[ArgKind], + formal_to_actual: List[List[int]], + context: "ArgumentInferContext", +) -> List[Constraint]: """Infer type variable constraints for a callable and actual arguments. Return a list of constraints. @@ -82,16 +111,16 @@ def infer_constraints_for_callable( if actual_arg_type is None: continue - actual_type = mapper.expand_actual_type(actual_arg_type, arg_kinds[actual], - callee.arg_names[i], callee.arg_kinds[i]) + actual_type = mapper.expand_actual_type( + actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i] + ) c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF) constraints.extend(c) return constraints -def infer_constraints(template: Type, actual: Type, - direction: int) -> List[Constraint]: +def infer_constraints(template: Type, actual: Type, direction: int) -> List[Constraint]: """Infer type constraints. Match a template type, which may contain type variable references, @@ -123,8 +152,7 @@ def infer_constraints(template: Type, actual: Type, return _infer_constraints(template, actual, direction) -def _infer_constraints(template: Type, actual: Type, - direction: int) -> List[Constraint]: +def _infer_constraints(template: Type, actual: Type, direction: int) -> List[Constraint]: orig_template = template template = get_proper_type(template) @@ -180,37 +208,44 @@ def _infer_constraints(template: Type, actual: Type, # variable if possible. This seems to help with some real-world # use cases. return any_constraints( - [infer_constraints_if_possible(template, a_item, direction) - for a_item in items], - eager=True) + [infer_constraints_if_possible(template, a_item, direction) for a_item in items], + eager=True, + ) if direction == SUPERTYPE_OF and isinstance(template, UnionType): # When the template is a union, we are okay with leaving some # type variables indeterminate. This helps with some special # cases, though this isn't very principled. return any_constraints( - [infer_constraints_if_possible(t_item, actual, direction) - for t_item in template.items], - eager=False) + [ + infer_constraints_if_possible(t_item, actual, direction) + for t_item in template.items + ], + eager=False, + ) # Remaining cases are handled by ConstraintBuilderVisitor. return template.accept(ConstraintBuilderVisitor(actual, direction)) -def infer_constraints_if_possible(template: Type, actual: Type, - direction: int) -> Optional[List[Constraint]]: +def infer_constraints_if_possible( + template: Type, actual: Type, direction: int +) -> Optional[List[Constraint]]: """Like infer_constraints, but return None if the input relation is known to be unsatisfiable, for example if template=List[T] and actual=int. (In this case infer_constraints would return [], just like it would for an automatically satisfied relation like template=List[T] and actual=object.) """ - if (direction == SUBTYPE_OF and - not mypy.subtypes.is_subtype(erase_typevars(template), actual)): + if direction == SUBTYPE_OF and not mypy.subtypes.is_subtype(erase_typevars(template), actual): return None - if (direction == SUPERTYPE_OF and - not mypy.subtypes.is_subtype(actual, erase_typevars(template))): + if direction == SUPERTYPE_OF and not mypy.subtypes.is_subtype( + actual, erase_typevars(template) + ): return None - if (direction == SUPERTYPE_OF and isinstance(template, TypeVarType) and - not mypy.subtypes.is_subtype(actual, erase_typevars(template.upper_bound))): + if ( + direction == SUPERTYPE_OF + and isinstance(template, TypeVarType) + and not mypy.subtypes.is_subtype(actual, erase_typevars(template.upper_bound)) + ): # This is not caught by the above branch because of the erase_typevars() call, # that would return 'Any' for a type variable. return None @@ -278,9 +313,7 @@ def any_constraints(options: List[Optional[List[Constraint]]], eager: bool) -> L if option in trivial_options: continue if option is not None: - merged_option: Optional[List[Constraint]] = [ - merge_with_any(c) for c in option - ] + merged_option: Optional[List[Constraint]] = [merge_with_any(c) for c in option] else: merged_option = None merged_options.append(merged_option) @@ -302,13 +335,14 @@ def is_same_constraints(x: List[Constraint], y: List[Constraint]) -> bool: def is_same_constraint(c1: Constraint, c2: Constraint) -> bool: # Ignore direction when comparing constraints against Any. - skip_op_check = ( - isinstance(get_proper_type(c1.target), AnyType) and - isinstance(get_proper_type(c2.target), AnyType) + skip_op_check = isinstance(get_proper_type(c1.target), AnyType) and isinstance( + get_proper_type(c2.target), AnyType + ) + return ( + c1.type_var == c2.type_var + and (c1.op == c2.op or skip_op_check) + and mypy.sametypes.is_same_type(c1.target, c2.target) ) - return (c1.type_var == c2.type_var - and (c1.op == c2.op or skip_op_check) - and mypy.sametypes.is_same_type(c1.target, c2.target)) def is_similar_constraints(x: List[Constraint], y: List[Constraint]) -> bool: @@ -330,9 +364,8 @@ def _is_similar_constraints(x: List[Constraint], y: List[Constraint]) -> bool: has_similar = False for c2 in y: # Ignore direction when either constraint is against Any. - skip_op_check = ( - isinstance(get_proper_type(c1.target), AnyType) or - isinstance(get_proper_type(c2.target), AnyType) + skip_op_check = isinstance(get_proper_type(c1.target), AnyType) or isinstance( + get_proper_type(c2.target), AnyType ) if c1.type_var == c2.type_var and (c1.op == c2.op or skip_op_check): has_similar = True @@ -411,8 +444,10 @@ def visit_partial_type(self, template: PartialType) -> List[Constraint]: # Non-trivial leaf type def visit_type_var(self, template: TypeVarType) -> List[Constraint]: - assert False, ("Unexpected TypeVarType in ConstraintBuilderVisitor" - " (should have been handled in infer_constraints)") + assert False, ( + "Unexpected TypeVarType in ConstraintBuilderVisitor" + " (should have been handled in infer_constraints)" + ) def visit_param_spec(self, template: ParamSpecType) -> List[Constraint]: # Can't infer ParamSpecs from component values (only via Callable[P, T]). @@ -437,13 +472,15 @@ def visit_instance(self, template: Instance) -> List[Constraint]: original_actual = actual = self.actual res: List[Constraint] = [] if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol: - if template.type.protocol_members == ['__call__']: + if template.type.protocol_members == ["__call__"]: # Special case: a generic callback protocol - if not any(mypy.sametypes.is_same_type(template, t) - for t in template.type.inferring): + if not any( + mypy.sametypes.is_same_type(template, t) for t in template.type.inferring + ): template.type.inferring.append(template) - call = mypy.subtypes.find_member('__call__', template, actual, - is_operator=True) + call = mypy.subtypes.find_member( + "__call__", template, actual, is_operator=True + ) assert call is not None if mypy.subtypes.is_subtype(actual, erase_typevars(call)): subres = infer_constraints(call, actual, self.direction) @@ -464,8 +501,7 @@ def visit_instance(self, template: Instance) -> List[Constraint]: assert isinstance(erased, Instance) # type: ignore # We always try nominal inference if possible, # it is much faster than the structural one. - if (self.direction == SUBTYPE_OF and - template.type.has_base(instance.type.fullname)): + if self.direction == SUBTYPE_OF and template.type.has_base(instance.type.fullname): mapped = map_instance_to_supertype(template, instance.type) tvars = mapped.type.defn.type_vars # N.B: We use zip instead of indexing because the lengths might have @@ -476,11 +512,11 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # The constraints for generic type parameters depend on variance. # Include constraints from both directions if invariant. if tvar.variance != CONTRAVARIANT: - res.extend(infer_constraints( - mapped_arg, instance_arg, self.direction)) + res.extend(infer_constraints(mapped_arg, instance_arg, self.direction)) if tvar.variance != COVARIANT: - res.extend(infer_constraints( - mapped_arg, instance_arg, neg_op(self.direction))) + res.extend( + infer_constraints(mapped_arg, instance_arg, neg_op(self.direction)) + ) elif isinstance(tvar, ParamSpecType) and isinstance(mapped_arg, ParamSpecType): suffix = get_proper_type(instance_arg) @@ -495,9 +531,10 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # TODO: constraints between prefixes prefix = mapped_arg.prefix suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types):], - suffix.arg_kinds[len(prefix.arg_kinds):], - suffix.arg_names[len(prefix.arg_names):]) + suffix.arg_types[len(prefix.arg_types) :], + suffix.arg_kinds[len(prefix.arg_kinds) :], + suffix.arg_names[len(prefix.arg_names) :], + ) res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix)) elif isinstance(suffix, ParamSpecType): res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix)) @@ -505,16 +542,13 @@ def visit_instance(self, template: Instance) -> List[Constraint]: raise NotImplementedError return res - elif (self.direction == SUPERTYPE_OF and - instance.type.has_base(template.type.fullname)): + elif self.direction == SUPERTYPE_OF and instance.type.has_base(template.type.fullname): mapped = map_instance_to_supertype(instance, template.type) tvars = template.type.defn.type_vars if template.type.has_type_var_tuple_type: - mapped_prefix, mapped_middle, mapped_suffix = ( - split_with_instance(mapped) - ) - template_prefix, template_middle, template_suffix = ( - split_with_instance(template) + mapped_prefix, mapped_middle, mapped_suffix = split_with_instance(mapped) + template_prefix, template_middle, template_suffix = split_with_instance( + template ) # Add a constraint for the type var tuple, and then @@ -522,14 +556,14 @@ def visit_instance(self, template: Instance) -> List[Constraint]: template_unpack = extract_unpack(template_middle) if template_unpack is not None: if isinstance(template_unpack, TypeVarTupleType): - res.append(Constraint( - template_unpack.id, - SUPERTYPE_OF, - TypeList(list(mapped_middle)) - )) + res.append( + Constraint( + template_unpack.id, SUPERTYPE_OF, TypeList(list(mapped_middle)) + ) + ) elif ( - isinstance(template_unpack, Instance) and - template_unpack.type.fullname == "builtins.tuple" + isinstance(template_unpack, Instance) + and template_unpack.type.fullname == "builtins.tuple" ): # TODO: check homogenous tuple case raise NotImplementedError @@ -559,13 +593,14 @@ def visit_instance(self, template: Instance) -> List[Constraint]: # The constraints for generic type parameters depend on variance. # Include constraints from both directions if invariant. if tvar.variance != CONTRAVARIANT: - res.extend(infer_constraints( - template_arg, mapped_arg, self.direction)) + res.extend(infer_constraints(template_arg, mapped_arg, self.direction)) if tvar.variance != COVARIANT: - res.extend(infer_constraints( - template_arg, mapped_arg, neg_op(self.direction))) - elif (isinstance(tvar, ParamSpecType) and - isinstance(template_arg, ParamSpecType)): + res.extend( + infer_constraints(template_arg, mapped_arg, neg_op(self.direction)) + ) + elif isinstance(tvar, ParamSpecType) and isinstance( + template_arg, ParamSpecType + ): suffix = get_proper_type(mapped_arg) if isinstance(suffix, CallableType): @@ -580,53 +615,66 @@ def visit_instance(self, template: Instance) -> List[Constraint]: prefix = template_arg.prefix suffix = suffix.copy_modified( - suffix.arg_types[len(prefix.arg_types):], - suffix.arg_kinds[len(prefix.arg_kinds):], - suffix.arg_names[len(prefix.arg_names):]) + suffix.arg_types[len(prefix.arg_types) :], + suffix.arg_kinds[len(prefix.arg_kinds) :], + suffix.arg_names[len(prefix.arg_names) :], + ) res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix)) elif isinstance(suffix, ParamSpecType): res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix)) return res - if (template.type.is_protocol and self.direction == SUPERTYPE_OF and - # We avoid infinite recursion for structural subtypes by checking - # whether this type already appeared in the inference chain. - # This is a conservative way to break the inference cycles. - # It never produces any "false" constraints but gives up soon - # on purely structural inference cycles, see #3829. - # Note that we use is_protocol_implementation instead of is_subtype - # because some type may be considered a subtype of a protocol - # due to _promote, but still not implement the protocol. - not any(mypy.sametypes.is_same_type(template, t) - for t in template.type.inferring) and - mypy.subtypes.is_protocol_implementation(instance, erased)): + if ( + template.type.is_protocol + and self.direction == SUPERTYPE_OF + and + # We avoid infinite recursion for structural subtypes by checking + # whether this type already appeared in the inference chain. + # This is a conservative way to break the inference cycles. + # It never produces any "false" constraints but gives up soon + # on purely structural inference cycles, see #3829. + # Note that we use is_protocol_implementation instead of is_subtype + # because some type may be considered a subtype of a protocol + # due to _promote, but still not implement the protocol. + not any(mypy.sametypes.is_same_type(template, t) for t in template.type.inferring) + and mypy.subtypes.is_protocol_implementation(instance, erased) + ): template.type.inferring.append(template) - res.extend(self.infer_constraints_from_protocol_members( - instance, template, original_actual, template)) + res.extend( + self.infer_constraints_from_protocol_members( + instance, template, original_actual, template + ) + ) template.type.inferring.pop() return res - elif (instance.type.is_protocol and self.direction == SUBTYPE_OF and - # We avoid infinite recursion for structural subtypes also here. - not any(mypy.sametypes.is_same_type(instance, i) - for i in instance.type.inferring) and - mypy.subtypes.is_protocol_implementation(erased, instance)): + elif ( + instance.type.is_protocol + and self.direction == SUBTYPE_OF + and + # We avoid infinite recursion for structural subtypes also here. + not any(mypy.sametypes.is_same_type(instance, i) for i in instance.type.inferring) + and mypy.subtypes.is_protocol_implementation(erased, instance) + ): instance.type.inferring.append(instance) - res.extend(self.infer_constraints_from_protocol_members( - instance, template, template, instance)) + res.extend( + self.infer_constraints_from_protocol_members( + instance, template, template, instance + ) + ) instance.type.inferring.pop() return res if isinstance(actual, AnyType): return self.infer_against_any(template.args, actual) - if (isinstance(actual, TupleType) - and is_named_instance(template, TUPLE_LIKE_INSTANCE_NAMES) - and self.direction == SUPERTYPE_OF): + if ( + isinstance(actual, TupleType) + and is_named_instance(template, TUPLE_LIKE_INSTANCE_NAMES) + and self.direction == SUPERTYPE_OF + ): for item in actual.items: cb = infer_constraints(template.args[0], item, SUPERTYPE_OF) res.extend(cb) return res elif isinstance(actual, TupleType) and self.direction == SUPERTYPE_OF: - return infer_constraints(template, - mypy.typeops.tuple_fallback(actual), - self.direction) + return infer_constraints(template, mypy.typeops.tuple_fallback(actual), self.direction) elif isinstance(actual, TypeVarType): if not actual.values: return infer_constraints(template, actual.upper_bound, self.direction) @@ -638,10 +686,9 @@ def visit_instance(self, template: Instance) -> List[Constraint]: else: return [] - def infer_constraints_from_protocol_members(self, - instance: Instance, template: Instance, - subtype: Type, protocol: Instance, - ) -> List[Constraint]: + def infer_constraints_from_protocol_members( + self, instance: Instance, template: Instance, subtype: Type, protocol: Instance + ) -> List[Constraint]: """Infer constraints for situations where either 'template' or 'instance' is a protocol. The 'protocol' is the one of two that is an instance of protocol type, 'subtype' @@ -657,8 +704,7 @@ def infer_constraints_from_protocol_members(self, # The above is safe since at this point we know that 'instance' is a subtype # of (erased) 'template', therefore it defines all protocol members res.extend(infer_constraints(temp, inst, self.direction)) - if (mypy.subtypes.IS_SETTABLE in - mypy.subtypes.get_member_flags(member, protocol.type)): + if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol.type): # Settable members are invariant, add opposite constraints res.extend(infer_constraints(temp, inst, neg_op(self.direction))) return res @@ -688,13 +734,18 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: cactual_ps = cactual.param_spec() if not cactual_ps: - res.append(Constraint(param_spec.id, - SUBTYPE_OF, - cactual.copy_modified( - arg_types=cactual.arg_types[prefix_len:], - arg_kinds=cactual.arg_kinds[prefix_len:], - arg_names=cactual.arg_names[prefix_len:], - ret_type=NoneType()))) + res.append( + Constraint( + param_spec.id, + SUBTYPE_OF, + cactual.copy_modified( + arg_types=cactual.arg_types[prefix_len:], + arg_kinds=cactual.arg_kinds[prefix_len:], + arg_names=cactual.arg_names[prefix_len:], + ret_type=NoneType(), + ), + ) + ) else: res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps)) @@ -702,7 +753,8 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: cactual_prefix = cactual.copy_modified( arg_types=cactual.arg_types[:prefix_len], arg_kinds=cactual.arg_kinds[:prefix_len], - arg_names=cactual.arg_names[:prefix_len]) + arg_names=cactual.arg_names[:prefix_len], + ) # TODO: see above "FIX" comments for param_spec is None case # TODO: this assume positional arguments @@ -715,8 +767,7 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: if cactual.type_guard is not None: cactual_ret_type = cactual.type_guard - res.extend(infer_constraints(template_ret_type, cactual_ret_type, - self.direction)) + res.extend(infer_constraints(template_ret_type, cactual_ret_type, self.direction)) return res elif isinstance(self.actual, AnyType): param_spec = template.param_spec() @@ -725,9 +776,13 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: # FIX what if generic res = self.infer_against_any(template.arg_types, self.actual) else: - res = [Constraint(param_spec.id, - SUBTYPE_OF, - callable_with_ellipsis(any_type, any_type, template.fallback))] + res = [ + Constraint( + param_spec.id, + SUBTYPE_OF, + callable_with_ellipsis(any_type, any_type, template.fallback), + ) + ] res.extend(infer_constraints(template.ret_type, any_type, self.direction)) return res elif isinstance(self.actual, Overloaded): @@ -737,8 +792,9 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: elif isinstance(self.actual, Instance): # Instances with __call__ method defined are considered structural # subtypes of Callable with a compatible signature. - call = mypy.subtypes.find_member('__call__', self.actual, self.actual, - is_operator=True) + call = mypy.subtypes.find_member( + "__call__", self.actual, self.actual, is_operator=True + ) if call: return infer_constraints(template, call, self.direction) else: @@ -746,8 +802,9 @@ def visit_callable_type(self, template: CallableType) -> List[Constraint]: else: return [] - def infer_against_overloaded(self, overloaded: Overloaded, - template: CallableType) -> List[Constraint]: + def infer_against_overloaded( + self, overloaded: Overloaded, template: CallableType + ) -> List[Constraint]: # Create constraints by matching an overloaded type against a template. # This is tricky to do in general. We cheat by only matching against # the first overload item that is callable compatible. This @@ -760,8 +817,7 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]: actual = self.actual # TODO: Support subclasses of Tuple is_varlength_tuple = ( - isinstance(actual, Instance) - and actual.type.fullname == "builtins.tuple" + isinstance(actual, Instance) and actual.type.fullname == "builtins.tuple" ) unpack_index = find_unpack_in_list(template.items) @@ -781,10 +837,7 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]: # where we expect Tuple[int, Unpack[Ts]], but not for Tuple[str, Unpack[Ts]]. assert len(template.items) == 1 - if ( - isinstance(actual, (TupleType, AnyType)) - or is_varlength_tuple - ): + if isinstance(actual, (TupleType, AnyType)) or is_varlength_tuple: modified_actual = actual if isinstance(actual, TupleType): # Exclude the items from before and after the unpack index. @@ -794,21 +847,17 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]: unpack_index, len(template.items) - unpack_index - 1, ) - modified_actual = actual.copy_modified( - items=list(actual_items) + modified_actual = actual.copy_modified(items=list(actual_items)) + return [ + Constraint( + type_var=unpacked_type.id, op=self.direction, target=modified_actual ) - return [Constraint( - type_var=unpacked_type.id, - op=self.direction, - target=modified_actual, - )] + ] if isinstance(actual, TupleType) and len(actual.items) == len(template.items): res: List[Constraint] = [] for i in range(len(template.items)): - res.extend(infer_constraints(template.items[i], - actual.items[i], - self.direction)) + res.extend(infer_constraints(template.items[i], actual.items[i], self.direction)) return res elif isinstance(actual, AnyType): return self.infer_against_any(template.items, actual) @@ -822,9 +871,7 @@ def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]: # NOTE: Non-matching keys are ignored. Compatibility is checked # elsewhere so this shouldn't be unsafe. for (item_name, template_item_type, actual_item_type) in template.zip(actual): - res.extend(infer_constraints(template_item_type, - actual_item_type, - self.direction)) + res.extend(infer_constraints(template_item_type, actual_item_type, self.direction)) return res elif isinstance(actual, AnyType): return self.infer_against_any(template.items.values(), actual) @@ -832,8 +879,10 @@ def visit_typeddict_type(self, template: TypedDictType) -> List[Constraint]: return [] def visit_union_type(self, template: UnionType) -> List[Constraint]: - assert False, ("Unexpected UnionType in ConstraintBuilderVisitor" - " (should have been handled in infer_constraints)") + assert False, ( + "Unexpected UnionType in ConstraintBuilderVisitor" + " (should have been handled in infer_constraints)" + ) def visit_type_alias_type(self, template: TypeAliasType) -> List[Constraint]: assert False, f"This should be never called, got {template}" @@ -861,8 +910,7 @@ def visit_type_type(self, template: TypeType) -> List[Constraint]: if isinstance(self.actual, CallableType): return infer_constraints(template.item, self.actual.ret_type, self.direction) elif isinstance(self.actual, Overloaded): - return infer_constraints(template.item, self.actual.items[0].ret_type, - self.direction) + return infer_constraints(template.item, self.actual.items[0].ret_type, self.direction) elif isinstance(self.actual, TypeType): return infer_constraints(template.item, self.actual.item, self.direction) elif isinstance(self.actual, AnyType): @@ -879,7 +927,7 @@ def neg_op(op: int) -> int: elif op == SUPERTYPE_OF: return SUBTYPE_OF else: - raise ValueError(f'Invalid operator {op}') + raise ValueError(f"Invalid operator {op}") def find_matching_overload_item(overloaded: Overloaded, template: CallableType) -> CallableType: @@ -888,26 +936,27 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType) for item in items: # Return type may be indeterminate in the template, so ignore it when performing a # subtype check. - if mypy.subtypes.is_callable_compatible(item, template, - is_compat=mypy.subtypes.is_subtype, - ignore_return=True): + if mypy.subtypes.is_callable_compatible( + item, template, is_compat=mypy.subtypes.is_subtype, ignore_return=True + ): return item # Fall back to the first item if we can't find a match. This is totally arbitrary -- # maybe we should just bail out at this point. return items[0] -def find_matching_overload_items(overloaded: Overloaded, - template: CallableType) -> List[CallableType]: +def find_matching_overload_items( + overloaded: Overloaded, template: CallableType +) -> List[CallableType]: """Like find_matching_overload_item, but return all matches, not just the first.""" items = overloaded.items res = [] for item in items: # Return type may be indeterminate in the template, so ignore it when performing a # subtype check. - if mypy.subtypes.is_callable_compatible(item, template, - is_compat=mypy.subtypes.is_subtype, - ignore_return=True): + if mypy.subtypes.is_callable_compatible( + item, template, is_compat=mypy.subtypes.is_subtype, ignore_return=True + ): res.append(item) if not res: # Falling back to all items if we can't find a match is pretty arbitrary, but diff --git a/mypy/copytype.py b/mypy/copytype.py index 85d7d531c5a3b..e5a02d811d8b5 100644 --- a/mypy/copytype.py +++ b/mypy/copytype.py @@ -1,12 +1,32 @@ from typing import Any, cast from mypy.types import ( - ProperType, UnboundType, AnyType, NoneType, UninhabitedType, ErasedType, DeletedType, - Instance, TypeVarType, ParamSpecType, PartialType, CallableType, TupleType, TypedDictType, - LiteralType, UnionType, Overloaded, TypeType, TypeAliasType, UnpackType, Parameters, - TypeVarTupleType + AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + TypeAliasType, + TypedDictType, + TypeType, + TypeVarTupleType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, ) -from mypy.type_visitor import TypeVisitor + +# type_visitor needs to be imported after types +from mypy.type_visitor import TypeVisitor # isort: skip def copy_type(t: ProperType) -> ProperType: @@ -62,9 +82,13 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType: return self.copy_common(t, dup) def visit_parameters(self, t: Parameters) -> ProperType: - dup = Parameters(t.arg_types, t.arg_kinds, t.arg_names, - variables=t.variables, - is_ellipsis_args=t.is_ellipsis_args) + dup = Parameters( + t.arg_types, + t.arg_kinds, + t.arg_names, + variables=t.variables, + is_ellipsis_args=t.is_ellipsis_args, + ) return self.copy_common(t, dup) def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: diff --git a/mypy/defaults.py b/mypy/defaults.py index dc9e49c2e9c6c..8e1720572ece5 100644 --- a/mypy/defaults.py +++ b/mypy/defaults.py @@ -7,16 +7,9 @@ PYTHON3_VERSION_MIN: Final = (3, 4) CACHE_DIR: Final = ".mypy_cache" CONFIG_FILE: Final = ["mypy.ini", ".mypy.ini"] -PYPROJECT_CONFIG_FILES: Final = [ - "pyproject.toml", -] -SHARED_CONFIG_FILES: Final = [ - "setup.cfg", -] -USER_CONFIG_FILES: Final = [ - "~/.config/mypy/config", - "~/.mypy.ini", -] +PYPROJECT_CONFIG_FILES: Final = ["pyproject.toml"] +SHARED_CONFIG_FILES: Final = ["setup.cfg"] +USER_CONFIG_FILES: Final = ["~/.config/mypy/config", "~/.mypy.ini"] if os.environ.get("XDG_CONFIG_HOME"): USER_CONFIG_FILES.insert(0, os.path.join(os.environ["XDG_CONFIG_HOME"], "mypy/config")) diff --git a/mypy/dmypy/__main__.py b/mypy/dmypy/__main__.py index a8da701799ec6..93fb21eff5b54 100644 --- a/mypy/dmypy/__main__.py +++ b/mypy/dmypy/__main__.py @@ -1,4 +1,4 @@ from mypy.dmypy.client import console_entry -if __name__ == '__main__': +if __name__ == "__main__": console_entry() diff --git a/mypy/dmypy/client.py b/mypy/dmypy/client.py index 3ed85dca9750b..5b6a6a0a072fa 100644 --- a/mypy/dmypy/client.py +++ b/mypy/dmypy/client.py @@ -12,15 +12,14 @@ import sys import time import traceback +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple -from typing import Any, Callable, Dict, Mapping, Optional, Tuple, List from typing_extensions import NoReturn +from mypy.dmypy_os import alive, kill from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive from mypy.ipc import IPCClient, IPCException -from mypy.dmypy_os import alive, kill from mypy.util import check_python_version, get_terminal_width - from mypy.version import __version__ # Argument parser. Subparsers are tied to action functions by the @@ -32,103 +31,150 @@ def __init__(self, prog: str) -> None: super().__init__(prog=prog, max_help_position=30) -parser = argparse.ArgumentParser(prog='dmypy', - description="Client for mypy daemon mode", - fromfile_prefix_chars='@') +parser = argparse.ArgumentParser( + prog="dmypy", description="Client for mypy daemon mode", fromfile_prefix_chars="@" +) parser.set_defaults(action=None) -parser.add_argument('--status-file', default=DEFAULT_STATUS_FILE, - help='status file to retrieve daemon details') -parser.add_argument('-V', '--version', action='version', - version='%(prog)s ' + __version__, - help="Show program's version number and exit") +parser.add_argument( + "--status-file", default=DEFAULT_STATUS_FILE, help="status file to retrieve daemon details" +) +parser.add_argument( + "-V", + "--version", + action="version", + version="%(prog)s " + __version__, + help="Show program's version number and exit", +) subparsers = parser.add_subparsers() -start_parser = p = subparsers.add_parser('start', help="Start daemon") -p.add_argument('--log-file', metavar='FILE', type=str, - help="Direct daemon stdout/stderr to FILE") -p.add_argument('--timeout', metavar='TIMEOUT', type=int, - help="Server shutdown timeout (in seconds)") -p.add_argument('flags', metavar='FLAG', nargs='*', type=str, - help="Regular mypy flags (precede with --)") - -restart_parser = p = subparsers.add_parser('restart', - help="Restart daemon (stop or kill followed by start)") -p.add_argument('--log-file', metavar='FILE', type=str, - help="Direct daemon stdout/stderr to FILE") -p.add_argument('--timeout', metavar='TIMEOUT', type=int, - help="Server shutdown timeout (in seconds)") -p.add_argument('flags', metavar='FLAG', nargs='*', type=str, - help="Regular mypy flags (precede with --)") - -status_parser = p = subparsers.add_parser('status', help="Show daemon status") -p.add_argument('-v', '--verbose', action='store_true', help="Print detailed status") -p.add_argument('--fswatcher-dump-file', help="Collect information about the current file state") - -stop_parser = p = subparsers.add_parser('stop', help="Stop daemon (asks it politely to go away)") - -kill_parser = p = subparsers.add_parser('kill', help="Kill daemon (kills the process)") - -check_parser = p = subparsers.add_parser('check', formatter_class=AugmentedHelpFormatter, - help="Check some files (requires daemon)") -p.add_argument('-v', '--verbose', action='store_true', help="Print detailed status") -p.add_argument('-q', '--quiet', action='store_true', help=argparse.SUPPRESS) # Deprecated -p.add_argument('--junit-xml', help="Write junit.xml to the given file") -p.add_argument('--perf-stats-file', help='write performance information to the given file') -p.add_argument('files', metavar='FILE', nargs='+', help="File (or directory) to check") - -run_parser = p = subparsers.add_parser('run', formatter_class=AugmentedHelpFormatter, - help="Check some files, [re]starting daemon if necessary") -p.add_argument('-v', '--verbose', action='store_true', help="Print detailed status") -p.add_argument('--junit-xml', help="Write junit.xml to the given file") -p.add_argument('--perf-stats-file', help='write performance information to the given file') -p.add_argument('--timeout', metavar='TIMEOUT', type=int, - help="Server shutdown timeout (in seconds)") -p.add_argument('--log-file', metavar='FILE', type=str, - help="Direct daemon stdout/stderr to FILE") -p.add_argument('flags', metavar='ARG', nargs='*', type=str, - help="Regular mypy flags and files (precede with --)") - -recheck_parser = p = subparsers.add_parser('recheck', formatter_class=AugmentedHelpFormatter, - help="Re-check the previous list of files, with optional modifications (requires daemon)") -p.add_argument('-v', '--verbose', action='store_true', help="Print detailed status") -p.add_argument('-q', '--quiet', action='store_true', help=argparse.SUPPRESS) # Deprecated -p.add_argument('--junit-xml', help="Write junit.xml to the given file") -p.add_argument('--perf-stats-file', help='write performance information to the given file') -p.add_argument('--update', metavar='FILE', nargs='*', - help="Files in the run to add or check again (default: all from previous run)") -p.add_argument('--remove', metavar='FILE', nargs='*', - help="Files to remove from the run") - -suggest_parser = p = subparsers.add_parser('suggest', - help="Suggest a signature or show call sites for a specific function") -p.add_argument('function', metavar='FUNCTION', type=str, - help="Function specified as '[package.]module.[class.]function'") -p.add_argument('--json', action='store_true', - help="Produce json that pyannotate can use to apply a suggestion") -p.add_argument('--no-errors', action='store_true', - help="Only produce suggestions that cause no errors") -p.add_argument('--no-any', action='store_true', - help="Only produce suggestions that don't contain Any") -p.add_argument('--flex-any', type=float, - help="Allow anys in types if they go above a certain score (scores are from 0-1)") -p.add_argument('--try-text', action='store_true', - help="Try using unicode wherever str is inferred") -p.add_argument('--callsites', action='store_true', - help="Find callsites instead of suggesting a type") -p.add_argument('--use-fixme', metavar='NAME', type=str, - help="A dummy name to use instead of Any for types that can't be inferred") -p.add_argument('--max-guesses', type=int, - help="Set the maximum number of types to try for a function (default 64)") - -hang_parser = p = subparsers.add_parser('hang', help="Hang for 100 seconds") - -daemon_parser = p = subparsers.add_parser('daemon', help="Run daemon in foreground") -p.add_argument('--timeout', metavar='TIMEOUT', type=int, - help="Server shutdown timeout (in seconds)") -p.add_argument('flags', metavar='FLAG', nargs='*', type=str, - help="Regular mypy flags (precede with --)") -p.add_argument('--options-data', help=argparse.SUPPRESS) -help_parser = p = subparsers.add_parser('help') +start_parser = p = subparsers.add_parser("start", help="Start daemon") +p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE") +p.add_argument( + "--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)" +) +p.add_argument( + "flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)" +) + +restart_parser = p = subparsers.add_parser( + "restart", help="Restart daemon (stop or kill followed by start)" +) +p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE") +p.add_argument( + "--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)" +) +p.add_argument( + "flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)" +) + +status_parser = p = subparsers.add_parser("status", help="Show daemon status") +p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status") +p.add_argument("--fswatcher-dump-file", help="Collect information about the current file state") + +stop_parser = p = subparsers.add_parser("stop", help="Stop daemon (asks it politely to go away)") + +kill_parser = p = subparsers.add_parser("kill", help="Kill daemon (kills the process)") + +check_parser = p = subparsers.add_parser( + "check", formatter_class=AugmentedHelpFormatter, help="Check some files (requires daemon)" +) +p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status") +p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated +p.add_argument("--junit-xml", help="Write junit.xml to the given file") +p.add_argument("--perf-stats-file", help="write performance information to the given file") +p.add_argument("files", metavar="FILE", nargs="+", help="File (or directory) to check") + +run_parser = p = subparsers.add_parser( + "run", + formatter_class=AugmentedHelpFormatter, + help="Check some files, [re]starting daemon if necessary", +) +p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status") +p.add_argument("--junit-xml", help="Write junit.xml to the given file") +p.add_argument("--perf-stats-file", help="write performance information to the given file") +p.add_argument( + "--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)" +) +p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE") +p.add_argument( + "flags", + metavar="ARG", + nargs="*", + type=str, + help="Regular mypy flags and files (precede with --)", +) + +recheck_parser = p = subparsers.add_parser( + "recheck", + formatter_class=AugmentedHelpFormatter, + help="Re-check the previous list of files, with optional modifications (requires daemon)", +) +p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status") +p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated +p.add_argument("--junit-xml", help="Write junit.xml to the given file") +p.add_argument("--perf-stats-file", help="write performance information to the given file") +p.add_argument( + "--update", + metavar="FILE", + nargs="*", + help="Files in the run to add or check again (default: all from previous run)", +) +p.add_argument("--remove", metavar="FILE", nargs="*", help="Files to remove from the run") + +suggest_parser = p = subparsers.add_parser( + "suggest", help="Suggest a signature or show call sites for a specific function" +) +p.add_argument( + "function", + metavar="FUNCTION", + type=str, + help="Function specified as '[package.]module.[class.]function'", +) +p.add_argument( + "--json", + action="store_true", + help="Produce json that pyannotate can use to apply a suggestion", +) +p.add_argument( + "--no-errors", action="store_true", help="Only produce suggestions that cause no errors" +) +p.add_argument( + "--no-any", action="store_true", help="Only produce suggestions that don't contain Any" +) +p.add_argument( + "--flex-any", + type=float, + help="Allow anys in types if they go above a certain score (scores are from 0-1)", +) +p.add_argument( + "--try-text", action="store_true", help="Try using unicode wherever str is inferred" +) +p.add_argument( + "--callsites", action="store_true", help="Find callsites instead of suggesting a type" +) +p.add_argument( + "--use-fixme", + metavar="NAME", + type=str, + help="A dummy name to use instead of Any for types that can't be inferred", +) +p.add_argument( + "--max-guesses", + type=int, + help="Set the maximum number of types to try for a function (default 64)", +) + +hang_parser = p = subparsers.add_parser("hang", help="Hang for 100 seconds") + +daemon_parser = p = subparsers.add_parser("daemon", help="Run daemon in foreground") +p.add_argument( + "--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)" +) +p.add_argument( + "flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)" +) +p.add_argument("--options-data", help=argparse.SUPPRESS) +help_parser = p = subparsers.add_parser("help") del p @@ -141,12 +187,13 @@ class BadStatus(Exception): - Status file malformed - Process whose pid is in the status file does not exist """ + pass def main(argv: List[str]) -> None: """The code is top-down.""" - check_python_version('dmypy') + check_python_version("dmypy") args = parser.parse_args(argv) if not args.action: parser.print_usage() @@ -172,14 +219,17 @@ def fail(msg: str) -> NoReturn: def action(subparser: argparse.ArgumentParser) -> Callable[[ActionFunction], ActionFunction]: """Decorator to tie an action function to a subparser.""" + def register(func: ActionFunction) -> ActionFunction: subparser.set_defaults(action=func) return func + return register # Action functions (run in client from command line). + @action(start_parser) def do_start(args: argparse.Namespace) -> None: """Start daemon (it must not already be running). @@ -226,6 +276,7 @@ def start_server(args: argparse.Namespace, allow_sources: bool = False) -> None: """Start the server from command arguments and wait for it.""" # Lazy import so this import doesn't slow down other commands. from mypy.dmypy_server import daemonize, process_start_options + start_options = process_start_options(args.flags, allow_sources) if daemonize(start_options, args.status_file, timeout=args.timeout, log_file=args.log_file): sys.exit(2) @@ -270,15 +321,15 @@ def do_run(args: argparse.Namespace) -> None: # Bad or missing status file or dead process; good to start. start_server(args, allow_sources=True) t0 = time.time() - response = request(args.status_file, 'run', version=__version__, args=args.flags) + response = request(args.status_file, "run", version=__version__, args=args.flags) # If the daemon signals that a restart is necessary, do it - if 'restart' in response: + if "restart" in response: print(f"Restarting: {response['restart']}") restart_server(args, allow_sources=True) - response = request(args.status_file, 'run', version=__version__, args=args.flags) + response = request(args.status_file, "run", version=__version__, args=args.flags) t1 = time.time() - response['roundtrip_time'] = t1 - t0 + response["roundtrip_time"] = t1 - t0 check_output(response, args.verbose, args.junit_xml, args.perf_stats_file) @@ -294,12 +345,12 @@ def do_status(args: argparse.Namespace) -> None: # Both check_status() and request() may raise BadStatus, # which will be handled by main(). check_status(status) - response = request(args.status_file, 'status', - fswatcher_dump_file=args.fswatcher_dump_file, - timeout=5) - if args.verbose or 'error' in response: + response = request( + args.status_file, "status", fswatcher_dump_file=args.fswatcher_dump_file, timeout=5 + ) + if args.verbose or "error" in response: show_stats(response) - if 'error' in response: + if "error" in response: fail(f"Daemon is stuck; consider {sys.argv[0]} kill") print("Daemon is up and running") @@ -308,8 +359,8 @@ def do_status(args: argparse.Namespace) -> None: def do_stop(args: argparse.Namespace) -> None: """Stop daemon via a 'stop' request.""" # May raise BadStatus, which will be handled by main(). - response = request(args.status_file, 'stop', timeout=5) - if 'error' in response: + response = request(args.status_file, "stop", timeout=5) + if "error" in response: show_stats(response) fail(f"Daemon is stuck; consider {sys.argv[0]} kill") else: @@ -332,9 +383,9 @@ def do_kill(args: argparse.Namespace) -> None: def do_check(args: argparse.Namespace) -> None: """Ask the daemon to check a list of files.""" t0 = time.time() - response = request(args.status_file, 'check', files=args.files) + response = request(args.status_file, "check", files=args.files) t1 = time.time() - response['roundtrip_time'] = t1 - t0 + response["roundtrip_time"] = t1 - t0 check_output(response, args.verbose, args.junit_xml, args.perf_stats_file) @@ -355,11 +406,11 @@ def do_recheck(args: argparse.Namespace) -> None: """ t0 = time.time() if args.remove is not None or args.update is not None: - response = request(args.status_file, 'recheck', remove=args.remove, update=args.update) + response = request(args.status_file, "recheck", remove=args.remove, update=args.update) else: - response = request(args.status_file, 'recheck') + response = request(args.status_file, "recheck") t1 = time.time() - response['roundtrip_time'] = t1 - t0 + response["roundtrip_time"] = t1 - t0 check_output(response, args.verbose, args.junit_xml, args.perf_stats_file) @@ -370,24 +421,36 @@ def do_suggest(args: argparse.Namespace) -> None: This just prints whatever the daemon reports as output. For now it may be closer to a list of call sites. """ - response = request(args.status_file, 'suggest', function=args.function, - json=args.json, callsites=args.callsites, no_errors=args.no_errors, - no_any=args.no_any, flex_any=args.flex_any, try_text=args.try_text, - use_fixme=args.use_fixme, max_guesses=args.max_guesses) + response = request( + args.status_file, + "suggest", + function=args.function, + json=args.json, + callsites=args.callsites, + no_errors=args.no_errors, + no_any=args.no_any, + flex_any=args.flex_any, + try_text=args.try_text, + use_fixme=args.use_fixme, + max_guesses=args.max_guesses, + ) check_output(response, verbose=False, junit_xml=None, perf_stats_file=None) -def check_output(response: Dict[str, Any], verbose: bool, - junit_xml: Optional[str], - perf_stats_file: Optional[str]) -> None: +def check_output( + response: Dict[str, Any], + verbose: bool, + junit_xml: Optional[str], + perf_stats_file: Optional[str], +) -> None: """Print the output from a check or recheck command. Call sys.exit() unless the status code is zero. """ - if 'error' in response: - fail(response['error']) + if "error" in response: + fail(response["error"]) try: - out, err, status_code = response['out'], response['err'], response['status'] + out, err, status_code = response["out"], response["err"], response["status"] except KeyError: fail(f"Response: {str(response)}") sys.stdout.write(out) @@ -398,12 +461,19 @@ def check_output(response: Dict[str, Any], verbose: bool, if junit_xml: # Lazy import so this import doesn't slow things down when not writing junit from mypy.util import write_junit_xml + messages = (out + err).splitlines() - write_junit_xml(response['roundtrip_time'], bool(err), messages, junit_xml, - response['python_version'], response['platform']) + write_junit_xml( + response["roundtrip_time"], + bool(err), + messages, + junit_xml, + response["python_version"], + response["platform"], + ) if perf_stats_file: - telemetry = response.get('stats', {}) - with open(perf_stats_file, 'w') as f: + telemetry = response.get("stats", {}) + with open(perf_stats_file, "w") as f: json.dump(telemetry, f) if status_code: @@ -412,19 +482,19 @@ def check_output(response: Dict[str, Any], verbose: bool, def show_stats(response: Mapping[str, object]) -> None: for key, value in sorted(response.items()): - if key not in ('out', 'err'): + if key not in ("out", "err"): print("%-24s: %10s" % (key, "%.3f" % value if isinstance(value, float) else value)) else: value = repr(value)[1:-1] if len(value) > 50: - value = value[:40] + ' ...' + value = value[:40] + " ..." print("%-24s: %s" % (key, value)) @action(hang_parser) def do_hang(args: argparse.Namespace) -> None: """Hang for 100 seconds, as a debug hack.""" - print(request(args.status_file, 'hang', timeout=1)) + print(request(args.status_file, "hang", timeout=1)) @action(daemon_parser) @@ -432,13 +502,15 @@ def do_daemon(args: argparse.Namespace) -> None: """Serve requests in the foreground.""" # Lazy import so this import doesn't slow down other commands. from mypy.dmypy_server import Server, process_start_options + if args.options_data: from mypy.options import Options + options_dict, timeout, log_file = pickle.loads(base64.b64decode(args.options_data)) options_obj = Options() options = options_obj.apply_changes(options_dict) if log_file: - sys.stdout = sys.stderr = open(log_file, 'a', buffering=1) + sys.stdout = sys.stderr = open(log_file, "a", buffering=1) fd = sys.stdout.fileno() os.dup2(fd, 2) os.dup2(fd, 1) @@ -457,8 +529,9 @@ def do_help(args: argparse.Namespace) -> None: # Client-side infrastructure. -def request(status_file: str, command: str, *, timeout: Optional[int] = None, - **kwds: object) -> Dict[str, Any]: +def request( + status_file: str, command: str, *, timeout: Optional[int] = None, **kwds: object +) -> Dict[str, Any]: """Send a request to the daemon. Return the JSON dict with the response. @@ -472,19 +545,19 @@ def request(status_file: str, command: str, *, timeout: Optional[int] = None, """ response: Dict[str, str] = {} args = dict(kwds) - args['command'] = command + args["command"] = command # Tell the server whether this request was initiated from a human-facing terminal, # so that it can format the type checking output accordingly. - args['is_tty'] = sys.stdout.isatty() or int(os.getenv('MYPY_FORCE_COLOR', '0')) > 0 - args['terminal_width'] = get_terminal_width() - bdata = json.dumps(args).encode('utf8') + args["is_tty"] = sys.stdout.isatty() or int(os.getenv("MYPY_FORCE_COLOR", "0")) > 0 + args["terminal_width"] = get_terminal_width() + bdata = json.dumps(args).encode("utf8") _, name = get_status(status_file) try: with IPCClient(name, timeout) as client: client.write(bdata) response = receive(client) except (OSError, IPCException) as err: - return {'error': str(err)} + return {"error": str(err)} # TODO: Other errors, e.g. ValueError, UnicodeError else: return response @@ -508,16 +581,16 @@ def check_status(data: Dict[str, Any]) -> Tuple[int, str]: Raise BadStatus if something's wrong. """ - if 'pid' not in data: + if "pid" not in data: raise BadStatus("Invalid status file (no pid field)") - pid = data['pid'] + pid = data["pid"] if not isinstance(pid, int): raise BadStatus("pid field is not an int") if not alive(pid): raise BadStatus("Daemon has died") - if 'connection_name' not in data: + if "connection_name" not in data: raise BadStatus("Invalid status file (no connection_name field)") - connection_name = data['connection_name'] + connection_name = data["connection_name"] if not isinstance(connection_name, str): raise BadStatus("connection_name field is not a string") return pid, connection_name diff --git a/mypy/dmypy_os.py b/mypy/dmypy_os.py index 1405e0a309e9c..0b823b6f41326 100644 --- a/mypy/dmypy_os.py +++ b/mypy/dmypy_os.py @@ -1,11 +1,10 @@ import sys - from typing import Any, Callable -if sys.platform == 'win32': +if sys.platform == "win32": import ctypes - from ctypes.wintypes import DWORD, HANDLE import subprocess + from ctypes.wintypes import DWORD, HANDLE PROCESS_QUERY_LIMITED_INFORMATION = ctypes.c_ulong(0x1000) @@ -19,12 +18,10 @@ def alive(pid: int) -> bool: """Is the process alive?""" - if sys.platform == 'win32': + if sys.platform == "win32": # why can't anything be easy... status = DWORD() - handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, - 0, - pid) + handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid) GetExitCodeProcess(handle, ctypes.byref(status)) return status.value == 259 # STILL_ACTIVE else: @@ -37,7 +34,7 @@ def alive(pid: int) -> bool: def kill(pid: int) -> None: """Kill the process.""" - if sys.platform == 'win32': + if sys.platform == "win32": subprocess.check_output(f"taskkill /pid {pid} /f /t") else: os.kill(pid, signal.SIGKILL) diff --git a/mypy/dmypy_server.py b/mypy/dmypy_server.py index 3fbda6b1a7d8f..a271b20792beb 100644 --- a/mypy/dmypy_server.py +++ b/mypy/dmypy_server.py @@ -15,35 +15,37 @@ import time import traceback from contextlib import redirect_stderr, redirect_stdout +from typing import AbstractSet, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple -from typing import AbstractSet, Any, Callable, Dict, List, Optional, Sequence, Tuple, Set from typing_extensions import Final import mypy.build import mypy.errors import mypy.main -from mypy.find_sources import create_source_list, InvalidSourceList -from mypy.server.update import FineGrainedBuildManager, refresh_suppressed_submodules from mypy.dmypy_util import receive -from mypy.ipc import IPCServer +from mypy.find_sources import InvalidSourceList, create_source_list from mypy.fscache import FileSystemCache -from mypy.fswatcher import FileSystemWatcher, FileData -from mypy.modulefinder import BuildSource, compute_search_paths, FindModuleCache, SearchPaths +from mypy.fswatcher import FileData, FileSystemWatcher +from mypy.ipc import IPCServer +from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths, compute_search_paths from mypy.options import Options -from mypy.suggestions import SuggestionFailure, SuggestionEngine +from mypy.server.update import FineGrainedBuildManager, refresh_suppressed_submodules +from mypy.suggestions import SuggestionEngine, SuggestionFailure from mypy.typestate import reset_global_state -from mypy.version import __version__ from mypy.util import FancyFormatter, count_stats +from mypy.version import __version__ MEM_PROFILE: Final = False # If True, dump memory profile after initialization -if sys.platform == 'win32': +if sys.platform == "win32": from subprocess import STARTUPINFO - def daemonize(options: Options, - status_file: str, - timeout: Optional[int] = None, - log_file: Optional[str] = None) -> int: + def daemonize( + options: Options, + status_file: str, + timeout: Optional[int] = None, + log_file: Optional[str] = None, + ) -> int: """Create the daemon process via "dmypy daemon" and pass options via command line When creating the daemon grandchild, we create it in a new console, which is @@ -54,21 +56,20 @@ def daemonize(options: Options, It also pickles the options to be unpickled by mypy. """ - command = [sys.executable, '-m', 'mypy.dmypy', '--status-file', status_file, 'daemon'] + command = [sys.executable, "-m", "mypy.dmypy", "--status-file", status_file, "daemon"] pickled_options = pickle.dumps((options.snapshot(), timeout, log_file)) command.append(f'--options-data="{base64.b64encode(pickled_options).decode()}"') info = STARTUPINFO() info.dwFlags = 0x1 # STARTF_USESHOWWINDOW aka use wShowWindow's value info.wShowWindow = 0 # SW_HIDE aka make the window invisible try: - subprocess.Popen(command, - creationflags=0x10, # CREATE_NEW_CONSOLE - startupinfo=info) + subprocess.Popen(command, creationflags=0x10, startupinfo=info) # CREATE_NEW_CONSOLE return 0 except subprocess.CalledProcessError as e: return e.returncode else: + def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> int: """Arrange to call func() in a grandchild of the current process. @@ -82,7 +83,7 @@ def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> i if pid: # Parent process: wait for child in case things go bad there. npid, sts = os.waitpid(pid, 0) - sig = sts & 0xff + sig = sts & 0xFF if sig: print("Child killed by signal", sig) return -sig @@ -94,7 +95,7 @@ def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> i try: os.setsid() # Detach controlling terminal os.umask(0o27) - devnull = os.open('/dev/null', os.O_RDWR) + devnull = os.open("/dev/null", os.O_RDWR) os.dup2(devnull, 0) os.dup2(devnull, 1) os.dup2(devnull, 2) @@ -105,7 +106,7 @@ def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> i os._exit(0) # Grandchild: run the server. if log_file: - sys.stdout = sys.stderr = open(log_file, 'a', buffering=1) + sys.stdout = sys.stderr = open(log_file, "a", buffering=1) fd = sys.stdout.fileno() os.dup2(fd, 2) os.dup2(fd, 1) @@ -114,10 +115,12 @@ def _daemonize_cb(func: Callable[[], None], log_file: Optional[str] = None) -> i # Make sure we never get back into the caller. os._exit(1) - def daemonize(options: Options, - status_file: str, - timeout: Optional[int] = None, - log_file: Optional[str] = None) -> int: + def daemonize( + options: Options, + status_file: str, + timeout: Optional[int] = None, + log_file: Optional[str] = None, + ) -> int: """Run the mypy daemon in a grandchild of the current process Return 0 for success, exit status for failure, negative if @@ -125,6 +128,7 @@ def daemonize(options: Options, """ return _daemonize_cb(Server(options, status_file, timeout).serve, log_file) + # Server code. CONNECTION_NAME: Final = "dmypy" @@ -132,17 +136,19 @@ def daemonize(options: Options, def process_start_options(flags: List[str], allow_sources: bool) -> Options: _, options = mypy.main.process_options( - ['-i'] + flags, require_targets=False, server_options=True + ["-i"] + flags, require_targets=False, server_options=True ) if options.report_dirs: print("dmypy: Ignoring report generation settings. Start/restart cannot generate reports.") if options.junit_xml: - print("dmypy: Ignoring report generation settings. " - "Start/restart does not support --junit-xml. Pass it to check/recheck instead") + print( + "dmypy: Ignoring report generation settings. " + "Start/restart does not support --junit-xml. Pass it to check/recheck instead" + ) options.junit_xml = None if not options.incremental: sys.exit("dmypy: start/restart should not disable incremental mode") - if options.follow_imports not in ('skip', 'error', 'normal'): + if options.follow_imports not in ("skip", "error", "normal"): sys.exit("dmypy: follow-imports=silent not supported") return options @@ -152,7 +158,7 @@ def ignore_suppressed_imports(module: str) -> bool: # Various submodules of 'encodings' can be suppressed, since it # uses module-level '__getattr__'. Skip them since there are many # of them, and following imports to them is kind of pointless. - return module.startswith('encodings.') + return module.startswith("encodings.") ModulePathPair = Tuple[str, str] @@ -165,9 +171,7 @@ class Server: # NOTE: the instance is constructed in the parent process but # serve() is called in the grandchild (by daemonize()). - def __init__(self, options: Options, - status_file: str, - timeout: Optional[int] = None) -> None: + def __init__(self, options: Options, status_file: str, timeout: Optional[int] = None) -> None: """Initialize the server with the desired mypy flags.""" self.options = options # Snapshot the options info before we muck with it, to detect changes @@ -200,47 +204,44 @@ def __init__(self, options: Options, self.formatter = FancyFormatter(sys.stdout, sys.stderr, options.show_error_codes) def _response_metadata(self) -> Dict[str, str]: - py_version = f'{self.options.python_version[0]}_{self.options.python_version[1]}' - return { - 'platform': self.options.platform, - 'python_version': py_version, - } + py_version = f"{self.options.python_version[0]}_{self.options.python_version[1]}" + return {"platform": self.options.platform, "python_version": py_version} def serve(self) -> None: """Serve requests, synchronously (no thread or fork).""" command = None server = IPCServer(CONNECTION_NAME, self.timeout) try: - with open(self.status_file, 'w') as f: - json.dump({'pid': os.getpid(), 'connection_name': server.connection_name}, f) - f.write('\n') # I like my JSON with a trailing newline + with open(self.status_file, "w") as f: + json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f) + f.write("\n") # I like my JSON with a trailing newline while True: with server: data = receive(server) resp: Dict[str, Any] = {} - if 'command' not in data: - resp = {'error': "No command found in request"} + if "command" not in data: + resp = {"error": "No command found in request"} else: - command = data['command'] + command = data["command"] if not isinstance(command, str): - resp = {'error': "Command is not a string"} + resp = {"error": "Command is not a string"} else: - command = data.pop('command') + command = data.pop("command") try: resp = self.run_command(command, data) except Exception: # If we are crashing, report the crash to the client tb = traceback.format_exception(*sys.exc_info()) - resp = {'error': "Daemon crashed!\n" + "".join(tb)} + resp = {"error": "Daemon crashed!\n" + "".join(tb)} resp.update(self._response_metadata()) - server.write(json.dumps(resp).encode('utf8')) + server.write(json.dumps(resp).encode("utf8")) raise try: resp.update(self._response_metadata()) - server.write(json.dumps(resp).encode('utf8')) + server.write(json.dumps(resp).encode("utf8")) except OSError: pass # Maybe the client hung up - if command == 'stop': + if command == "stop": reset_global_state() sys.exit(0) finally: @@ -249,7 +250,7 @@ def serve(self) -> None: # simplify the logic and always remove the file, since # that could cause us to remove a future server's # status file.) - if command != 'stop': + if command != "stop": os.unlink(self.status_file) try: server.cleanup() # try to remove the socket dir on Linux @@ -261,15 +262,15 @@ def serve(self) -> None: def run_command(self, command: str, data: Dict[str, object]) -> Dict[str, object]: """Run a specific command from the registry.""" - key = 'cmd_' + command + key = "cmd_" + command method = getattr(self.__class__, key, None) if method is None: - return {'error': f"Unrecognized command '{command}'"} + return {"error": f"Unrecognized command '{command}'"} else: - if command not in {'check', 'recheck', 'run'}: + if command not in {"check", "recheck", "run"}: # Only the above commands use some error formatting. - del data['is_tty'] - del data['terminal_width'] + del data["is_tty"] + del data["terminal_width"] return method(self, **data) # Command functions (run in the server via RPC). @@ -279,10 +280,10 @@ def cmd_status(self, fswatcher_dump_file: Optional[str] = None) -> Dict[str, obj res: Dict[str, object] = {} res.update(get_meminfo()) if fswatcher_dump_file: - data = self.fswatcher.dump_file_data() if hasattr(self, 'fswatcher') else {} + data = self.fswatcher.dump_file_data() if hasattr(self, "fswatcher") else {} # Using .dumps and then writing was noticeably faster than using dump s = json.dumps(data) - with open(fswatcher_dump_file, 'w') as f: + with open(fswatcher_dump_file, "w") as f: f.write(s) return res @@ -295,8 +296,9 @@ def cmd_stop(self) -> Dict[str, object]: os.unlink(self.status_file) return {} - def cmd_run(self, version: str, args: Sequence[str], - is_tty: bool, terminal_width: int) -> Dict[str, object]: + def cmd_run( + self, version: str, args: Sequence[str], is_tty: bool, terminal_width: int + ) -> Dict[str, object]: """Check a list of files, triggering a restart if needed.""" stderr = io.StringIO() stdout = io.StringIO() @@ -306,17 +308,18 @@ def cmd_run(self, version: str, args: Sequence[str], with redirect_stderr(stderr): with redirect_stdout(stdout): sources, options = mypy.main.process_options( - ['-i'] + list(args), + ["-i"] + list(args), require_targets=True, server_options=True, fscache=self.fscache, - program='mypy-daemon', - header=argparse.SUPPRESS) + program="mypy-daemon", + header=argparse.SUPPRESS, + ) # Signal that we need to restart if the options have changed if self.options_snapshot != options.snapshot(): - return {'restart': 'configuration changed'} + return {"restart": "configuration changed"} if __version__ != version: - return {'restart': 'mypy version changed'} + return {"restart": "mypy version changed"} if self.fine_grained_manager: manager = self.fine_grained_manager.manager start_plugins_snapshot = manager.plugins_snapshot @@ -324,27 +327,30 @@ def cmd_run(self, version: str, args: Sequence[str], options, manager.errors, sys.stdout, extra_plugins=() ) if current_plugins_snapshot != start_plugins_snapshot: - return {'restart': 'plugins changed'} + return {"restart": "plugins changed"} except InvalidSourceList as err: - return {'out': '', 'err': str(err), 'status': 2} + return {"out": "", "err": str(err), "status": 2} except SystemExit as e: - return {'out': stdout.getvalue(), 'err': stderr.getvalue(), 'status': e.code} + return {"out": stdout.getvalue(), "err": stderr.getvalue(), "status": e.code} return self.check(sources, is_tty, terminal_width) - def cmd_check(self, files: Sequence[str], - is_tty: bool, terminal_width: int) -> Dict[str, object]: + def cmd_check( + self, files: Sequence[str], is_tty: bool, terminal_width: int + ) -> Dict[str, object]: """Check a list of files.""" try: sources = create_source_list(files, self.options, self.fscache) except InvalidSourceList as err: - return {'out': '', 'err': str(err), 'status': 2} + return {"out": "", "err": str(err), "status": 2} return self.check(sources, is_tty, terminal_width) - def cmd_recheck(self, - is_tty: bool, - terminal_width: int, - remove: Optional[List[str]] = None, - update: Optional[List[str]] = None) -> Dict[str, object]: + def cmd_recheck( + self, + is_tty: bool, + terminal_width: int, + remove: Optional[List[str]] = None, + update: Optional[List[str]] = None, + ) -> Dict[str, object]: """Check the same list of files we checked most recently. If remove/update is given, they modify the previous list; @@ -352,7 +358,7 @@ def cmd_recheck(self, """ t0 = time.time() if not self.fine_grained_manager: - return {'error': "Command 'recheck' is only valid after a 'check' command"} + return {"error": "Command 'recheck' is only valid after a 'check' command"} sources = self.previous_sources if remove: removals = set(remove) @@ -363,7 +369,7 @@ def cmd_recheck(self, try: added_sources = create_source_list(added, self.options, self.fscache) except InvalidSourceList as err: - return {'out': '', 'err': str(err), 'status': 2} + return {"out": "", "err": str(err), "status": 2} sources = sources + added_sources # Make a copy! t1 = time.time() manager = self.fine_grained_manager.manager @@ -378,8 +384,9 @@ def cmd_recheck(self, self.update_stats(res) return res - def check(self, sources: List[BuildSource], - is_tty: bool, terminal_width: int) -> Dict[str, Any]: + def check( + self, sources: List[BuildSource], is_tty: bool, terminal_width: int + ) -> Dict[str, Any]: """Check using fine-grained incremental mode. If is_tty is True format the output nicely with colors and summary line @@ -406,31 +413,30 @@ def update_stats(self, res: Dict[str, Any]) -> None: if self.fine_grained_manager: manager = self.fine_grained_manager.manager manager.dump_stats() - res['stats'] = manager.stats + res["stats"] = manager.stats manager.stats = {} def following_imports(self) -> bool: """Are we following imports?""" # TODO: What about silent? - return self.options.follow_imports == 'normal' + return self.options.follow_imports == "normal" - def initialize_fine_grained(self, sources: List[BuildSource], - is_tty: bool, terminal_width: int) -> Dict[str, Any]: + def initialize_fine_grained( + self, sources: List[BuildSource], is_tty: bool, terminal_width: int + ) -> Dict[str, Any]: self.fswatcher = FileSystemWatcher(self.fscache) t0 = time.time() self.update_sources(sources) t1 = time.time() try: - result = mypy.build.build(sources=sources, - options=self.options, - fscache=self.fscache) + result = mypy.build.build(sources=sources, options=self.options, fscache=self.fscache) except mypy.errors.CompileError as e: - output = ''.join(s + '\n' for s in e.messages) + output = "".join(s + "\n" for s in e.messages) if e.use_stdout: - out, err = output, '' + out, err = output, "" else: - out, err = '', output - return {'out': out, 'err': err, 'status': 2} + out, err = "", output + return {"out": out, "err": err, "status": 2} messages = result.errors self.fine_grained_manager = FineGrainedBuildManager(result) @@ -449,15 +455,20 @@ def initialize_fine_grained(self, sources: List[BuildSource], # the fswatcher, so we pick up the changes. for state in self.fine_grained_manager.graph.values(): meta = state.meta - if meta is None: continue + if meta is None: + continue assert state.path is not None self.fswatcher.set_file_data( state.path, - FileData(st_mtime=float(meta.mtime), st_size=meta.size, hash=meta.hash)) + FileData(st_mtime=float(meta.mtime), st_size=meta.size, hash=meta.hash), + ) changed, removed = self.find_changed(sources) - changed += self.find_added_suppressed(self.fine_grained_manager.graph, set(), - self.fine_grained_manager.manager.search_paths) + changed += self.find_added_suppressed( + self.fine_grained_manager.graph, + set(), + self.fine_grained_manager.manager.search_paths, + ) # Find anything that has had its dependency list change for state in self.fine_grained_manager.graph.values(): @@ -479,7 +490,8 @@ def initialize_fine_grained(self, sources: List[BuildSource], build_time=t2 - t1, find_changes_time=t3 - t2, fg_update_time=t4 - t3, - files_changed=len(removed) + len(changed)) + files_changed=len(removed) + len(changed), + ) else: # Stores the initial state of sources as a side effect. @@ -487,17 +499,19 @@ def initialize_fine_grained(self, sources: List[BuildSource], if MEM_PROFILE: from mypy.memprofile import print_memory_profile + print_memory_profile(run_gc=False) status = 1 if messages else 0 messages = self.pretty_messages(messages, len(sources), is_tty, terminal_width) - return {'out': ''.join(s + '\n' for s in messages), 'err': '', 'status': status} - - def fine_grained_increment(self, - sources: List[BuildSource], - remove: Optional[List[str]] = None, - update: Optional[List[str]] = None, - ) -> List[str]: + return {"out": "".join(s + "\n" for s in messages), "err": "", "status": status} + + def fine_grained_increment( + self, + sources: List[BuildSource], + remove: Optional[List[str]] = None, + update: Optional[List[str]] = None, + ) -> List[str]: """Perform a fine-grained type checking increment. If remove and update are None, determine changed paths by using @@ -521,8 +535,9 @@ def fine_grained_increment(self, # Use the remove/update lists to update fswatcher. # This avoids calling stat() for unchanged files. changed, removed = self.update_changed(sources, remove or [], update or []) - changed += self.find_added_suppressed(self.fine_grained_manager.graph, set(), - manager.search_paths) + changed += self.find_added_suppressed( + self.fine_grained_manager.graph, set(), manager.search_paths + ) manager.search_paths = compute_search_paths(sources, manager.options, manager.data_dir) t1 = time.time() manager.log(f"fine-grained increment: find_changed: {t1 - t0:.3f}s") @@ -532,7 +547,8 @@ def fine_grained_increment(self, manager.add_stats( find_changes_time=t1 - t0, fg_update_time=t2 - t1, - files_changed=len(removed) + len(changed)) + files_changed=len(removed) + len(changed), + ) self.previous_sources = sources return messages @@ -613,11 +629,7 @@ def refresh_file(module: str, path: str) -> List[str]: for module_id, path in new_unsuppressed: new_messages = refresh_suppressed_submodules( - module_id, path, - fine_grained_manager.deps, - graph, - self.fscache, - refresh_file + module_id, path, fine_grained_manager.deps, graph, self.fscache, refresh_file ) if new_messages is not None: messages = new_messages @@ -652,17 +664,18 @@ def refresh_file(module: str, path: str) -> List[str]: fg_update_time=t2 - t1, refresh_suppressed_time=t3 - t2, find_added_supressed_time=t4 - t3, - cleanup_time=t5 - t4) + cleanup_time=t5 - t4, + ) return messages def find_reachable_changed_modules( - self, - roots: List[BuildSource], - graph: mypy.build.Graph, - seen: Set[str], - changed_paths: AbstractSet[str]) -> Tuple[List[Tuple[str, str]], - List[BuildSource]]: + self, + roots: List[BuildSource], + graph: mypy.build.Graph, + seen: Set[str], + changed_paths: AbstractSet[str], + ) -> Tuple[List[Tuple[str, str]], List[BuildSource]]: """Follow imports within graph from given sources until hitting changed modules. If we find a changed module, we can't continue following imports as the imports @@ -694,22 +707,19 @@ def find_reachable_changed_modules( for dep in state.dependencies: if dep not in seen: seen.add(dep) - worklist.append(BuildSource(graph[dep].path, - graph[dep].id)) + worklist.append(BuildSource(graph[dep].path, graph[dep].id)) return changed, new_files - def direct_imports(self, - module: Tuple[str, str], - graph: mypy.build.Graph) -> List[BuildSource]: + def direct_imports( + self, module: Tuple[str, str], graph: mypy.build.Graph + ) -> List[BuildSource]: """Return the direct imports of module not included in seen.""" state = graph[module[0]] - return [BuildSource(graph[dep].path, dep) - for dep in state.dependencies] + return [BuildSource(graph[dep].path, dep) for dep in state.dependencies] - def find_added_suppressed(self, - graph: mypy.build.Graph, - seen: Set[str], - search_paths: SearchPaths) -> List[Tuple[str, str]]: + def find_added_suppressed( + self, graph: mypy.build.Graph, seen: Set[str], search_paths: SearchPaths + ) -> List[Tuple[str, str]]: """Find suppressed modules that have been added (and not included in seen). Args: @@ -724,14 +734,16 @@ def find_added_suppressed(self, # Filter out things that shouldn't actually be considered suppressed. # # TODO: Figure out why these are treated as suppressed - all_suppressed = {module - for module in all_suppressed - if module not in graph and not ignore_suppressed_imports(module)} + all_suppressed = { + module + for module in all_suppressed + if module not in graph and not ignore_suppressed_imports(module) + } # Optimization: skip top-level packages that are obviously not # there, to avoid calling the relatively slow find_module() # below too many times. - packages = {module.split('.', 1)[0] for module in all_suppressed} + packages = {module.split(".", 1)[0] for module in all_suppressed} packages = filter_out_missing_top_level_packages(packages, search_paths, self.fscache) # TODO: Namespace packages @@ -741,42 +753,47 @@ def find_added_suppressed(self, found = [] for module in all_suppressed: - top_level_pkg = module.split('.', 1)[0] + top_level_pkg = module.split(".", 1)[0] if top_level_pkg not in packages: # Fast path: non-existent top-level package continue result = finder.find_module(module, fast_path=True) if isinstance(result, str) and module not in seen: # When not following imports, we only follow imports to .pyi files. - if not self.following_imports() and not result.endswith('.pyi'): + if not self.following_imports() and not result.endswith(".pyi"): continue found.append((module, result)) seen.add(module) return found - def increment_output(self, - messages: List[str], - sources: List[BuildSource], - is_tty: bool, - terminal_width: int) -> Dict[str, Any]: + def increment_output( + self, messages: List[str], sources: List[BuildSource], is_tty: bool, terminal_width: int + ) -> Dict[str, Any]: status = 1 if messages else 0 messages = self.pretty_messages(messages, len(sources), is_tty, terminal_width) - return {'out': ''.join(s + '\n' for s in messages), 'err': '', 'status': status} - - def pretty_messages(self, messages: List[str], n_sources: int, - is_tty: bool = False, terminal_width: Optional[int] = None) -> List[str]: + return {"out": "".join(s + "\n" for s in messages), "err": "", "status": status} + + def pretty_messages( + self, + messages: List[str], + n_sources: int, + is_tty: bool = False, + terminal_width: Optional[int] = None, + ) -> List[str]: use_color = self.options.color_output and is_tty fit_width = self.options.pretty and is_tty if fit_width: - messages = self.formatter.fit_in_terminal(messages, - fixed_terminal_width=terminal_width) + messages = self.formatter.fit_in_terminal( + messages, fixed_terminal_width=terminal_width + ) if self.options.error_summary: summary: Optional[str] = None n_errors, n_notes, n_files = count_stats(messages) if n_errors: - summary = self.formatter.format_error(n_errors, n_files, n_sources, - use_color=use_color) + summary = self.formatter.format_error( + n_errors, n_files, n_sources, use_color=use_color + ) elif not messages or n_notes == len(messages): summary = self.formatter.format_success(n_sources, use_color) if summary: @@ -793,11 +810,9 @@ def update_sources(self, sources: List[BuildSource]) -> None: paths = [path for path in paths if self.fscache.isfile(path)] self.fswatcher.add_watched_paths(paths) - def update_changed(self, - sources: List[BuildSource], - remove: List[str], - update: List[str], - ) -> ChangesAndRemovals: + def update_changed( + self, sources: List[BuildSource], remove: List[str], update: List[str] + ) -> ChangesAndRemovals: changed_paths = self.fswatcher.update_changed(remove, update) return self._find_changed(sources, changed_paths) @@ -806,12 +821,15 @@ def find_changed(self, sources: List[BuildSource]) -> ChangesAndRemovals: changed_paths = self.fswatcher.find_changed() return self._find_changed(sources, changed_paths) - def _find_changed(self, sources: List[BuildSource], - changed_paths: AbstractSet[str]) -> ChangesAndRemovals: + def _find_changed( + self, sources: List[BuildSource], changed_paths: AbstractSet[str] + ) -> ChangesAndRemovals: # Find anything that has been added or modified - changed = [(source.module, source.path) - for source in sources - if source.path and source.path in changed_paths] + changed = [ + (source.module, source.path) + for source in sources + if source.path and source.path in changed_paths + ] # Now find anything that has been removed from the build modules = {source.module for source in sources} @@ -833,15 +851,13 @@ def _find_changed(self, sources: List[BuildSource], return changed, removed - def cmd_suggest(self, - function: str, - callsites: bool, - **kwargs: Any) -> Dict[str, object]: + def cmd_suggest(self, function: str, callsites: bool, **kwargs: Any) -> Dict[str, object]: """Suggest a signature for a function.""" if not self.fine_grained_manager: return { - 'error': "Command 'suggest' is only valid after a 'check' command" - " (that produces no parse errors)"} + "error": "Command 'suggest' is only valid after a 'check' command" + " (that produces no parse errors)" + } engine = SuggestionEngine(self.fine_grained_manager, **kwargs) try: if callsites: @@ -849,13 +865,13 @@ def cmd_suggest(self, else: out = engine.suggest(function) except SuggestionFailure as err: - return {'error': str(err)} + return {"error": str(err)} else: if not out: out = "No suggestions\n" elif not out.endswith("\n"): out += "\n" - return {'out': out, 'err': "", 'status': 0} + return {"out": out, "err": "", "status": 0} finally: self.flush_caches() @@ -868,7 +884,7 @@ def cmd_hang(self) -> Dict[str, object]: # Misc utilities. -MiB: Final = 2 ** 20 +MiB: Final = 2**20 def get_meminfo() -> Dict[str, Any]: @@ -876,31 +892,33 @@ def get_meminfo() -> Dict[str, Any]: try: import psutil # type: ignore # It's not in typeshed yet except ImportError: - res['memory_psutil_missing'] = ( - 'psutil not found, run pip install mypy[dmypy] ' - 'to install the needed components for dmypy' + res["memory_psutil_missing"] = ( + "psutil not found, run pip install mypy[dmypy] " + "to install the needed components for dmypy" ) else: process = psutil.Process() meminfo = process.memory_info() - res['memory_rss_mib'] = meminfo.rss / MiB - res['memory_vms_mib'] = meminfo.vms / MiB - if sys.platform == 'win32': - res['memory_maxrss_mib'] = meminfo.peak_wset / MiB + res["memory_rss_mib"] = meminfo.rss / MiB + res["memory_vms_mib"] = meminfo.vms / MiB + if sys.platform == "win32": + res["memory_maxrss_mib"] = meminfo.peak_wset / MiB else: # See https://stackoverflow.com/questions/938733/total-memory-used-by-python-process import resource # Since it doesn't exist on Windows. + rusage = resource.getrusage(resource.RUSAGE_SELF) - if sys.platform == 'darwin': + if sys.platform == "darwin": factor = 1 else: factor = 1024 # Linux - res['memory_maxrss_mib'] = rusage.ru_maxrss * factor / MiB + res["memory_maxrss_mib"] = rusage.ru_maxrss * factor / MiB return res -def find_all_sources_in_build(graph: mypy.build.Graph, - extra: Sequence[BuildSource] = ()) -> List[BuildSource]: +def find_all_sources_in_build( + graph: mypy.build.Graph, extra: Sequence[BuildSource] = () +) -> List[BuildSource]: result = list(extra) seen = {source.module for source in result} for module, state in graph.items(): @@ -929,9 +947,9 @@ def fix_module_deps(graph: mypy.build.Graph) -> None: state.suppressed_set = set(new_suppressed) -def filter_out_missing_top_level_packages(packages: Set[str], - search_paths: SearchPaths, - fscache: FileSystemCache) -> Set[str]: +def filter_out_missing_top_level_packages( + packages: Set[str], search_paths: SearchPaths, fscache: FileSystemCache +) -> Set[str]: """Quickly filter out obviously missing top-level packages. Return packages with entries that can't be found removed. @@ -942,10 +960,12 @@ def filter_out_missing_top_level_packages(packages: Set[str], # Start with a empty set and add all potential top-level packages. found = set() paths = ( - search_paths.python_path + search_paths.mypy_path + search_paths.package_path + - search_paths.typeshed_path + search_paths.python_path + + search_paths.mypy_path + + search_paths.package_path + + search_paths.typeshed_path ) - paths += tuple(os.path.join(p, '@python2') for p in search_paths.typeshed_path) + paths += tuple(os.path.join(p, "@python2") for p in search_paths.typeshed_path) for p in paths: try: entries = fscache.listdir(p) @@ -954,14 +974,14 @@ def filter_out_missing_top_level_packages(packages: Set[str], for entry in entries: # The code is hand-optimized for mypyc since this may be somewhat # performance-critical. - if entry.endswith('.py'): + if entry.endswith(".py"): entry = entry[:-3] - elif entry.endswith('.pyi'): + elif entry.endswith(".pyi"): entry = entry[:-4] - elif entry.endswith('-stubs'): + elif entry.endswith("-stubs"): # Possible PEP 561 stub package entry = entry[:-6] - if entry.endswith('-python2'): + if entry.endswith("-python2"): entry = entry[:-8] if entry in packages: found.add(entry) diff --git a/mypy/dmypy_util.py b/mypy/dmypy_util.py index 2b458c51e5a4d..31c1aee13860e 100644 --- a/mypy/dmypy_util.py +++ b/mypy/dmypy_util.py @@ -4,8 +4,8 @@ """ import json - from typing import Any + from typing_extensions import Final from mypy.ipc import IPCBase @@ -23,7 +23,7 @@ def receive(connection: IPCBase) -> Any: if not bdata: raise OSError("No data received") try: - data = json.loads(bdata.decode('utf8')) + data = json.loads(bdata.decode("utf8")) except Exception as e: raise OSError("Data received is not valid JSON") from e if not isinstance(data, dict): diff --git a/mypy/erasetype.py b/mypy/erasetype.py index ec0ad1338840c..2d8853bc3d249 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -1,13 +1,37 @@ -from typing import Optional, Container, Callable, List, Dict, cast +from typing import Callable, Container, Dict, List, Optional, cast +from mypy.nodes import ARG_STAR, ARG_STAR2 from mypy.types import ( - Type, TypeVisitor, UnboundType, AnyType, NoneType, TypeVarId, Instance, TypeVarType, - CallableType, TupleType, TypedDictType, UnionType, Overloaded, ErasedType, PartialType, - DeletedType, TypeTranslator, UninhabitedType, TypeType, TypeOfAny, LiteralType, ProperType, - get_proper_type, get_proper_types, TypeAliasType, ParamSpecType, Parameters, UnpackType, - TypeVarTupleType + AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeTranslator, + TypeType, + TypeVarId, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, + get_proper_types, ) -from mypy.nodes import ARG_STAR, ARG_STAR2 def erase_type(typ: Type) -> ProperType: @@ -27,7 +51,6 @@ def erase_type(typ: Type) -> ProperType: class EraseTypeVisitor(TypeVisitor[ProperType]): - def visit_unbound_type(self, t: UnboundType) -> ProperType: # TODO: replace with an assert after UnboundType can't leak from semantic analysis. return AnyType(TypeOfAny.from_error) @@ -100,6 +123,7 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: def visit_union_type(self, t: UnionType) -> ProperType: erased_items = [erase_type(item) for item in t.items] from mypy.typeops import make_simplified_union + return make_simplified_union(erased_items) def visit_type_type(self, t: TypeType) -> ProperType: @@ -113,10 +137,12 @@ def erase_typevars(t: Type, ids_to_erase: Optional[Container[TypeVarId]] = None) """Replace all type variables in a type with any, or just the ones in the provided collection. """ + def erase_id(id: TypeVarId) -> bool: if ids_to_erase is None: return True return id in ids_to_erase + return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form))) @@ -164,10 +190,7 @@ class LastKnownValueEraser(TypeTranslator): def visit_instance(self, t: Instance) -> Type: if not t.last_known_value and not t.args: return t - new_t = t.copy_modified( - args=[a.accept(self) for a in t.args], - last_known_value=None, - ) + new_t = t.copy_modified(args=[a.accept(self) for a in t.args], last_known_value=None) new_t.can_be_true = t.can_be_true new_t.can_be_false = t.can_be_false return new_t @@ -183,8 +206,7 @@ def visit_union_type(self, t: UnionType) -> Type: # Call make_simplified_union only on lists of instance types # that all have the same fullname, to avoid simplifying too # much. - instances = [item for item in new.items - if isinstance(get_proper_type(item), Instance)] + instances = [item for item in new.items if isinstance(get_proper_type(item), Instance)] # Avoid merge in simple cases such as optional types. if len(instances) > 1: instances_by_name: Dict[str, List[Instance]] = {} @@ -201,6 +223,7 @@ def visit_union_type(self, t: UnionType) -> Type: merged.append(item) else: from mypy.typeops import make_simplified_union + merged.append(make_simplified_union(types)) del instances_by_name[item.type.fullname] else: diff --git a/mypy/errorcodes.py b/mypy/errorcodes.py index e237e818edaea..f3e60064d6164 100644 --- a/mypy/errorcodes.py +++ b/mypy/errorcodes.py @@ -4,8 +4,8 @@ """ from typing import Dict, List -from typing_extensions import Final +from typing_extensions import Final # All created error codes are implicitly stored in this list. all_error_codes: List["ErrorCode"] = [] @@ -14,10 +14,9 @@ class ErrorCode: - def __init__(self, code: str, - description: str, - category: str, - default_enabled: bool = True) -> None: + def __init__( + self, code: str, description: str, category: str, default_enabled: bool = True + ) -> None: self.code = code self.description = description self.category = category @@ -25,7 +24,7 @@ def __init__(self, code: str, error_codes[code] = self def __str__(self) -> str: - return f'' + return f"" ATTR_DEFINED: Final = ErrorCode("attr-defined", "Check that attribute exists", "General") @@ -94,9 +93,7 @@ def __str__(self) -> str: EXIT_RETURN: Final = ErrorCode( "exit-return", "Warn about too general return type for '__exit__'", "General" ) -LITERAL_REQ: Final = ErrorCode( - "literal-required", "Check that value is a literal", 'General' -) +LITERAL_REQ: Final = ErrorCode("literal-required", "Check that value is a literal", "General") UNUSED_COROUTINE: Final = ErrorCode( "unused-coroutine", "Ensure that all coroutines are used", "General" ) @@ -113,9 +110,7 @@ def __str__(self) -> str: REDUNDANT_CAST: Final = ErrorCode( "redundant-cast", "Check that cast changes type of expression", "General" ) -ASSERT_TYPE: Final = ErrorCode( - "assert-type", "Check that assert_type() call succeeds", "General" -) +ASSERT_TYPE: Final = ErrorCode("assert-type", "Check that assert_type() call succeeds", "General") COMPARISON_OVERLAP: Final = ErrorCode( "comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General" ) diff --git a/mypy/errors.py b/mypy/errors.py index 978390cb93927..2656a6edf2c5c 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -1,20 +1,19 @@ import os.path import sys import traceback - -from mypy.backports import OrderedDict from collections import defaultdict +from typing import Callable, Dict, List, Optional, Set, TextIO, Tuple, TypeVar, Union -from typing import Tuple, List, TypeVar, Set, Dict, Optional, TextIO, Callable, Union from typing_extensions import Final, Literal, NoReturn -from mypy.scope import Scope -from mypy.options import Options -from mypy.version import __version__ as mypy_version -from mypy.errorcodes import ErrorCode, IMPORT -from mypy.message_registry import ErrorMessage from mypy import errorcodes as codes +from mypy.backports import OrderedDict +from mypy.errorcodes import IMPORT, ErrorCode +from mypy.message_registry import ErrorMessage +from mypy.options import Options +from mypy.scope import Scope from mypy.util import DEFAULT_SOURCE_OFFSET, is_typeshed_file +from mypy.version import __version__ as mypy_version T = TypeVar("T") @@ -33,7 +32,7 @@ class ErrorInfo: import_ctx: List[Tuple[str, int]] # The path to source file that was the source of this error. - file = '' + file = "" # The fully-qualified id of the source module for this error. module: Optional[str] = None @@ -45,22 +44,22 @@ class ErrorInfo: function_or_member: Optional[str] = "" # Unqualified, may be None # The line number related to this error within file. - line = 0 # -1 if unknown + line = 0 # -1 if unknown # The column number related to this error with file. - column = 0 # -1 if unknown + column = 0 # -1 if unknown # The end line number related to this error within file. - end_line = 0 # -1 if unknown + end_line = 0 # -1 if unknown # The end column number related to this error with file. - end_column = 0 # -1 if unknown + end_column = 0 # -1 if unknown # Either 'error' or 'note' - severity = '' + severity = "" # The error message. - message = '' + message = "" # The error code. code: Optional[ErrorCode] = None @@ -85,24 +84,26 @@ class ErrorInfo: # by mypy daemon) hidden = False - def __init__(self, - import_ctx: List[Tuple[str, int]], - file: str, - module: Optional[str], - typ: Optional[str], - function_or_member: Optional[str], - line: int, - column: int, - end_line: int, - end_column: int, - severity: str, - message: str, - code: Optional[ErrorCode], - blocker: bool, - only_once: bool, - allow_dups: bool, - origin: Optional[Tuple[str, int, int]] = None, - target: Optional[str] = None) -> None: + def __init__( + self, + import_ctx: List[Tuple[str, int]], + file: str, + module: Optional[str], + typ: Optional[str], + function_or_member: Optional[str], + line: int, + column: int, + end_line: int, + end_column: int, + severity: str, + message: str, + code: Optional[ErrorCode], + blocker: bool, + only_once: bool, + allow_dups: bool, + origin: Optional[Tuple[str, int, int]] = None, + target: Optional[str] = None, + ) -> None: self.import_ctx = import_ctx self.file = file self.module = module @@ -124,15 +125,7 @@ def __init__(self, # Type used internally to represent errors: # (path, line, column, end_line, end_column, severity, message, allow_dups, code) -ErrorTuple = Tuple[Optional[str], - int, - int, - int, - int, - str, - str, - bool, - Optional[ErrorCode]] +ErrorTuple = Tuple[Optional[str], int, int, int, int, str, str, bool, Optional[ErrorCode]] class ErrorWatcher: @@ -143,15 +136,20 @@ class ErrorWatcher: at the top of the stack, and is propagated down the stack unless filtered out by one of the ErrorWatcher instances. """ - def __init__(self, errors: 'Errors', *, - filter_errors: Union[bool, Callable[[str, ErrorInfo], bool]] = False, - save_filtered_errors: bool = False): + + def __init__( + self, + errors: "Errors", + *, + filter_errors: Union[bool, Callable[[str, ErrorInfo], bool]] = False, + save_filtered_errors: bool = False, + ): self.errors = errors self._has_new_errors = False self._filter = filter_errors self._filtered: Optional[List[ErrorInfo]] = [] if save_filtered_errors else None - def __enter__(self) -> 'ErrorWatcher': + def __enter__(self) -> "ErrorWatcher": self.errors._watchers.append(self) return self @@ -252,17 +250,19 @@ class Errors: _watchers: List[ErrorWatcher] = [] - def __init__(self, - show_error_context: bool = False, - show_column_numbers: bool = False, - show_error_codes: bool = False, - pretty: bool = False, - show_error_end: bool = False, - read_source: Optional[Callable[[str], Optional[List[str]]]] = None, - show_absolute_path: bool = False, - enabled_error_codes: Optional[Set[ErrorCode]] = None, - disabled_error_codes: Optional[Set[ErrorCode]] = None, - many_errors_threshold: int = -1) -> None: + def __init__( + self, + show_error_context: bool = False, + show_column_numbers: bool = False, + show_error_codes: bool = False, + pretty: bool = False, + show_error_end: bool = False, + read_source: Optional[Callable[[str], Optional[List[str]]]] = None, + show_absolute_path: bool = False, + enabled_error_codes: Optional[Set[ErrorCode]] = None, + disabled_error_codes: Optional[Set[ErrorCode]] = None, + many_errors_threshold: int = -1, + ) -> None: self.show_error_context = show_error_context self.show_column_numbers = show_column_numbers self.show_error_codes = show_error_codes @@ -299,7 +299,7 @@ def set_ignore_prefix(self, prefix: str) -> None: """Set path prefix that will be removed from all paths.""" prefix = os.path.normpath(prefix) # Add separator to the end, if not given. - if os.path.basename(prefix) != '': + if os.path.basename(prefix) != "": prefix += os.sep self.ignore_prefix = prefix @@ -310,9 +310,7 @@ def simplify_path(self, file: str) -> str: file = os.path.normpath(file) return remove_path_prefix(file, self.ignore_prefix) - def set_file(self, file: str, - module: Optional[str], - scope: Optional[Scope] = None) -> None: + def set_file(self, file: str, module: Optional[str], scope: Optional[Scope] = None) -> None: """Set the path and module id of the current file.""" # The path will be simplified later, in render_messages. That way # * 'file' is always a key that uniquely identifies a source file @@ -324,9 +322,9 @@ def set_file(self, file: str, self.target_module = module self.scope = scope - def set_file_ignored_lines(self, file: str, - ignored_lines: Dict[int, List[str]], - ignore_all: bool = False) -> None: + def set_file_ignored_lines( + self, file: str, ignored_lines: Dict[int, List[str]], ignore_all: bool = False + ) -> None: self.ignored_lines[file] = ignored_lines if ignore_all: self.ignored_files.add(file) @@ -350,21 +348,23 @@ def set_import_context(self, ctx: List[Tuple[str, int]]) -> None: """Replace the entire import context with a new value.""" self.import_ctx = ctx[:] - def report(self, - line: int, - column: Optional[int], - message: str, - code: Optional[ErrorCode] = None, - *, - blocker: bool = False, - severity: str = 'error', - file: Optional[str] = None, - only_once: bool = False, - allow_dups: bool = False, - origin_line: Optional[int] = None, - offset: int = 0, - end_line: Optional[int] = None, - end_column: Optional[int] = None) -> None: + def report( + self, + line: int, + column: Optional[int], + message: str, + code: Optional[ErrorCode] = None, + *, + blocker: bool = False, + severity: str = "error", + file: Optional[str] = None, + only_once: bool = False, + allow_dups: bool = False, + origin_line: Optional[int] = None, + offset: int = 0, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + ) -> None: """Report message at the given line using the current error context. Args: @@ -410,11 +410,25 @@ def report(self, code = code or (codes.MISC if not blocker else None) - info = ErrorInfo(self.import_context(), file, self.current_module(), type, - function, line, column, end_line, end_column, severity, message, code, - blocker, only_once, allow_dups, - origin=(self.file, origin_line, end_line), - target=self.current_target()) + info = ErrorInfo( + self.import_context(), + file, + self.current_module(), + type, + function, + line, + column, + end_line, + end_column, + severity, + message, + code, + blocker, + only_once, + allow_dups, + origin=(self.file, origin_line, end_line), + target=self.current_target(), + ) self.add_error_info(info) def _add_error_info(self, file: str, info: ErrorInfo) -> None: @@ -466,7 +480,8 @@ def add_error_info(self, info: ErrorInfo) -> None: if self.is_ignored_error(scope_line, info, self.ignored_lines[file]): # Annotation requests us to ignore all errors on this line. self.used_ignored_lines[file][scope_line].append( - (info.code or codes.MISC).code) + (info.code or codes.MISC).code + ) return if file in self.ignored_files: return @@ -493,12 +508,26 @@ def add_error_info(self, info: ErrorInfo) -> None: # code, report a more specific note. old_code = original_error_codes[info.code].code if old_code in ignored_codes: - msg = (f'Error code changed to {info.code.code}; "type: ignore" comment ' + - 'may be out of date') + msg = ( + f'Error code changed to {info.code.code}; "type: ignore" comment ' + + "may be out of date" + ) note = ErrorInfo( - info.import_ctx, info.file, info.module, info.type, info.function_or_member, - info.line, info.column, info.end_line, info.end_column, 'note', msg, - code=None, blocker=False, only_once=False, allow_dups=False + info.import_ctx, + info.file, + info.module, + info.type, + info.function_or_member, + info.line, + info.column, + info.end_line, + info.end_column, + "note", + msg, + code=None, + blocker=False, + only_once=False, + allow_dups=False, ) self._add_error_info(file, note) @@ -507,15 +536,17 @@ def has_many_errors(self) -> bool: return False if len(self.error_info_map) >= self.many_errors_threshold: return True - if sum(len(errors) - for errors in self.error_info_map.values()) >= self.many_errors_threshold: + if ( + sum(len(errors) for errors in self.error_info_map.values()) + >= self.many_errors_threshold + ): return True return False def report_hidden_errors(self, info: ErrorInfo) -> None: message = ( - '(Skipping most remaining errors due to unresolved imports or missing stubs; ' + - 'fix these first)' + "(Skipping most remaining errors due to unresolved imports or missing stubs; " + + "fix these first)" ) if message in self.only_once_messages: return @@ -530,7 +561,7 @@ def report_hidden_errors(self, info: ErrorInfo) -> None: column=info.column, end_line=info.end_line, end_column=info.end_column, - severity='note', + severity="note", message=message, code=None, blocker=False, @@ -599,14 +630,28 @@ def generate_unused_ignore_errors(self, file: str) -> None: unused_codes_message = f"[{', '.join(sorted(unused_ignored_codes))}]" message = f'Unused "type: ignore{unused_codes_message}" comment' # Don't use report since add_error_info will ignore the error! - info = ErrorInfo(self.import_context(), file, self.current_module(), None, - None, line, -1, line, -1, 'error', message, - None, False, False, False) + info = ErrorInfo( + self.import_context(), + file, + self.current_module(), + None, + None, + line, + -1, + line, + -1, + "error", + message, + None, + False, + False, + False, + ) self._add_error_info(file, info) - def generate_ignore_without_code_errors(self, - file: str, - is_warning_unused_ignores: bool) -> None: + def generate_ignore_without_code_errors( + self, file: str, is_warning_unused_ignores: bool + ) -> None: if is_typeshed_file(file) or file in self.ignored_files: return @@ -627,16 +672,30 @@ def generate_ignore_without_code_errors(self, if is_warning_unused_ignores and not used_ignored_lines[line]: continue - codes_hint = '' + codes_hint = "" ignored_codes = sorted(set(used_ignored_lines[line])) if ignored_codes: codes_hint = f' (consider "type: ignore[{", ".join(ignored_codes)}]" instead)' message = f'"type: ignore" comment without error code{codes_hint}' # Don't use report since add_error_info will ignore the error! - info = ErrorInfo(self.import_context(), file, self.current_module(), None, - None, line, -1, line, -1, 'error', message, codes.IGNORE_WITHOUT_CODE, - False, False, False) + info = ErrorInfo( + self.import_context(), + file, + self.current_module(), + None, + None, + line, + -1, + line, + -1, + "error", + message, + codes.IGNORE_WITHOUT_CODE, + False, + False, + False, + ) self._add_error_info(file, info) def num_messages(self) -> int: @@ -670,12 +729,13 @@ def raise_error(self, use_stdout: bool = True) -> NoReturn: """ # self.new_messages() will format all messages that haven't already # been returned from a file_messages() call. - raise CompileError(self.new_messages(), - use_stdout=use_stdout, - module_with_blocker=self.blocker_module()) + raise CompileError( + self.new_messages(), use_stdout=use_stdout, module_with_blocker=self.blocker_module() + ) - def format_messages(self, error_info: List[ErrorInfo], - source_lines: Optional[List[str]]) -> List[str]: + def format_messages( + self, error_info: List[ErrorInfo], source_lines: Optional[List[str]] + ) -> List[str]: """Return a string list that represents the error messages. Use a form suitable for displaying to the user. If self.pretty @@ -687,29 +747,37 @@ def format_messages(self, error_info: List[ErrorInfo], errors = self.render_messages(self.sort_messages(error_info)) errors = self.remove_duplicates(errors) for ( - file, line, column, end_line, end_column, severity, message, allow_dups, code + file, + line, + column, + end_line, + end_column, + severity, + message, + allow_dups, + code, ) in errors: - s = '' + s = "" if file is not None: if self.show_column_numbers and line >= 0 and column >= 0: - srcloc = f'{file}:{line}:{1 + column}' + srcloc = f"{file}:{line}:{1 + column}" if self.show_error_end and end_line >= 0 and end_column >= 0: - srcloc += f':{end_line}:{end_column}' + srcloc += f":{end_line}:{end_column}" elif line >= 0: - srcloc = f'{file}:{line}' + srcloc = f"{file}:{line}" else: srcloc = file - s = f'{srcloc}: {severity}: {message}' + s = f"{srcloc}: {severity}: {message}" else: s = message - if self.show_error_codes and code and severity != 'note': + if self.show_error_codes and code and severity != "note": # If note has an error code, it is related to a previous error. Avoid # displaying duplicate error codes. - s = f'{s} [{code.code}]' + s = f"{s} [{code.code}]" a.append(s) if self.pretty: # Add source code fragment and a location marker. - if severity == 'error' and source_lines and line > 0: + if severity == "error" and source_lines and line > 0: source_line = source_lines[line - 1] source_line_expanded = source_line.expandtabs() if column < 0: @@ -722,11 +790,11 @@ def format_messages(self, error_info: List[ErrorInfo], # Note, currently coloring uses the offset to detect source snippets, # so these offsets should not be arbitrary. - a.append(' ' * DEFAULT_SOURCE_OFFSET + source_line_expanded) - marker = '^' + a.append(" " * DEFAULT_SOURCE_OFFSET + source_line_expanded) + marker = "^" if end_line == line and end_column > column: marker = f'^{"~" * (end_column - column - 1)}' - a.append(' ' * (DEFAULT_SOURCE_OFFSET + column) + marker) + a.append(" " * (DEFAULT_SOURCE_OFFSET + column) + marker) return a def file_messages(self, path: str) -> List[str]: @@ -761,14 +829,10 @@ def targets(self) -> Set[str]: # TODO: Make sure that either target is always defined or that not being defined # is okay for fine-grained incremental checking. return { - info.target - for errs in self.error_info_map.values() - for info in errs - if info.target + info.target for errs in self.error_info_map.values() for info in errs if info.target } - def render_messages(self, - errors: List[ErrorInfo]) -> List[ErrorTuple]: + def render_messages(self, errors: List[ErrorInfo]) -> List[ErrorTuple]: """Translate the messages into a sequence of tuples. Each tuple is of form (path, line, col, severity, message, allow_dups, code). @@ -790,18 +854,19 @@ def render_messages(self, i = last while i >= 0: path, line = e.import_ctx[i] - fmt = '{}:{}: note: In module imported here' + fmt = "{}:{}: note: In module imported here" if i < last: - fmt = '{}:{}: note: ... from here' + fmt = "{}:{}: note: ... from here" if i > 0: - fmt += ',' + fmt += "," else: - fmt += ':' + fmt += ":" # Remove prefix to ignore from path (if present) to # simplify path. path = remove_path_prefix(path, self.ignore_prefix) - result.append((None, -1, -1, -1, -1, 'note', - fmt.format(path, line), e.allow_dups, None)) + result.append( + (None, -1, -1, -1, -1, "note", fmt.format(path, line), e.allow_dups, None) + ) i -= 1 file = self.simplify_path(e.file) @@ -809,40 +874,95 @@ def render_messages(self, # Report context within a source file. if not self.show_error_context: pass - elif (e.function_or_member != prev_function_or_member or - e.type != prev_type): + elif e.function_or_member != prev_function_or_member or e.type != prev_type: if e.function_or_member is None: if e.type is None: result.append( - (file, -1, -1, -1, -1, 'note', 'At top level:', e.allow_dups, None)) + (file, -1, -1, -1, -1, "note", "At top level:", e.allow_dups, None) + ) else: - result.append((file, -1, -1, -1, -1, 'note', 'In class "{}":'.format( - e.type), e.allow_dups, None)) + result.append( + ( + file, + -1, + -1, + -1, + -1, + "note", + 'In class "{}":'.format(e.type), + e.allow_dups, + None, + ) + ) else: if e.type is None: - result.append((file, -1, -1, -1, -1, 'note', - 'In function "{}":'.format( - e.function_or_member), e.allow_dups, None)) + result.append( + ( + file, + -1, + -1, + -1, + -1, + "note", + 'In function "{}":'.format(e.function_or_member), + e.allow_dups, + None, + ) + ) else: - result.append((file, -1, -1, -1, -1, 'note', - 'In member "{}" of class "{}":'.format( - e.function_or_member, e.type), e.allow_dups, None)) + result.append( + ( + file, + -1, + -1, + -1, + -1, + "note", + 'In member "{}" of class "{}":'.format( + e.function_or_member, e.type + ), + e.allow_dups, + None, + ) + ) elif e.type != prev_type: if e.type is None: result.append( - (file, -1, -1, -1, -1, 'note', 'At top level:', e.allow_dups, None)) + (file, -1, -1, -1, -1, "note", "At top level:", e.allow_dups, None) + ) else: - result.append((file, -1, -1, -1, -1, 'note', - f'In class "{e.type}":', e.allow_dups, None)) + result.append( + (file, -1, -1, -1, -1, "note", f'In class "{e.type}":', e.allow_dups, None) + ) if isinstance(e.message, ErrorMessage): result.append( - (file, e.line, e.column, e.end_line, e.end_column, e.severity, e.message.value, - e.allow_dups, e.code)) + ( + file, + e.line, + e.column, + e.end_line, + e.end_column, + e.severity, + e.message.value, + e.allow_dups, + e.code, + ) + ) else: result.append( - (file, e.line, e.column, e.end_line, e.end_column, e.severity, e.message, - e.allow_dups, e.code)) + ( + file, + e.line, + e.column, + e.end_line, + e.end_column, + e.severity, + e.message, + e.allow_dups, + e.code, + ) + ) prev_import_context = e.import_ctx prev_function_or_member = e.function_or_member @@ -862,9 +982,11 @@ def sort_messages(self, errors: List[ErrorInfo]) -> List[ErrorInfo]: while i < len(errors): i0 = i # Find neighbouring errors with the same context and file. - while (i + 1 < len(errors) and - errors[i + 1].import_ctx == errors[i].import_ctx and - errors[i + 1].file == errors[i].file): + while ( + i + 1 < len(errors) + and errors[i + 1].import_ctx == errors[i].import_ctx + and errors[i + 1].file == errors[i].file + ): i += 1 i += 1 @@ -885,19 +1007,21 @@ def remove_duplicates(self, errors: List[ErrorTuple]) -> List[ErrorTuple]: # Find duplicates, unless duplicates are allowed. if not errors[i][7]: while j >= 0 and errors[j][0] == errors[i][0]: - if errors[j][6].strip() == 'Got:': + if errors[j][6].strip() == "Got:": conflicts_notes = True j -= 1 j = i - 1 - while (j >= 0 and errors[j][0] == errors[i][0] and - errors[j][1] == errors[i][1]): - if (errors[j][5] == errors[i][5] and - # Allow duplicate notes in overload conflicts reporting. - not ((errors[i][5] == 'note' and - errors[i][6].strip() in allowed_duplicates) - or (errors[i][6].strip().startswith('def ') and - conflicts_notes)) and - errors[j][6] == errors[i][6]): # ignore column + while j >= 0 and errors[j][0] == errors[i][0] and errors[j][1] == errors[i][1]: + if ( + errors[j][5] == errors[i][5] + and + # Allow duplicate notes in overload conflicts reporting. + not ( + (errors[i][5] == "note" and errors[i][6].strip() in allowed_duplicates) + or (errors[i][6].strip().startswith("def ") and conflicts_notes) + ) + and errors[j][6] == errors[i][6] + ): # ignore column dup = True break j -= 1 @@ -925,11 +1049,13 @@ class CompileError(Exception): # Can be set in case there was a module with a blocking error module_with_blocker: Optional[str] = None - def __init__(self, - messages: List[str], - use_stdout: bool = False, - module_with_blocker: Optional[str] = None) -> None: - super().__init__('\n'.join(messages)) + def __init__( + self, + messages: List[str], + use_stdout: bool = False, + module_with_blocker: Optional[str] = None, + ) -> None: + super().__init__("\n".join(messages)) self.messages = messages self.use_stdout = use_stdout self.module_with_blocker = module_with_blocker @@ -940,25 +1066,26 @@ def remove_path_prefix(path: str, prefix: Optional[str]) -> str: Otherwise, return path. If path is None, return None. """ if prefix is not None and path.startswith(prefix): - return path[len(prefix):] + return path[len(prefix) :] else: return path -def report_internal_error(err: Exception, - file: Optional[str], - line: int, - errors: Errors, - options: Options, - stdout: Optional[TextIO] = None, - stderr: Optional[TextIO] = None, - ) -> NoReturn: +def report_internal_error( + err: Exception, + file: Optional[str], + line: int, + errors: Errors, + options: Options, + stdout: Optional[TextIO] = None, + stderr: Optional[TextIO] = None, +) -> NoReturn: """Report internal error and exit. This optionally starts pdb or shows a traceback. """ - stdout = (stdout or sys.stdout) - stderr = (stderr or sys.stderr) + stdout = stdout or sys.stdout + stderr = stderr or sys.stderr # Dump out errors so far, they often provide a clue. # But catch unexpected errors rendering them. try: @@ -970,32 +1097,35 @@ def report_internal_error(err: Exception, # Compute file:line prefix for official-looking error messages. if file: if line: - prefix = f'{file}:{line}: ' + prefix = f"{file}:{line}: " else: - prefix = f'{file}: ' + prefix = f"{file}: " else: - prefix = '' + prefix = "" # Print "INTERNAL ERROR" message. - print(f'{prefix}error: INTERNAL ERROR --', - 'Please try using mypy master on GitHub:\n' - 'https://mypy.readthedocs.io/en/stable/common_issues.html' - '#using-a-development-mypy-build', - file=stderr) + print( + f"{prefix}error: INTERNAL ERROR --", + "Please try using mypy master on GitHub:\n" + "https://mypy.readthedocs.io/en/stable/common_issues.html" + "#using-a-development-mypy-build", + file=stderr, + ) if options.show_traceback: - print('Please report a bug at https://github.com/python/mypy/issues', - file=stderr) + print("Please report a bug at https://github.com/python/mypy/issues", file=stderr) else: - print('If this issue continues with mypy master, ' - 'please report a bug at https://github.com/python/mypy/issues', - file=stderr) - print(f'version: {mypy_version}', - file=stderr) + print( + "If this issue continues with mypy master, " + "please report a bug at https://github.com/python/mypy/issues", + file=stderr, + ) + print(f"version: {mypy_version}", file=stderr) # If requested, drop into pdb. This overrides show_tb. if options.pdb: - print('Dropping into pdb', file=stderr) + print("Dropping into pdb", file=stderr) import pdb + pdb.post_mortem(sys.exc_info()[2]) # If requested, print traceback, else print note explaining how to get one. @@ -1003,17 +1133,19 @@ def report_internal_error(err: Exception, raise err if not options.show_traceback: if not options.pdb: - print('{}: note: please use --show-traceback to print a traceback ' - 'when reporting a bug'.format(prefix), - file=stderr) + print( + "{}: note: please use --show-traceback to print a traceback " + "when reporting a bug".format(prefix), + file=stderr, + ) else: tb = traceback.extract_stack()[:-2] tb2 = traceback.extract_tb(sys.exc_info()[2]) - print('Traceback (most recent call last):') + print("Traceback (most recent call last):") for s in traceback.format_list(tb + tb2): - print(s.rstrip('\n')) - print(f'{type(err).__name__}: {err}', file=stdout) - print(f'{prefix}: note: use --pdb to drop into pdb', file=stderr) + print(s.rstrip("\n")) + print(f"{type(err).__name__}: {err}", file=stdout) + print(f"{prefix}: note: use --pdb to drop into pdb", file=stderr) # Exit. The caller has nothing more to say. # We use exit code 2 to signal that this is no ordinary error. diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 630c809b46cae..4515a137ced25 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -1,12 +1,36 @@ -from typing import Dict, Iterable, List, TypeVar, Mapping, cast, Union, Optional, Sequence +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, TypeVar, Union, cast from mypy.types import ( - Type, Instance, CallableType, TypeVisitor, UnboundType, AnyType, - NoneType, Overloaded, TupleType, TypedDictType, UnionType, - ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, TypeVarId, - FunctionLike, TypeVarType, LiteralType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType, TypeVarLikeType, Parameters, ParamSpecFlavor, - UnpackType, TypeVarTupleType, TypeList + AnyType, + CallableType, + DeletedType, + ErasedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecFlavor, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeType, + TypeVarId, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, ) from mypy.typevartuples import split_with_instance, split_with_prefix_and_suffix @@ -50,7 +74,7 @@ def expand_type_by_instance(typ: Type, instance: Instance) -> Type: return expand_type(typ, variables) -F = TypeVar('F', bound=FunctionLike) +F = TypeVar("F", bound=FunctionLike) def freshen_function_type_vars(callee: F) -> F: @@ -76,8 +100,7 @@ def freshen_function_type_vars(callee: F) -> F: return cast(F, fresh) else: assert isinstance(callee, Overloaded) - fresh_overload = Overloaded([freshen_function_type_vars(item) - for item in callee.items]) + fresh_overload = Overloaded([freshen_function_type_vars(item) for item in callee.items]) return cast(F, fresh_overload) @@ -132,11 +155,14 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: # TODO: why does this case even happen? Instances aren't plural. return Instance(inst.type, inst.args, line=inst.line, column=inst.column) elif isinstance(repl, ParamSpecType): - return repl.copy_modified(flavor=t.flavor, prefix=t.prefix.copy_modified( - arg_types=t.prefix.arg_types + repl.prefix.arg_types, - arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds, - arg_names=t.prefix.arg_names + repl.prefix.arg_names, - )) + return repl.copy_modified( + flavor=t.flavor, + prefix=t.prefix.copy_modified( + arg_types=t.prefix.arg_types + repl.prefix.arg_types, + arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds, + arg_names=t.prefix.arg_names + repl.prefix.arg_names, + ), + ) elif isinstance(repl, Parameters) or isinstance(repl, CallableType): # if the paramspec is *P.args or **P.kwargs: if t.flavor != ParamSpecFlavor.BARE: @@ -148,10 +174,12 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: else: return repl else: - return Parameters(t.prefix.arg_types + repl.arg_types, - t.prefix.arg_kinds + repl.arg_kinds, - t.prefix.arg_names + repl.arg_names, - variables=[*t.prefix.variables, *repl.variables]) + return Parameters( + t.prefix.arg_types + repl.arg_types, + t.prefix.arg_kinds + repl.arg_kinds, + t.prefix.arg_names + repl.arg_names, + variables=[*t.prefix.variables, *repl.variables], + ) else: # TODO: should this branch be removed? better not to fail silently return repl @@ -220,12 +248,14 @@ def visit_callable_type(self, t: CallableType) -> Type: arg_kinds=prefix.arg_kinds + t.arg_kinds, arg_names=prefix.arg_names + t.arg_names, ret_type=t.ret_type.accept(self), - type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None)) + type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), + ) - return t.copy_modified(arg_types=self.expand_types(t.arg_types), - ret_type=t.ret_type.accept(self), - type_guard=(t.type_guard.accept(self) - if t.type_guard is not None else None)) + return t.copy_modified( + arg_types=self.expand_types(t.arg_types), + ret_type=t.ret_type.accept(self), + type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), + ) def visit_overloaded(self, t: Overloaded) -> Type: items: List[CallableType] = [] @@ -284,6 +314,7 @@ def visit_union_type(self, t: UnionType) -> Type: # After substituting for type variables in t.items, # some of the resulting types might be subtypes of others. from mypy.typeops import make_simplified_union # asdf + return make_simplified_union(self.expand_types(t.items), t.line, t.column) def visit_partial_type(self, t: PartialType) -> Type: diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index 243bbf024faf0..60b8dffb50d63 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -2,17 +2,41 @@ from typing import Optional +from mypy.fastparse import parse_type_string from mypy.nodes import ( - Expression, NameExpr, MemberExpr, IndexExpr, RefExpr, TupleExpr, IntExpr, FloatExpr, UnaryExpr, - ComplexExpr, ListExpr, StrExpr, BytesExpr, UnicodeExpr, EllipsisExpr, CallExpr, OpExpr, - get_member_expr_fullname + BytesExpr, + CallExpr, + ComplexExpr, + EllipsisExpr, + Expression, + FloatExpr, + IndexExpr, + IntExpr, + ListExpr, + MemberExpr, + NameExpr, + OpExpr, + RefExpr, + StrExpr, + TupleExpr, + UnaryExpr, + UnicodeExpr, + get_member_expr_fullname, ) -from mypy.fastparse import parse_type_string +from mypy.options import Options from mypy.types import ( - Type, UnboundType, TypeList, EllipsisType, AnyType, CallableArgument, TypeOfAny, - RawExpressionType, ProperType, UnionType, ANNOTATED_TYPE_NAMES, + ANNOTATED_TYPE_NAMES, + AnyType, + CallableArgument, + EllipsisType, + ProperType, + RawExpressionType, + Type, + TypeList, + TypeOfAny, + UnboundType, + UnionType, ) -from mypy.options import Options class TypeTranslationError(Exception): @@ -20,7 +44,7 @@ class TypeTranslationError(Exception): def _extract_argument_name(expr: Expression) -> Optional[str]: - if isinstance(expr, NameExpr) and expr.name == 'None': + if isinstance(expr, NameExpr) and expr.name == "None": return None elif isinstance(expr, StrExpr): return expr.value @@ -30,10 +54,12 @@ def _extract_argument_name(expr: Expression) -> Optional[str]: raise TypeTranslationError() -def expr_to_unanalyzed_type(expr: Expression, - options: Optional[Options] = None, - allow_new_syntax: bool = False, - _parent: Optional[Expression] = None) -> ProperType: +def expr_to_unanalyzed_type( + expr: Expression, + options: Optional[Options] = None, + allow_new_syntax: bool = False, + _parent: Optional[Expression] = None, +) -> ProperType: """Translate an expression to the corresponding type. The result is not semantically analyzed. It can be UnboundType or TypeList. @@ -47,10 +73,10 @@ def expr_to_unanalyzed_type(expr: Expression, name: Optional[str] = None if isinstance(expr, NameExpr): name = expr.name - if name == 'True': - return RawExpressionType(True, 'builtins.bool', line=expr.line, column=expr.column) - elif name == 'False': - return RawExpressionType(False, 'builtins.bool', line=expr.line, column=expr.column) + if name == "True": + return RawExpressionType(True, "builtins.bool", line=expr.line, column=expr.column) + elif name == "False": + return RawExpressionType(False, "builtins.bool", line=expr.line, column=expr.column) else: return UnboundType(name, line=expr.line, column=expr.column) elif isinstance(expr, MemberExpr): @@ -76,18 +102,25 @@ def expr_to_unanalyzed_type(expr: Expression, return expr_to_unanalyzed_type(args[0], options, allow_new_syntax, expr) else: - base.args = tuple(expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) - for arg in args) + base.args = tuple( + expr_to_unanalyzed_type(arg, options, allow_new_syntax, expr) for arg in args + ) if not base.args: base.empty_tuple_index = True return base else: raise TypeTranslationError() - elif (isinstance(expr, OpExpr) - and expr.op == '|' - and ((options and options.python_version >= (3, 10)) or allow_new_syntax)): - return UnionType([expr_to_unanalyzed_type(expr.left, options, allow_new_syntax), - expr_to_unanalyzed_type(expr.right, options, allow_new_syntax)]) + elif ( + isinstance(expr, OpExpr) + and expr.op == "|" + and ((options and options.python_version >= (3, 10)) or allow_new_syntax) + ): + return UnionType( + [ + expr_to_unanalyzed_type(expr.left, options, allow_new_syntax), + expr_to_unanalyzed_type(expr.right, options, allow_new_syntax), + ] + ) elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr): c = expr.callee names = [] @@ -102,7 +135,7 @@ def expr_to_unanalyzed_type(expr: Expression, c = c.expr else: raise TypeTranslationError() - arg_const = '.'.join(reversed(names)) + arg_const = ".".join(reversed(names)) # Go through the constructor args to get its name and type. name = None @@ -132,34 +165,43 @@ def expr_to_unanalyzed_type(expr: Expression, raise TypeTranslationError() return CallableArgument(typ, name, arg_const, expr.line, expr.column) elif isinstance(expr, ListExpr): - return TypeList([expr_to_unanalyzed_type(t, options, allow_new_syntax, expr) - for t in expr.items], - line=expr.line, column=expr.column) + return TypeList( + [expr_to_unanalyzed_type(t, options, allow_new_syntax, expr) for t in expr.items], + line=expr.line, + column=expr.column, + ) elif isinstance(expr, StrExpr): - return parse_type_string(expr.value, 'builtins.str', expr.line, expr.column, - assume_str_is_unicode=expr.from_python_3) + return parse_type_string( + expr.value, + "builtins.str", + expr.line, + expr.column, + assume_str_is_unicode=expr.from_python_3, + ) elif isinstance(expr, BytesExpr): - return parse_type_string(expr.value, 'builtins.bytes', expr.line, expr.column, - assume_str_is_unicode=False) + return parse_type_string( + expr.value, "builtins.bytes", expr.line, expr.column, assume_str_is_unicode=False + ) elif isinstance(expr, UnicodeExpr): - return parse_type_string(expr.value, 'builtins.unicode', expr.line, expr.column, - assume_str_is_unicode=True) + return parse_type_string( + expr.value, "builtins.unicode", expr.line, expr.column, assume_str_is_unicode=True + ) elif isinstance(expr, UnaryExpr): typ = expr_to_unanalyzed_type(expr.expr, options, allow_new_syntax) if isinstance(typ, RawExpressionType): - if isinstance(typ.literal_value, int) and expr.op == '-': + if isinstance(typ.literal_value, int) and expr.op == "-": typ.literal_value *= -1 return typ raise TypeTranslationError() elif isinstance(expr, IntExpr): - return RawExpressionType(expr.value, 'builtins.int', line=expr.line, column=expr.column) + return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column) elif isinstance(expr, FloatExpr): # Floats are not valid parameters for RawExpressionType , so we just # pass in 'None' for now. We'll report the appropriate error at a later stage. - return RawExpressionType(None, 'builtins.float', line=expr.line, column=expr.column) + return RawExpressionType(None, "builtins.float", line=expr.line, column=expr.column) elif isinstance(expr, ComplexExpr): # Same thing as above with complex numbers. - return RawExpressionType(None, 'builtins.complex', line=expr.line, column=expr.column) + return RawExpressionType(None, "builtins.complex", line=expr.line, column=expr.column) elif isinstance(expr, EllipsisExpr): return EllipsisType(expr.line) else: diff --git a/mypy/fastparse.py b/mypy/fastparse.py index b5b31a60b539d..a73fdf0717bcb 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1,51 +1,121 @@ -from mypy.util import unnamed_function import copy import re import sys -import warnings - import typing # for typing.Type, which conflicts with types.Type -from typing import ( - Tuple, Union, TypeVar, Callable, Sequence, Optional, Any, Dict, cast, List -) +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, cast from typing_extensions import Final, Literal, overload -from mypy.sharedparse import ( - special_function_elide_names, argument_elide_name, -) +from mypy import defaults, errorcodes as codes, message_registry +from mypy.errors import Errors from mypy.nodes import ( - MypyFile, Node, ImportBase, Import, ImportAll, ImportFrom, FuncDef, - OverloadedFuncDef, OverloadPart, - ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, - ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, - DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, MatchStmt, - TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, - DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr, - FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, LambdaExpr, ComparisonExpr, AssignmentExpr, - StarExpr, YieldFromExpr, NonlocalDecl, DictionaryComprehension, - SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - AwaitExpr, TempNode, RefExpr, Expression, Statement, - ArgKind, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR2, - check_arg_names, + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + ArgKind, + Argument, + AssertStmt, + AssignmentExpr, + AssignmentStmt, + AwaitExpr, + Block, + BreakStmt, + BytesExpr, + CallExpr, + ClassDef, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + ContinueStmt, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + Expression, + ExpressionStmt, FakeInfo, + FloatExpr, + ForStmt, + FuncDef, + GeneratorExpr, + GlobalDecl, + IfStmt, + Import, + ImportAll, + ImportBase, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MatchStmt, + MemberExpr, + MypyFile, + NameExpr, + Node, + NonlocalDecl, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + OverloadPart, + PassStmt, + RaiseStmt, + RefExpr, + ReturnStmt, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + Statement, + StrExpr, + SuperExpr, + TempNode, + TryStmt, + TupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + YieldExpr, + YieldFromExpr, + check_arg_names, ) +from mypy.options import Options from mypy.patterns import ( - AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, - ClassPattern, SingletonPattern + AsPattern, + ClassPattern, + MappingPattern, + OrPattern, + SequencePattern, + SingletonPattern, + StarredPattern, + ValuePattern, ) +from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable +from mypy.sharedparse import argument_elide_name, special_function_elide_names from mypy.types import ( - Type, CallableType, AnyType, UnboundType, TupleType, TypeList, EllipsisType, CallableArgument, - TypeOfAny, Instance, RawExpressionType, ProperType, UnionType, + AnyType, + CallableArgument, + CallableType, + EllipsisType, + Instance, + ProperType, + RawExpressionType, + TupleType, + Type, + TypeList, + TypeOfAny, + UnboundType, + UnionType, ) -from mypy import defaults -from mypy import message_registry, errorcodes as codes -from mypy.errors import Errors -from mypy.options import Options -from mypy.reachability import infer_reachability_of_if_statement, mark_block_unreachable -from mypy.util import bytes_to_human_readable_repr +from mypy.util import bytes_to_human_readable_repr, unnamed_function try: # pull this into a final variable to make mypyc be quiet about the @@ -55,33 +125,43 @@ # Check if we can use the stdlib ast module instead of typed_ast. if sys.version_info >= (3, 8): import ast as ast3 - assert 'kind' in ast3.Constant._fields, \ - f"This 3.8.0 alpha ({sys.version.split()[0]}) is too old; 3.8.0a3 required" + + assert ( + "kind" in ast3.Constant._fields + ), f"This 3.8.0 alpha ({sys.version.split()[0]}) is too old; 3.8.0a3 required" # TODO: Num, Str, Bytes, NameConstant, Ellipsis are deprecated in 3.8. # TODO: Index, ExtSlice are deprecated in 3.9. from ast import ( AST, - Call, - FunctionType, - Name, Attribute, + Bytes, + Call, Ellipsis as ast3_Ellipsis, - Starred, - NameConstant, Expression as ast3_Expression, - Str, - Bytes, + FunctionType, Index, + Name, + NameConstant, Num, + Starred, + Str, UnaryOp, USub, ) - def ast3_parse(source: Union[str, bytes], filename: str, mode: str, - feature_version: int = PY_MINOR_VERSION) -> AST: - return ast3.parse(source, filename, mode, - type_comments=True, # This works the magic - feature_version=feature_version) + def ast3_parse( + source: Union[str, bytes], + filename: str, + mode: str, + feature_version: int = PY_MINOR_VERSION, + ) -> AST: + return ast3.parse( + source, + filename, + mode, + type_comments=True, # This works the magic + feature_version=feature_version, + ) NamedExpr = ast3.NamedExpr Constant = ast3.Constant @@ -89,24 +169,28 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, from typed_ast import ast3 from typed_ast.ast3 import ( AST, - Call, - FunctionType, - Name, Attribute, + Bytes, + Call, Ellipsis as ast3_Ellipsis, - Starred, - NameConstant, Expression as ast3_Expression, - Str, - Bytes, + FunctionType, Index, + Name, + NameConstant, Num, + Starred, + Str, UnaryOp, USub, ) - def ast3_parse(source: Union[str, bytes], filename: str, mode: str, - feature_version: int = PY_MINOR_VERSION) -> AST: + def ast3_parse( + source: Union[str, bytes], + filename: str, + mode: str, + feature_version: int = PY_MINOR_VERSION, + ) -> AST: return ast3.parse(source, filename, mode, feature_version=feature_version) # These don't exist before 3.8 @@ -139,17 +223,21 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, try: from typed_ast import ast35 # type: ignore[attr-defined] # noqa: F401 except ImportError: - print('The typed_ast package is not installed.\n' - 'You can install it with `python3 -m pip install typed-ast`.', - file=sys.stderr) + print( + "The typed_ast package is not installed.\n" + "You can install it with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) else: - print('You need a more recent version of the typed_ast package.\n' - 'You can update to the latest version with ' - '`python3 -m pip install -U typed-ast`.', - file=sys.stderr) + print( + "You need a more recent version of the typed_ast package.\n" + "You can update to the latest version with " + "`python3 -m pip install -U typed-ast`.", + file=sys.stderr, + ) sys.exit(1) -N = TypeVar('N', bound=Node) +N = TypeVar("N", bound=Node) # There is no way to create reasonable fallbacks at this stage, # they must be patched later. @@ -160,14 +248,16 @@ def ast3_parse(source: Union[str, bytes], filename: str, mode: str, INVALID_TYPE_IGNORE: Final = 'Invalid "type: ignore" comment' -TYPE_IGNORE_PATTERN: Final = re.compile(r'[^#]*#\s*type:\s*ignore\s*(.*)') +TYPE_IGNORE_PATTERN: Final = re.compile(r"[^#]*#\s*type:\s*ignore\s*(.*)") -def parse(source: Union[str, bytes], - fnam: str, - module: Optional[str], - errors: Optional[Errors] = None, - options: Optional[Options] = None) -> MypyFile: +def parse( + source: Union[str, bytes], + fnam: str, + module: Optional[str], + errors: Optional[Errors] = None, + options: Optional[Options] = None, +) -> MypyFile: """Parse a source file, without doing any semantic analysis. @@ -181,7 +271,7 @@ def parse(source: Union[str, bytes], if options is None: options = Options() errors.set_file(fnam, module) - is_stub_file = fnam.endswith('.pyi') + is_stub_file = fnam.endswith(".pyi") try: if is_stub_file: feature_version = defaults.PYTHON3_VERSION[1] @@ -191,12 +281,9 @@ def parse(source: Union[str, bytes], # Disable deprecation warnings about \u with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - ast = ast3_parse(source, fnam, 'exec', feature_version=feature_version) + ast = ast3_parse(source, fnam, "exec", feature_version=feature_version) - tree = ASTConverter(options=options, - is_stub=is_stub_file, - errors=errors, - ).visit(ast) + tree = ASTConverter(options=options, is_stub=is_stub_file, errors=errors).visit(ast) tree.path = fnam tree.is_stub = is_stub_file except SyntaxError as e: @@ -207,8 +294,13 @@ def parse(source: Union[str, bytes], # start of the f-string. This would be misleading, as mypy will report the error as the # lineno within the file. e.lineno = None - errors.report(e.lineno if e.lineno is not None else -1, e.offset, e.msg, blocker=True, - code=codes.SYNTAX) + errors.report( + e.lineno if e.lineno is not None else -1, + e.offset, + e.msg, + blocker=True, + code=codes.SYNTAX, + ) tree = MypyFile([], [], False, {}) if raise_on_error and errors.is_errors(): @@ -225,28 +317,29 @@ def parse_type_ignore_tag(tag: Optional[str]) -> Optional[List[str]]: * list of ignored error codes if a tag was found * None if the tag was invalid. """ - if not tag or tag.strip() == '' or tag.strip().startswith('#'): + if not tag or tag.strip() == "" or tag.strip().startswith("#"): # No tag -- ignore all errors. return [] - m = re.match(r'\s*\[([^]#]*)\]\s*(#.*)?$', tag) + m = re.match(r"\s*\[([^]#]*)\]\s*(#.*)?$", tag) if m is None: # Invalid "# type: ignore" comment. return None - return [code.strip() for code in m.group(1).split(',')] + return [code.strip() for code in m.group(1).split(",")] -def parse_type_comment(type_comment: str, - line: int, - column: int, - errors: Optional[Errors], - assume_str_is_unicode: bool = True, - ) -> Tuple[Optional[List[str]], Optional[ProperType]]: +def parse_type_comment( + type_comment: str, + line: int, + column: int, + errors: Optional[Errors], + assume_str_is_unicode: bool = True, +) -> Tuple[Optional[List[str]], Optional[ProperType]]: """Parse type portion of a type comment (+ optional type ignore). Return (ignore info, parsed type). """ try: - typ = ast3_parse(type_comment, '', 'eval') + typ = ast3_parse(type_comment, "", "eval") except SyntaxError: if errors is not None: stripped_type = type_comment.split("#", 2)[0].strip() @@ -269,16 +362,23 @@ def parse_type_comment(type_comment: str, else: ignored = None assert isinstance(typ, ast3_Expression) - converted = TypeConverter(errors, - line=line, - override_column=column, - assume_str_is_unicode=assume_str_is_unicode, - is_evaluated=False).visit(typ.body) + converted = TypeConverter( + errors, + line=line, + override_column=column, + assume_str_is_unicode=assume_str_is_unicode, + is_evaluated=False, + ).visit(typ.body) return ignored, converted -def parse_type_string(expr_string: str, expr_fallback_name: str, - line: int, column: int, assume_str_is_unicode: bool = True) -> ProperType: +def parse_type_string( + expr_string: str, + expr_fallback_name: str, + line: int, + column: int, + assume_str_is_unicode: bool = True, +) -> ProperType: """Parses a type that was originally present inside of an explicit string, byte string, or unicode string. @@ -294,8 +394,13 @@ def parse_type_string(expr_string: str, expr_fallback_name: str, code with unicode_literals...) and setting `assume_str_is_unicode` accordingly. """ try: - _, node = parse_type_comment(expr_string.strip(), line=line, column=column, errors=None, - assume_str_is_unicode=assume_str_is_unicode) + _, node = parse_type_comment( + expr_string.strip(), + line=line, + column=column, + errors=None, + assume_str_is_unicode=assume_str_is_unicode, + ) if isinstance(node, UnboundType) and node.original_str_expr is None: node.original_str_expr = expr_string node.original_str_fallback = expr_fallback_name @@ -312,18 +417,15 @@ def parse_type_string(expr_string: str, expr_fallback_name: str, def is_no_type_check_decorator(expr: ast3.expr) -> bool: if isinstance(expr, Name): - return expr.id == 'no_type_check' + return expr.id == "no_type_check" elif isinstance(expr, Attribute): if isinstance(expr.value, Name): - return expr.value.id == 'typing' and expr.attr == 'no_type_check' + return expr.value.id == "typing" and expr.attr == "no_type_check" return False class ASTConverter: - def __init__(self, - options: Options, - is_stub: bool, - errors: Errors) -> None: + def __init__(self, options: Options, is_stub: bool, errors: Errors) -> None: # 'C' for class, 'F' for function self.class_and_function_stack: List[Literal["C", "F"]] = [] self.imports: List[ImportBase] = [] @@ -338,14 +440,16 @@ def __init__(self, self.visitor_cache: Dict[type, Callable[[Optional[AST]], Any]] = {} def note(self, msg: str, line: int, column: int) -> None: - self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) - - def fail(self, - msg: str, - line: int, - column: int, - blocker: bool = True, - code: codes.ErrorCode = codes.SYNTAX) -> None: + self.errors.report(line, column, msg, severity="note", code=codes.SYNTAX) + + def fail( + self, + msg: str, + line: int, + column: int, + blocker: bool = True, + code: codes.ErrorCode = codes.SYNTAX, + ) -> None: if blocker or not self.options.ignore_errors: self.errors.report(line, column, msg, blocker=blocker, code=code) @@ -364,7 +468,7 @@ def visit(self, node: Optional[AST]) -> Any: typeobj = type(node) visitor = self.visitor_cache.get(typeobj) if visitor is None: - method = 'visit_' + node.__class__.__name__ + method = "visit_" + node.__class__.__name__ visitor = getattr(self, method) self.visitor_cache[typeobj] = visitor return visitor(node) @@ -388,20 +492,27 @@ def translate_expr_list(self, l: Sequence[AST]) -> List[Expression]: return cast(List[Expression], self.translate_opt_expr_list(l)) def get_lineno(self, node: Union[ast3.expr, ast3.stmt]) -> int: - if (isinstance(node, (ast3.AsyncFunctionDef, ast3.ClassDef, ast3.FunctionDef)) - and node.decorator_list): + if ( + isinstance(node, (ast3.AsyncFunctionDef, ast3.ClassDef, ast3.FunctionDef)) + and node.decorator_list + ): return node.decorator_list[0].lineno return node.lineno - def translate_stmt_list(self, - stmts: Sequence[ast3.stmt], - ismodule: bool = False) -> List[Statement]: + def translate_stmt_list( + self, stmts: Sequence[ast3.stmt], ismodule: bool = False + ) -> List[Statement]: # A "# type: ignore" comment before the first statement of a module # ignores the whole module: - if (ismodule and stmts and self.type_ignores - and min(self.type_ignores) < self.get_lineno(stmts[0])): + if ( + ismodule + and stmts + and self.type_ignores + and min(self.type_ignores) < self.get_lineno(stmts[0]) + ): self.errors.used_ignored_lines[self.errors.file][min(self.type_ignores)].append( - codes.FILE.code) + codes.FILE.code + ) block = Block(self.fix_function_overloads(self.translate_stmt_list(stmts))) mark_block_unreachable(block) return [block] @@ -413,61 +524,58 @@ def translate_stmt_list(self, return res - def translate_type_comment(self, - n: Union[ast3.stmt, ast3.arg], - type_comment: Optional[str]) -> Optional[ProperType]: + def translate_type_comment( + self, n: Union[ast3.stmt, ast3.arg], type_comment: Optional[str] + ) -> Optional[ProperType]: if type_comment is None: return None else: lineno = n.lineno - extra_ignore, typ = parse_type_comment(type_comment, - lineno, - n.col_offset, - self.errors) + extra_ignore, typ = parse_type_comment(type_comment, lineno, n.col_offset, self.errors) if extra_ignore is not None: self.type_ignores[lineno] = extra_ignore return typ op_map: Final[Dict[typing.Type[AST], str]] = { - ast3.Add: '+', - ast3.Sub: '-', - ast3.Mult: '*', - ast3.MatMult: '@', - ast3.Div: '/', - ast3.Mod: '%', - ast3.Pow: '**', - ast3.LShift: '<<', - ast3.RShift: '>>', - ast3.BitOr: '|', - ast3.BitXor: '^', - ast3.BitAnd: '&', - ast3.FloorDiv: '//' + ast3.Add: "+", + ast3.Sub: "-", + ast3.Mult: "*", + ast3.MatMult: "@", + ast3.Div: "/", + ast3.Mod: "%", + ast3.Pow: "**", + ast3.LShift: "<<", + ast3.RShift: ">>", + ast3.BitOr: "|", + ast3.BitXor: "^", + ast3.BitAnd: "&", + ast3.FloorDiv: "//", } def from_operator(self, op: ast3.operator) -> str: op_name = ASTConverter.op_map.get(type(op)) if op_name is None: - raise RuntimeError('Unknown operator ' + str(type(op))) + raise RuntimeError("Unknown operator " + str(type(op))) else: return op_name comp_op_map: Final[Dict[typing.Type[AST], str]] = { - ast3.Gt: '>', - ast3.Lt: '<', - ast3.Eq: '==', - ast3.GtE: '>=', - ast3.LtE: '<=', - ast3.NotEq: '!=', - ast3.Is: 'is', - ast3.IsNot: 'is not', - ast3.In: 'in', - ast3.NotIn: 'not in' + ast3.Gt: ">", + ast3.Lt: "<", + ast3.Eq: "==", + ast3.GtE: ">=", + ast3.LtE: "<=", + ast3.NotEq: "!=", + ast3.Is: "is", + ast3.IsNot: "is not", + ast3.In: "in", + ast3.NotIn: "not in", } def from_comp_operator(self, op: ast3.cmpop) -> str: op_name = ASTConverter.comp_op_map.get(type(op)) if op_name is None: - raise RuntimeError('Unknown comparison operator ' + str(type(op))) + raise RuntimeError("Unknown comparison operator " + str(type(op))) else: return op_name @@ -502,12 +610,16 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: # Check IfStmt block to determine if function overloads can be merged if_overload_name = self._check_ifstmt_for_overloads(stmt, current_overload_name) if if_overload_name is not None: - if_block_with_overload, if_unknown_truth_value = \ - self._get_executable_if_block_with_overloads(stmt) + ( + if_block_with_overload, + if_unknown_truth_value, + ) = self._get_executable_if_block_with_overloads(stmt) - if (current_overload_name is not None - and isinstance(stmt, (Decorator, FuncDef)) - and stmt.name == current_overload_name): + if ( + current_overload_name is not None + and isinstance(stmt, (Decorator, FuncDef)) + and stmt.name == current_overload_name + ): if last_if_stmt is not None: skipped_if_stmts.append(last_if_stmt) if last_if_overload is not None: @@ -547,9 +659,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload.append(last_if_overload) last_if_stmt, last_if_overload = None, None if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef): - skipped_if_stmts.extend( - cast(List[IfStmt], if_block_with_overload.body[:-1]) - ) + skipped_if_stmts.extend(cast(List[IfStmt], if_block_with_overload.body[:-1])) current_overload.extend(if_block_with_overload.body[-1].items) else: current_overload.append( @@ -587,10 +697,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: if isinstance(stmt, Decorator) and not unnamed_function(stmt.name): current_overload = [stmt] current_overload_name = stmt.name - elif ( - isinstance(stmt, IfStmt) - and if_overload_name is not None - ): + elif isinstance(stmt, IfStmt) and if_overload_name is not None: current_overload = [] current_overload_name = if_overload_name last_if_stmt = stmt @@ -601,7 +708,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: ) last_if_overload = cast( Union[Decorator, FuncDef, OverloadedFuncDef], - if_block_with_overload.body[-1] + if_block_with_overload.body[-1], ) last_if_unknown_truth_value = if_unknown_truth_value else: @@ -642,15 +749,13 @@ def _check_ifstmt_for_overloads( ) or len(stmt.body[0].body) > 1 and isinstance(stmt.body[0].body[-1], OverloadedFuncDef) - and all( - self._is_stripped_if_stmt(if_stmt) - for if_stmt in stmt.body[0].body[:-1] - ) + and all(self._is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body[0].body[:-1]) ): return None overload_name = cast( - Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1]).name + Union[Decorator, FuncDef, OverloadedFuncDef], stmt.body[0].body[-1] + ).name if stmt.else_body is None: return overload_name @@ -663,9 +768,8 @@ def _check_ifstmt_for_overloads( return overload_name if ( isinstance(stmt.else_body.body[0], IfStmt) - and self._check_ifstmt_for_overloads( - stmt.else_body.body[0], current_overload_name - ) == overload_name + and self._check_ifstmt_for_overloads(stmt.else_body.body[0], current_overload_name) + == overload_name ): return overload_name @@ -682,10 +786,7 @@ def _get_executable_if_block_with_overloads( i.e. the truth value is unknown. """ infer_reachability_of_if_statement(stmt, self.options) - if ( - stmt.else_body is None - and stmt.body[0].is_unreachable is True - ): + if stmt.else_body is None and stmt.body[0].is_unreachable is True: # always False condition with no else return None, None if ( @@ -740,7 +841,7 @@ def _is_stripped_if_stmt(self, stmt: Statement) -> bool: return self._is_stripped_if_stmt(stmt.else_body.body[0]) def in_method_scope(self) -> bool: - return self.class_and_function_stack[-2:] == ['C', 'F'] + return self.class_and_function_stack[-2:] == ["C", "F"] def translate_module_id(self, id: str) -> str: """Return the actual, internal module id for a source text id. @@ -748,11 +849,11 @@ def translate_module_id(self, id: str) -> str: For example, translate '__builtin__' in Python 2 to 'builtins'. """ if id == self.options.custom_typing_module: - return 'typing' - elif id == '__builtin__' and self.options.python_version[0] == 2: + return "typing" + elif id == "__builtin__" and self.options.python_version[0] == 2: # HACK: __builtin__ in Python 2 is aliases to builtins. However, the implementation # is named __builtin__.py (there is another layer of translation elsewhere). - return 'builtins' + return "builtins" return id def visit_Module(self, mod: ast3.Module) -> MypyFile: @@ -764,11 +865,7 @@ def visit_Module(self, mod: ast3.Module) -> MypyFile: else: self.fail(INVALID_TYPE_IGNORE, ti.lineno, -1) body = self.fix_function_overloads(self.translate_stmt_list(mod.body, ismodule=True)) - return MypyFile(body, - self.imports, - False, - self.type_ignores, - ) + return MypyFile(body, self.imports, False, self.type_ignores) # --- stmt --- # FunctionDef(identifier name, arguments args, @@ -783,12 +880,14 @@ def visit_FunctionDef(self, n: ast3.FunctionDef) -> Union[FuncDef, Decorator]: def visit_AsyncFunctionDef(self, n: ast3.AsyncFunctionDef) -> Union[FuncDef, Decorator]: return self.do_func_def(n, is_coroutine=True) - def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], - is_coroutine: bool = False) -> Union[FuncDef, Decorator]: + def do_func_def( + self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], is_coroutine: bool = False + ) -> Union[FuncDef, Decorator]: """Helper shared between visit_FunctionDef and visit_AsyncFunctionDef.""" - self.class_and_function_stack.append('F') - no_type_check = bool(n.decorator_list and - any(is_no_type_check_decorator(d) for d in n.decorator_list)) + self.class_and_function_stack.append("F") + no_type_check = bool( + n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list) + ) lineno = n.lineno args = self.transform_args(n.args, lineno, no_type_check=no_type_check) @@ -805,30 +904,33 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], return_type = None elif n.type_comment is not None: try: - func_type_ast = ast3_parse(n.type_comment, '', 'func_type') + func_type_ast = ast3_parse(n.type_comment, "", "func_type") assert isinstance(func_type_ast, FunctionType) # for ellipsis arg - if (len(func_type_ast.argtypes) == 1 and - isinstance(func_type_ast.argtypes[0], ast3_Ellipsis)): + if len(func_type_ast.argtypes) == 1 and isinstance( + func_type_ast.argtypes[0], ast3_Ellipsis + ): if n.returns: # PEP 484 disallows both type annotations and type comments self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, lineno, n.col_offset) - arg_types = [a.type_annotation - if a.type_annotation is not None - else AnyType(TypeOfAny.unannotated) - for a in args] + arg_types = [ + a.type_annotation + if a.type_annotation is not None + else AnyType(TypeOfAny.unannotated) + for a in args + ] else: # PEP 484 disallows both type annotations and type comments if n.returns or any(a.type_annotation is not None for a in args): self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, lineno, n.col_offset) - translated_args = (TypeConverter(self.errors, - line=lineno, - override_column=n.col_offset) - .translate_expr_list(func_type_ast.argtypes)) - arg_types = [a if a is not None else AnyType(TypeOfAny.unannotated) - for a in translated_args] - return_type = TypeConverter(self.errors, - line=lineno).visit(func_type_ast.returns) + translated_args = TypeConverter( + self.errors, line=lineno, override_column=n.col_offset + ).translate_expr_list(func_type_ast.argtypes) + arg_types = [ + a if a is not None else AnyType(TypeOfAny.unannotated) + for a in translated_args + ] + return_type = TypeConverter(self.errors, line=lineno).visit(func_type_ast.returns) # add implicit self type if self.in_method_scope() and len(arg_types) < len(args): @@ -838,44 +940,46 @@ def do_func_def(self, n: Union[ast3.FunctionDef, ast3.AsyncFunctionDef], err_msg = f'{TYPE_COMMENT_SYNTAX_ERROR} "{stripped_type}"' self.fail(err_msg, lineno, n.col_offset) if n.type_comment and n.type_comment[0] not in ["(", "#"]: - self.note('Suggestion: wrap argument types in parentheses', - lineno, n.col_offset) + self.note( + "Suggestion: wrap argument types in parentheses", lineno, n.col_offset + ) arg_types = [AnyType(TypeOfAny.from_error)] * len(args) return_type = AnyType(TypeOfAny.from_error) else: arg_types = [a.type_annotation for a in args] - return_type = TypeConverter(self.errors, line=n.returns.lineno - if n.returns else lineno).visit(n.returns) + return_type = TypeConverter( + self.errors, line=n.returns.lineno if n.returns else lineno + ).visit(n.returns) for arg, arg_type in zip(args, arg_types): self.set_type_optional(arg_type, arg.initializer) func_type = None if any(arg_types) or return_type: - if len(arg_types) != 1 and any(isinstance(t, EllipsisType) - for t in arg_types): - self.fail("Ellipses cannot accompany other argument types " - "in function type signature", lineno, n.col_offset) + if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types): + self.fail( + "Ellipses cannot accompany other argument types " "in function type signature", + lineno, + n.col_offset, + ) elif len(arg_types) > len(arg_kinds): - self.fail('Type signature has too many arguments', lineno, n.col_offset, - blocker=False) + self.fail( + "Type signature has too many arguments", lineno, n.col_offset, blocker=False + ) elif len(arg_types) < len(arg_kinds): - self.fail('Type signature has too few arguments', lineno, n.col_offset, - blocker=False) + self.fail( + "Type signature has too few arguments", lineno, n.col_offset, blocker=False + ) else: - func_type = CallableType([a if a is not None else - AnyType(TypeOfAny.unannotated) for a in arg_types], - arg_kinds, - arg_names, - return_type if return_type is not None else - AnyType(TypeOfAny.unannotated), - _dummy_fallback) - - func_def = FuncDef( - n.name, - args, - self.as_required_block(n.body, lineno), - func_type) + func_type = CallableType( + [a if a is not None else AnyType(TypeOfAny.unannotated) for a in arg_types], + arg_kinds, + arg_names, + return_type if return_type is not None else AnyType(TypeOfAny.unannotated), + _dummy_fallback, + ) + + func_def = FuncDef(n.name, args, self.as_required_block(n.body, lineno), func_type) if isinstance(func_def.type, CallableType): # semanal.py does some in-place modifications we want to avoid func_def.unanalyzed_type = func_def.type.copy_modified() @@ -919,15 +1023,13 @@ def set_type_optional(self, type: Optional[Type], initializer: Optional[Expressi if self.options.no_implicit_optional: return # Indicate that type should be wrapped in an Optional if arg is initialized to None. - optional = isinstance(initializer, NameExpr) and initializer.name == 'None' + optional = isinstance(initializer, NameExpr) and initializer.name == "None" if isinstance(type, UnboundType): type.optional = optional - def transform_args(self, - args: ast3.arguments, - line: int, - no_type_check: bool = False, - ) -> List[Argument]: + def transform_args( + self, args: ast3.arguments, line: int, no_type_check: bool = False + ) -> List[Argument]: new_args = [] names: List[ast3.arg] = [] posonlyargs = getattr(args, "posonlyargs", cast(List[ast3.arg], [])) @@ -953,11 +1055,11 @@ def transform_args(self, # keyword-only arguments with defaults for a, kd in zip(args.kwonlyargs, args.kw_defaults): - new_args.append(self.make_argument( - a, - kd, - ARG_NAMED if kd is None else ARG_NAMED_OPT, - no_type_check)) + new_args.append( + self.make_argument( + a, kd, ARG_NAMED if kd is None else ARG_NAMED_OPT, no_type_check + ) + ) names.append(a) # **kwarg @@ -969,8 +1071,14 @@ def transform_args(self, return new_args - def make_argument(self, arg: ast3.arg, default: Optional[ast3.expr], kind: ArgKind, - no_type_check: bool, pos_only: bool = False) -> Argument: + def make_argument( + self, + arg: ast3.arg, + default: Optional[ast3.expr], + kind: ArgKind, + no_type_check: bool, + pos_only: bool = False, + ) -> Argument: if no_type_check: arg_type = None else: @@ -997,16 +1105,17 @@ def fail_arg(self, msg: str, arg: ast3.arg) -> None: # stmt* body, # expr* decorator_list) def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: - self.class_and_function_stack.append('C') - keywords = [(kw.arg, self.visit(kw.value)) - for kw in n.keywords if kw.arg] - - cdef = ClassDef(n.name, - self.as_required_block(n.body, n.lineno), - None, - self.translate_expr_list(n.bases), - metaclass=dict(keywords).get('metaclass'), - keywords=keywords) + self.class_and_function_stack.append("C") + keywords = [(kw.arg, self.visit(kw.value)) for kw in n.keywords if kw.arg] + + cdef = ClassDef( + n.name, + self.as_required_block(n.body, n.lineno), + None, + self.translate_expr_list(n.bases), + metaclass=dict(keywords).get("metaclass"), + keywords=keywords, + ) cdef.decorators = self.translate_expr_list(n.decorator_list) # Set end_lineno to the old mypy 0.700 lineno, in order to keep # existing "# type: ignore" comments working: @@ -1060,63 +1169,75 @@ def visit_AnnAssign(self, n: ast3.AnnAssign) -> AssignmentStmt: # AugAssign(expr target, operator op, expr value) def visit_AugAssign(self, n: ast3.AugAssign) -> OperatorAssignmentStmt: - s = OperatorAssignmentStmt(self.from_operator(n.op), - self.visit(n.target), - self.visit(n.value)) + s = OperatorAssignmentStmt( + self.from_operator(n.op), self.visit(n.target), self.visit(n.value) + ) return self.set_line(s, n) # For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) def visit_For(self, n: ast3.For) -> ForStmt: target_type = self.translate_type_comment(n, n.type_comment) - node = ForStmt(self.visit(n.target), - self.visit(n.iter), - self.as_required_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno), - target_type) + node = ForStmt( + self.visit(n.target), + self.visit(n.iter), + self.as_required_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno), + target_type, + ) return self.set_line(node, n) # AsyncFor(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) def visit_AsyncFor(self, n: ast3.AsyncFor) -> ForStmt: target_type = self.translate_type_comment(n, n.type_comment) - node = ForStmt(self.visit(n.target), - self.visit(n.iter), - self.as_required_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno), - target_type) + node = ForStmt( + self.visit(n.target), + self.visit(n.iter), + self.as_required_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno), + target_type, + ) node.is_async = True return self.set_line(node, n) # While(expr test, stmt* body, stmt* orelse) def visit_While(self, n: ast3.While) -> WhileStmt: - node = WhileStmt(self.visit(n.test), - self.as_required_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno)) + node = WhileStmt( + self.visit(n.test), + self.as_required_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno), + ) return self.set_line(node, n) # If(expr test, stmt* body, stmt* orelse) def visit_If(self, n: ast3.If) -> IfStmt: lineno = n.lineno - node = IfStmt([self.visit(n.test)], - [self.as_required_block(n.body, lineno)], - self.as_block(n.orelse, lineno)) + node = IfStmt( + [self.visit(n.test)], + [self.as_required_block(n.body, lineno)], + self.as_block(n.orelse, lineno), + ) return self.set_line(node, n) # With(withitem* items, stmt* body, string? type_comment) def visit_With(self, n: ast3.With) -> WithStmt: target_type = self.translate_type_comment(n, n.type_comment) - node = WithStmt([self.visit(i.context_expr) for i in n.items], - [self.visit(i.optional_vars) for i in n.items], - self.as_required_block(n.body, n.lineno), - target_type) + node = WithStmt( + [self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_required_block(n.body, n.lineno), + target_type, + ) return self.set_line(node, n) # AsyncWith(withitem* items, stmt* body, string? type_comment) def visit_AsyncWith(self, n: ast3.AsyncWith) -> WithStmt: target_type = self.translate_type_comment(n, n.type_comment) - s = WithStmt([self.visit(i.context_expr) for i in n.items], - [self.visit(i.optional_vars) for i in n.items], - self.as_required_block(n.body, n.lineno), - target_type) + s = WithStmt( + [self.visit(i.context_expr) for i in n.items], + [self.visit(i.optional_vars) for i in n.items], + self.as_required_block(n.body, n.lineno), + target_type, + ) s.is_async = True return self.set_line(s, n) @@ -1133,12 +1254,14 @@ def visit_Try(self, n: ast3.Try) -> TryStmt: types = [self.visit(h.type) for h in n.handlers] handlers = [self.as_required_block(h.body, h.lineno) for h in n.handlers] - node = TryStmt(self.as_required_block(n.body, n.lineno), - vs, - types, - handlers, - self.as_block(n.orelse, n.lineno), - self.as_block(n.finalbody, n.lineno)) + node = TryStmt( + self.as_required_block(n.body, n.lineno), + vs, + types, + handlers, + self.as_block(n.orelse, n.lineno), + self.as_block(n.finalbody, n.lineno), + ) return self.set_line(node, n) # Assert(expr test, expr? msg) @@ -1165,13 +1288,15 @@ def visit_Import(self, n: ast3.Import) -> Import: # ImportFrom(identifier? module, alias* names, int? level) def visit_ImportFrom(self, n: ast3.ImportFrom) -> ImportBase: assert n.level is not None - if len(n.names) == 1 and n.names[0].name == '*': - mod = n.module if n.module is not None else '' + if len(n.names) == 1 and n.names[0].name == "*": + mod = n.module if n.module is not None else "" i: ImportBase = ImportAll(mod, n.level) else: - i = ImportFrom(self.translate_module_id(n.module) if n.module is not None else '', - n.level, - [(a.name, a.asname) for a in n.names]) + i = ImportFrom( + self.translate_module_id(n.module) if n.module is not None else "", + n.level, + [(a.name, a.asname) for a in n.names], + ) self.imports.append(i) return self.set_line(i, n) @@ -1218,11 +1343,11 @@ def visit_BoolOp(self, n: ast3.BoolOp) -> OpExpr: assert len(n.values) >= 2 op_node = n.op if isinstance(op_node, ast3.And): - op = 'and' + op = "and" elif isinstance(op_node, ast3.Or): - op = 'or' + op = "or" else: - raise RuntimeError('unknown BoolOp ' + str(type(n))) + raise RuntimeError("unknown BoolOp " + str(type(n))) # potentially inefficient! return self.group(op, self.translate_expr_list(n.values), n) @@ -1239,7 +1364,7 @@ def visit_BinOp(self, n: ast3.BinOp) -> OpExpr: op = self.from_operator(n.op) if op is None: - raise RuntimeError('cannot translate BinOp ' + str(type(n.op))) + raise RuntimeError("cannot translate BinOp " + str(type(n.op))) e = OpExpr(op, self.visit(n.left), self.visit(n.right)) return self.set_line(e, n) @@ -1248,16 +1373,16 @@ def visit_BinOp(self, n: ast3.BinOp) -> OpExpr: def visit_UnaryOp(self, n: ast3.UnaryOp) -> UnaryExpr: op = None if isinstance(n.op, ast3.Invert): - op = '~' + op = "~" elif isinstance(n.op, ast3.Not): - op = 'not' + op = "not" elif isinstance(n.op, ast3.UAdd): - op = '+' + op = "+" elif isinstance(n.op, ast3.USub): - op = '-' + op = "-" if op is None: - raise RuntimeError('cannot translate UnaryOp ' + str(type(n.op))) + raise RuntimeError("cannot translate UnaryOp " + str(type(n.op))) e = UnaryExpr(op, self.visit(n.operand)) return self.set_line(e, n) @@ -1268,22 +1393,22 @@ def visit_Lambda(self, n: ast3.Lambda) -> LambdaExpr: body.lineno = n.body.lineno body.col_offset = n.body.col_offset - e = LambdaExpr(self.transform_args(n.args, n.lineno), - self.as_required_block([body], n.lineno)) + e = LambdaExpr( + self.transform_args(n.args, n.lineno), self.as_required_block([body], n.lineno) + ) e.set_line(n.lineno, n.col_offset) # Overrides set_line -- can't use self.set_line return e # IfExp(expr test, expr body, expr orelse) def visit_IfExp(self, n: ast3.IfExp) -> ConditionalExpr: - e = ConditionalExpr(self.visit(n.test), - self.visit(n.body), - self.visit(n.orelse)) + e = ConditionalExpr(self.visit(n.test), self.visit(n.body), self.visit(n.orelse)) return self.set_line(e, n) # Dict(expr* keys, expr* values) def visit_Dict(self, n: ast3.Dict) -> DictExpr: - e = DictExpr(list(zip(self.translate_opt_expr_list(n.keys), - self.translate_expr_list(n.values)))) + e = DictExpr( + list(zip(self.translate_opt_expr_list(n.keys), self.translate_expr_list(n.values))) + ) return self.set_line(e, n) # Set(expr* elts) @@ -1307,12 +1432,9 @@ def visit_DictComp(self, n: ast3.DictComp) -> DictionaryComprehension: iters = [self.visit(c.iter) for c in n.generators] ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] is_async = [bool(c.is_async) for c in n.generators] - e = DictionaryComprehension(self.visit(n.key), - self.visit(n.value), - targets, - iters, - ifs_list, - is_async) + e = DictionaryComprehension( + self.visit(n.key), self.visit(n.value), targets, iters, ifs_list, is_async + ) return self.set_line(e, n) # GeneratorExp(expr elt, comprehension* generators) @@ -1321,11 +1443,7 @@ def visit_GeneratorExp(self, n: ast3.GeneratorExp) -> GeneratorExpr: iters = [self.visit(c.iter) for c in n.generators] ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] is_async = [bool(c.is_async) for c in n.generators] - e = GeneratorExpr(self.visit(n.elt), - targets, - iters, - ifs_list, - is_async) + e = GeneratorExpr(self.visit(n.elt), targets, iters, ifs_list, is_async) return self.set_line(e, n) # Await(expr value) @@ -1358,14 +1476,17 @@ def visit_Call(self, n: Call) -> CallExpr: keywords = n.keywords keyword_names = [k.arg for k in keywords] arg_types = self.translate_expr_list( - [a.value if isinstance(a, Starred) else a for a in args] + - [k.value for k in keywords]) - arg_kinds = ([ARG_STAR if type(a) is Starred else ARG_POS for a in args] + - [ARG_STAR2 if arg is None else ARG_NAMED for arg in keyword_names]) - e = CallExpr(self.visit(n.func), - arg_types, - arg_kinds, - cast('List[Optional[str]]', [None] * len(args)) + keyword_names) + [a.value if isinstance(a, Starred) else a for a in args] + [k.value for k in keywords] + ) + arg_kinds = [ARG_STAR if type(a) is Starred else ARG_POS for a in args] + [ + ARG_STAR2 if arg is None else ARG_NAMED for arg in keyword_names + ] + e = CallExpr( + self.visit(n.func), + arg_types, + arg_kinds, + cast("List[Optional[str]]", [None] * len(args)) + keyword_names, + ) return self.set_line(e, n) # Constant(object value) -- a constant, in Python 3.8. @@ -1373,7 +1494,7 @@ def visit_Constant(self, n: Constant) -> Any: val = n.value e: Any = None if val is None: - e = NameExpr('None') + e = NameExpr("None") elif isinstance(val, str): e = StrExpr(n.s) elif isinstance(val, bytes): @@ -1389,7 +1510,7 @@ def visit_Constant(self, n: Constant) -> Any: elif val is Ellipsis: e = EllipsisExpr() else: - raise RuntimeError('Constant not implemented for ' + str(type(val))) + raise RuntimeError("Constant not implemented for " + str(type(val))) return self.set_line(e, n) # Num(object n) -- a number as a PyObject. @@ -1406,7 +1527,7 @@ def visit_Num(self, n: ast3.Num) -> Union[IntExpr, FloatExpr, ComplexExpr]: elif isinstance(val, complex): e = ComplexExpr(val) else: - raise RuntimeError('num not implemented for ' + str(type(val))) + raise RuntimeError("num not implemented for " + str(type(val))) return self.set_line(e, n) # Str(string s) @@ -1424,19 +1545,16 @@ def visit_Str(self, n: Str) -> Union[UnicodeExpr, StrExpr]: def visit_JoinedStr(self, n: ast3.JoinedStr) -> Expression: # Each of n.values is a str or FormattedValue; we just concatenate # them all using ''.join. - empty_string = StrExpr('') + empty_string = StrExpr("") empty_string.set_line(n.lineno, n.col_offset) strs_to_join = ListExpr(self.translate_expr_list(n.values)) strs_to_join.set_line(empty_string) # Don't make unnecessary join call if there is only one str to join if len(strs_to_join.items) == 1: return self.set_line(strs_to_join.items[0], n) - join_method = MemberExpr(empty_string, 'join') + join_method = MemberExpr(empty_string, "join") join_method.set_line(empty_string) - result_expression = CallExpr(join_method, - [strs_to_join], - [ARG_POS], - [None]) + result_expression = CallExpr(join_method, [strs_to_join], [ARG_POS], [None]) return self.set_line(result_expression, n) # FormattedValue(expr value) @@ -1447,16 +1565,15 @@ def visit_FormattedValue(self, n: ast3.FormattedValue) -> Expression: # to allow mypyc to support f-strings with format specifiers and conversions. val_exp = self.visit(n.value) val_exp.set_line(n.lineno, n.col_offset) - conv_str = '' if n.conversion is None or n.conversion < 0 else '!' + chr(n.conversion) - format_string = StrExpr('{' + conv_str + ':{}}') - format_spec_exp = self.visit(n.format_spec) if n.format_spec is not None else StrExpr('') + conv_str = "" if n.conversion is None or n.conversion < 0 else "!" + chr(n.conversion) + format_string = StrExpr("{" + conv_str + ":{}}") + format_spec_exp = self.visit(n.format_spec) if n.format_spec is not None else StrExpr("") format_string.set_line(n.lineno, n.col_offset) - format_method = MemberExpr(format_string, 'format') + format_method = MemberExpr(format_string, "format") format_method.set_line(format_string) - result_expression = CallExpr(format_method, - [val_exp, format_spec_exp], - [ARG_POS, ARG_POS], - [None, None]) + result_expression = CallExpr( + format_method, [val_exp, format_spec_exp], [ARG_POS, ARG_POS], [None, None] + ) return self.set_line(result_expression, n) # Bytes(bytes s) @@ -1479,9 +1596,11 @@ def visit_Attribute(self, n: Attribute) -> Union[MemberExpr, SuperExpr]: value = n.value member_expr = MemberExpr(self.visit(value), n.attr) obj = member_expr.expr - if (isinstance(obj, CallExpr) and - isinstance(obj.callee, NameExpr) and - obj.callee.name == 'super'): + if ( + isinstance(obj, CallExpr) + and isinstance(obj.callee, NameExpr) + and obj.callee.name == "super" + ): e: Union[MemberExpr, SuperExpr] = SuperExpr(member_expr.name, obj) else: e = member_expr @@ -1493,9 +1612,8 @@ def visit_Subscript(self, n: ast3.Subscript) -> IndexExpr: self.set_line(e, n) # alias to please mypyc is_py38_or_earlier = sys.version_info < (3, 9) - if ( - isinstance(n.slice, ast3.Slice) or - (is_py38_or_earlier and isinstance(n.slice, ast3.ExtSlice)) + if isinstance(n.slice, ast3.Slice) or ( + is_py38_or_earlier and isinstance(n.slice, ast3.ExtSlice) ): # Before Python 3.9, Slice has no line/column in the raw ast. To avoid incompatibility # visit_Slice doesn't set_line, even in Python 3.9 on. @@ -1534,9 +1652,7 @@ def visit_Tuple(self, n: ast3.Tuple) -> TupleExpr: # Slice(expr? lower, expr? upper, expr? step) def visit_Slice(self, n: ast3.Slice) -> SliceExpr: - return SliceExpr(self.visit(n.lower), - self.visit(n.upper), - self.visit(n.step)) + return SliceExpr(self.visit(n.lower), self.visit(n.upper), self.visit(n.step)) # ExtSlice(slice* dims) def visit_ExtSlice(self, n: ast3.ExtSlice) -> TupleExpr: @@ -1550,10 +1666,12 @@ def visit_Index(self, n: Index) -> Node: # Match(expr subject, match_case* cases) # python 3.10 and later def visit_Match(self, n: Match) -> MatchStmt: - node = MatchStmt(self.visit(n.subject), - [self.visit(c.pattern) for c in n.cases], - [self.visit(c.guard) for c in n.cases], - [self.as_required_block(c.body, n.lineno) for c in n.cases]) + node = MatchStmt( + self.visit(n.subject), + [self.visit(c.pattern) for c in n.cases], + [self.visit(c.guard) for c in n.cases], + [self.as_required_block(c.body, n.lineno) for c in n.cases], + ) return self.set_line(node, n) def visit_MatchValue(self, n: MatchValue) -> ValuePattern: @@ -1619,13 +1737,14 @@ def visit_MatchOr(self, n: MatchOr) -> OrPattern: class TypeConverter: - def __init__(self, - errors: Optional[Errors], - line: int = -1, - override_column: int = -1, - assume_str_is_unicode: bool = True, - is_evaluated: bool = True, - ) -> None: + def __init__( + self, + errors: Optional[Errors], + line: int = -1, + override_column: int = -1, + assume_str_is_unicode: bool = True, + is_evaluated: bool = True, + ) -> None: self.errors = errors self.line = line self.override_column = override_column @@ -1655,18 +1774,16 @@ def invalid_type(self, node: AST, note: Optional[str] = None) -> RawExpressionTy See RawExpressionType's docstring for more details on how it's used. """ return RawExpressionType( - None, - 'typing.Any', - line=self.line, - column=getattr(node, 'col_offset', -1), - note=note, + None, "typing.Any", line=self.line, column=getattr(node, "col_offset", -1), note=note ) @overload - def visit(self, node: ast3.expr) -> ProperType: ... + def visit(self, node: ast3.expr) -> ProperType: + ... @overload - def visit(self, node: Optional[AST]) -> Optional[ProperType]: ... + def visit(self, node: Optional[AST]) -> Optional[ProperType]: + ... def visit(self, node: Optional[AST]) -> Optional[ProperType]: """Modified visit -- keep track of the stack of nodes""" @@ -1674,7 +1791,7 @@ def visit(self, node: Optional[AST]) -> Optional[ProperType]: return None self.node_stack.append(node) try: - method = 'visit_' + node.__class__.__name__ + method = "visit_" + node.__class__.__name__ visitor = getattr(self, method, None) if visitor is not None: return visitor(node) @@ -1695,7 +1812,7 @@ def fail(self, msg: str, line: int, column: int) -> None: def note(self, msg: str, line: int, column: int) -> None: if self.errors: - self.errors.report(line, column, msg, severity='note', code=codes.SYNTAX) + self.errors.report(line, column, msg, severity="note", code=codes.SYNTAX) def translate_expr_list(self, l: Sequence[ast3.expr]) -> List[Type]: return [self.visit(e) for e in l] @@ -1704,11 +1821,9 @@ def visit_raw_str(self, s: str) -> Type: # An escape hatch that allows the AST walker in fastparse2 to # directly hook into the Python 3 type converter in some cases # without needing to create an intermediary `Str` object. - _, typ = parse_type_comment(s.strip(), - self.line, - -1, - self.errors, - self.assume_str_is_unicode) + _, typ = parse_type_comment( + s.strip(), self.line, -1, self.errors, self.assume_str_is_unicode + ) return typ or AnyType(TypeOfAny.from_error) def visit_Call(self, e: Call) -> Type: @@ -1735,26 +1850,37 @@ def visit_Call(self, e: Call) -> Type: elif i == 1: name = self._extract_argument_name(arg) else: - self.fail("Too many arguments for argument constructor", - f.lineno, f.col_offset) + self.fail("Too many arguments for argument constructor", f.lineno, f.col_offset) for k in e.keywords: value = k.value if k.arg == "name": if name is not None: - self.fail('"{}" gets multiple values for keyword argument "name"'.format( - constructor), f.lineno, f.col_offset) + self.fail( + '"{}" gets multiple values for keyword argument "name"'.format( + constructor + ), + f.lineno, + f.col_offset, + ) name = self._extract_argument_name(value) elif k.arg == "type": if typ is not default_type: - self.fail('"{}" gets multiple values for keyword argument "type"'.format( - constructor), f.lineno, f.col_offset) + self.fail( + '"{}" gets multiple values for keyword argument "type"'.format( + constructor + ), + f.lineno, + f.col_offset, + ) converted = self.visit(value) assert converted is not None typ = converted else: self.fail( f'Unexpected argument "{k.arg}" for argument constructor', - value.lineno, value.col_offset) + value.lineno, + value.col_offset, + ) return CallableArgument(typ, name, constructor, e.lineno, e.col_offset) def translate_argument_list(self, l: Sequence[ast3.expr]) -> TypeList: @@ -1763,10 +1889,13 @@ def translate_argument_list(self, l: Sequence[ast3.expr]) -> TypeList: def _extract_argument_name(self, n: ast3.expr) -> Optional[str]: if isinstance(n, Str): return n.s.strip() - elif isinstance(n, NameConstant) and str(n.value) == 'None': + elif isinstance(n, NameConstant) and str(n.value) == "None": return None - self.fail('Expected string literal for argument name, got {}'.format( - type(n).__name__), self.line, 0) + self.fail( + "Expected string literal for argument name, got {}".format(type(n).__name__), + self.line, + 0, + ) return None def visit_Name(self, n: Name) -> Type: @@ -1778,15 +1907,17 @@ def visit_BinOp(self, n: ast3.BinOp) -> Type: left = self.visit(n.left) right = self.visit(n.right) - return UnionType([left, right], - line=self.line, - column=self.convert_column(n.col_offset), - is_evaluated=self.is_evaluated, - uses_pep604_syntax=True) + return UnionType( + [left, right], + line=self.line, + column=self.convert_column(n.col_offset), + is_evaluated=self.is_evaluated, + uses_pep604_syntax=True, + ) def visit_NameConstant(self, n: NameConstant) -> Type: if isinstance(n.value, bool): - return RawExpressionType(n.value, 'builtins.bool', line=self.line) + return RawExpressionType(n.value, "builtins.bool", line=self.line) else: return UnboundType(str(n.value), line=self.line, column=n.col_offset) @@ -1795,26 +1926,36 @@ def visit_Constant(self, n: Constant) -> Type: val = n.value if val is None: # None is a type. - return UnboundType('None', line=self.line) + return UnboundType("None", line=self.line) if isinstance(val, str): # Parse forward reference. - if (n.kind and 'u' in n.kind) or self.assume_str_is_unicode: - return parse_type_string(n.s, 'builtins.unicode', self.line, n.col_offset, - assume_str_is_unicode=self.assume_str_is_unicode) + if (n.kind and "u" in n.kind) or self.assume_str_is_unicode: + return parse_type_string( + n.s, + "builtins.unicode", + self.line, + n.col_offset, + assume_str_is_unicode=self.assume_str_is_unicode, + ) else: - return parse_type_string(n.s, 'builtins.str', self.line, n.col_offset, - assume_str_is_unicode=self.assume_str_is_unicode) + return parse_type_string( + n.s, + "builtins.str", + self.line, + n.col_offset, + assume_str_is_unicode=self.assume_str_is_unicode, + ) if val is Ellipsis: # '...' is valid in some types. return EllipsisType(line=self.line) if isinstance(val, bool): # Special case for True/False. - return RawExpressionType(val, 'builtins.bool', line=self.line) + return RawExpressionType(val, "builtins.bool", line=self.line) if isinstance(val, (int, float, complex)): return self.numeric_type(val, n) if isinstance(val, bytes): contents = bytes_to_human_readable_repr(val) - return RawExpressionType(contents, 'builtins.bytes', self.line, column=n.col_offset) + return RawExpressionType(contents, "builtins.bytes", self.line, column=n.col_offset) # Everything else is invalid. return self.invalid_type(n) @@ -1836,18 +1977,15 @@ def numeric_type(self, value: object, n: AST) -> Type: # this by throwing away the type. if isinstance(value, int): numeric_value: Optional[int] = value - type_name = 'builtins.int' + type_name = "builtins.int" else: # Other kinds of numbers (floats, complex) are not valid parameters for # RawExpressionType so we just pass in 'None' for now. We'll report the # appropriate error at a later stage. numeric_value = None - type_name = f'builtins.{type(value).__name__}' + type_name = f"builtins.{type(value).__name__}" return RawExpressionType( - numeric_value, - type_name, - line=self.line, - column=getattr(n, 'col_offset', -1), + numeric_value, type_name, line=self.line, column=getattr(n, "col_offset", -1) ) # These next three methods are only used if we are on python < @@ -1872,26 +2010,34 @@ def visit_Str(self, n: Str) -> Type: # unused on < 3.8. kind: str = getattr(n, "kind") # noqa - if 'u' in kind or self.assume_str_is_unicode: - return parse_type_string(n.s, 'builtins.unicode', self.line, n.col_offset, - assume_str_is_unicode=self.assume_str_is_unicode) + if "u" in kind or self.assume_str_is_unicode: + return parse_type_string( + n.s, + "builtins.unicode", + self.line, + n.col_offset, + assume_str_is_unicode=self.assume_str_is_unicode, + ) else: - return parse_type_string(n.s, 'builtins.str', self.line, n.col_offset, - assume_str_is_unicode=self.assume_str_is_unicode) + return parse_type_string( + n.s, + "builtins.str", + self.line, + n.col_offset, + assume_str_is_unicode=self.assume_str_is_unicode, + ) # Bytes(bytes s) def visit_Bytes(self, n: Bytes) -> Type: contents = bytes_to_human_readable_repr(n.s) - return RawExpressionType(contents, 'builtins.bytes', self.line, column=n.col_offset) + return RawExpressionType(contents, "builtins.bytes", self.line, column=n.col_offset) def visit_Index(self, n: ast3.Index) -> Type: # cast for mypyc's benefit on Python 3.9 return self.visit(cast(Any, n).value) def visit_Slice(self, n: ast3.Slice) -> Type: - return self.invalid_type( - n, note="did you mean to use ',' instead of ':' ?" - ) + return self.invalid_type(n, note="did you mean to use ',' instead of ':' ?") # Subscript(expr value, slice slice, expr_context ctx) # Python 3.8 and before # Subscript(expr value, expr slice, expr_context ctx) # Python 3.9 and later @@ -1927,14 +2073,24 @@ def visit_Subscript(self, n: ast3.Subscript) -> Type: value = self.visit(n.value) if isinstance(value, UnboundType) and not value.args: - return UnboundType(value.name, params, line=self.line, column=value.column, - empty_tuple_index=empty_tuple_index) + return UnboundType( + value.name, + params, + line=self.line, + column=value.column, + empty_tuple_index=empty_tuple_index, + ) else: return self.invalid_type(n) def visit_Tuple(self, n: ast3.Tuple) -> Type: - return TupleType(self.translate_expr_list(n.elts), _dummy_fallback, - implicit=True, line=self.line, column=self.convert_column(n.col_offset)) + return TupleType( + self.translate_expr_list(n.elts), + _dummy_fallback, + implicit=True, + line=self.line, + column=self.convert_column(n.col_offset), + ) # Attribute(expr value, identifier attr, expr_context ctx) def visit_Attribute(self, n: Attribute) -> Type: diff --git a/mypy/fastparse2.py b/mypy/fastparse2.py index cc8d9599b7418..56b4429d5b390 100644 --- a/mypy/fastparse2.py +++ b/mypy/fastparse2.py @@ -14,73 +14,134 @@ different class hierarchies, which made it difficult to write a shared visitor between the two in a typesafe way. """ -from mypy.util import unnamed_function import sys +import typing # for typing.Type, which conflicts with types.Type import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union, cast -import typing # for typing.Type, which conflicts with types.Type -from typing import Tuple, Union, TypeVar, Callable, Sequence, Optional, Any, Dict, cast, List from typing_extensions import Final, Literal -from mypy.sharedparse import ( - special_function_elide_names, argument_elide_name, +from mypy import errorcodes as codes, message_registry +from mypy.errors import Errors +from mypy.fastparse import ( + INVALID_TYPE_IGNORE, + TYPE_IGNORE_PATTERN, + TypeConverter, + parse_type_comment, + parse_type_ignore_tag, ) from mypy.nodes import ( - MypyFile, Node, ImportBase, Import, ImportAll, ImportFrom, FuncDef, OverloadedFuncDef, - ClassDef, Decorator, Block, Var, OperatorAssignmentStmt, - ExpressionStmt, AssignmentStmt, ReturnStmt, RaiseStmt, AssertStmt, - DelStmt, BreakStmt, ContinueStmt, PassStmt, GlobalDecl, - WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, - TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr, - DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, UnicodeExpr, - FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr, - UnaryExpr, LambdaExpr, ComparisonExpr, DictionaryComprehension, - SetComprehension, ComplexExpr, EllipsisExpr, YieldExpr, Argument, - Expression, Statement, BackquoteExpr, PrintStmt, ExecStmt, - ArgKind, ARG_POS, ARG_OPT, ARG_STAR, ARG_NAMED, ARG_STAR2, OverloadPart, check_arg_names, + ARG_NAMED, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + ArgKind, + Argument, + AssertStmt, + AssignmentStmt, + BackquoteExpr, + Block, + BreakStmt, + CallExpr, + ClassDef, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + ContinueStmt, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + ExecStmt, + Expression, + ExpressionStmt, FakeInfo, -) -from mypy.types import ( - Type, CallableType, AnyType, UnboundType, EllipsisType, TypeOfAny, Instance, - ProperType -) -from mypy import message_registry, errorcodes as codes -from mypy.errors import Errors -from mypy.fastparse import ( - TypeConverter, parse_type_comment, parse_type_ignore_tag, - TYPE_IGNORE_PATTERN, INVALID_TYPE_IGNORE + FloatExpr, + ForStmt, + FuncDef, + GeneratorExpr, + GlobalDecl, + IfStmt, + Import, + ImportAll, + ImportBase, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MemberExpr, + MypyFile, + NameExpr, + Node, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + OverloadPart, + PassStmt, + PrintStmt, + RaiseStmt, + ReturnStmt, + SetComprehension, + SetExpr, + SliceExpr, + Statement, + StrExpr, + SuperExpr, + TryStmt, + TupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + YieldExpr, + check_arg_names, ) from mypy.options import Options -from mypy.util import bytes_to_human_readable_repr from mypy.reachability import mark_block_unreachable +from mypy.sharedparse import argument_elide_name, special_function_elide_names +from mypy.types import ( + AnyType, + CallableType, + EllipsisType, + Instance, + ProperType, + Type, + TypeOfAny, + UnboundType, +) +from mypy.util import bytes_to_human_readable_repr, unnamed_function try: from typed_ast import ast27 - from typed_ast.ast27 import ( - AST, - Call, - Name, - Attribute, - Tuple as ast27_Tuple, - ) + from typed_ast.ast27 import AST, Attribute, Call, Name, Tuple as ast27_Tuple + # Import ast3 from fastparse, which has special case for Python 3.8 from mypy.fastparse import ast3, ast3_parse except ImportError: try: from typed_ast import ast35 # type: ignore[attr-defined] # noqa: F401 except ImportError: - print('The typed_ast package is not installed.\n' - 'For Python 2 support, install mypy using `python3 -m pip install "mypy[python2]"`' - 'Alternatively, you can install typed_ast with `python3 -m pip install typed-ast`.', - file=sys.stderr) + print( + "The typed_ast package is not installed.\n" + 'For Python 2 support, install mypy using `python3 -m pip install "mypy[python2]"`' + "Alternatively, you can install typed_ast with `python3 -m pip install typed-ast`.", + file=sys.stderr, + ) else: - print('You need a more recent version of the typed_ast package.\n' - 'You can update to the latest version with ' - '`python3 -m pip install -U typed-ast`.', - file=sys.stderr) + print( + "You need a more recent version of the typed_ast package.\n" + "You can update to the latest version with " + "`python3 -m pip install -U typed-ast`.", + file=sys.stderr, + ) sys.exit(1) -N = TypeVar('N', bound=Node) +N = TypeVar("N", bound=Node) # There is no way to create reasonable fallbacks at this stage, # they must be patched later. @@ -91,11 +152,13 @@ TYPE_COMMENT_AST_ERROR: Final = "invalid type comment" -def parse(source: Union[str, bytes], - fnam: str, - module: Optional[str], - errors: Optional[Errors] = None, - options: Optional[Options] = None) -> MypyFile: +def parse( + source: Union[str, bytes], + fnam: str, + module: Optional[str], + errors: Optional[Errors] = None, + options: Optional[Options] = None, +) -> MypyFile: """Parse a source file, without doing any semantic analysis. Return the parse tree. If errors is not provided, raise ParseError @@ -108,22 +171,25 @@ def parse(source: Union[str, bytes], if options is None: options = Options() errors.set_file(fnam, module) - is_stub_file = fnam.endswith('.pyi') + is_stub_file = fnam.endswith(".pyi") try: assert options.python_version[0] < 3 and not is_stub_file # Disable deprecation warnings about <>. with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) - ast = ast27.parse(source, fnam, 'exec') - tree = ASTConverter(options=options, - errors=errors, - ).visit(ast) + ast = ast27.parse(source, fnam, "exec") + tree = ASTConverter(options=options, errors=errors).visit(ast) assert isinstance(tree, MypyFile) tree.path = fnam tree.is_stub = is_stub_file except SyntaxError as e: - errors.report(e.lineno if e.lineno is not None else -1, e.offset, e.msg, blocker=True, - code=codes.SYNTAX) + errors.report( + e.lineno if e.lineno is not None else -1, + e.offset, + e.msg, + blocker=True, + code=codes.SYNTAX, + ) tree = MypyFile([], [], False, {}) if raise_on_error and errors.is_errors(): @@ -134,17 +200,15 @@ def parse(source: Union[str, bytes], def is_no_type_check_decorator(expr: ast27.expr) -> bool: if isinstance(expr, Name): - return expr.id == 'no_type_check' + return expr.id == "no_type_check" elif isinstance(expr, Attribute): if isinstance(expr.value, Name): - return expr.value.id == 'typing' and expr.attr == 'no_type_check' + return expr.value.id == "typing" and expr.attr == "no_type_check" return False class ASTConverter: - def __init__(self, - options: Options, - errors: Errors) -> None: + def __init__(self, options: Options, errors: Errors) -> None: # 'C' for class, 'F' for function self.class_and_function_stack: List[Literal["C", "F"]] = [] self.imports: List[ImportBase] = [] @@ -186,7 +250,7 @@ def visit(self, node: Optional[AST]) -> Any: # same as in typed_ast stub typeobj = type(node) visitor = self.visitor_cache.get(typeobj) if visitor is None: - method = 'visit_' + node.__class__.__name__ + method = "visit_" + node.__class__.__name__ visitor = getattr(self, method) self.visitor_cache[typeobj] = visitor return visitor(node) @@ -209,15 +273,20 @@ def get_lineno(self, node: Union[ast27.expr, ast27.stmt]) -> int: return node.decorator_list[0].lineno return node.lineno - def translate_stmt_list(self, - stmts: Sequence[ast27.stmt], - module: bool = False) -> List[Statement]: + def translate_stmt_list( + self, stmts: Sequence[ast27.stmt], module: bool = False + ) -> List[Statement]: # A "# type: ignore" comment before the first statement of a module # ignores the whole module: - if (module and stmts and self.type_ignores - and min(self.type_ignores) < self.get_lineno(stmts[0])): + if ( + module + and stmts + and self.type_ignores + and min(self.type_ignores) < self.get_lineno(stmts[0]) + ): self.errors.used_ignored_lines[self.errors.file][min(self.type_ignores)].append( - codes.FILE.code) + codes.FILE.code + ) block = Block(self.fix_function_overloads(self.translate_stmt_list(stmts))) mark_block_unreachable(block) return [block] @@ -229,62 +298,65 @@ def translate_stmt_list(self, res.append(node) return res - def translate_type_comment(self, n: ast27.stmt, - type_comment: Optional[str]) -> Optional[ProperType]: + def translate_type_comment( + self, n: ast27.stmt, type_comment: Optional[str] + ) -> Optional[ProperType]: if type_comment is None: return None else: lineno = n.lineno - extra_ignore, typ = parse_type_comment(type_comment, - lineno, - n.col_offset, - self.errors, - assume_str_is_unicode=self.unicode_literals) + extra_ignore, typ = parse_type_comment( + type_comment, + lineno, + n.col_offset, + self.errors, + assume_str_is_unicode=self.unicode_literals, + ) if extra_ignore is not None: self.type_ignores[lineno] = extra_ignore return typ op_map: Final[Dict[typing.Type[AST], str]] = { - ast27.Add: '+', - ast27.Sub: '-', - ast27.Mult: '*', - ast27.Div: '/', - ast27.Mod: '%', - ast27.Pow: '**', - ast27.LShift: '<<', - ast27.RShift: '>>', - ast27.BitOr: '|', - ast27.BitXor: '^', - ast27.BitAnd: '&', - ast27.FloorDiv: '//' + ast27.Add: "+", + ast27.Sub: "-", + ast27.Mult: "*", + ast27.Div: "/", + ast27.Mod: "%", + ast27.Pow: "**", + ast27.LShift: "<<", + ast27.RShift: ">>", + ast27.BitOr: "|", + ast27.BitXor: "^", + ast27.BitAnd: "&", + ast27.FloorDiv: "//", } def from_operator(self, op: ast27.operator) -> str: op_name = ASTConverter.op_map.get(type(op)) if op_name is None: - raise RuntimeError('Unknown operator ' + str(type(op))) - elif op_name == '@': - raise RuntimeError('mypy does not support the MatMult operator') + raise RuntimeError("Unknown operator " + str(type(op))) + elif op_name == "@": + raise RuntimeError("mypy does not support the MatMult operator") else: return op_name comp_op_map: Final[Dict[typing.Type[AST], str]] = { - ast27.Gt: '>', - ast27.Lt: '<', - ast27.Eq: '==', - ast27.GtE: '>=', - ast27.LtE: '<=', - ast27.NotEq: '!=', - ast27.Is: 'is', - ast27.IsNot: 'is not', - ast27.In: 'in', - ast27.NotIn: 'not in' + ast27.Gt: ">", + ast27.Lt: "<", + ast27.Eq: "==", + ast27.GtE: ">=", + ast27.LtE: "<=", + ast27.NotEq: "!=", + ast27.Is: "is", + ast27.IsNot: "is not", + ast27.In: "in", + ast27.NotIn: "not in", } def from_comp_operator(self, op: ast27.cmpop) -> str: op_name = ASTConverter.comp_op_map.get(type(op)) if op_name is None: - raise RuntimeError('Unknown comparison operator ' + str(type(op))) + raise RuntimeError("Unknown comparison operator " + str(type(op))) else: return op_name @@ -306,9 +378,11 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: current_overload: List[OverloadPart] = [] current_overload_name: Optional[str] = None for stmt in stmts: - if (current_overload_name is not None - and isinstance(stmt, (Decorator, FuncDef)) - and stmt.name == current_overload_name): + if ( + current_overload_name is not None + and isinstance(stmt, (Decorator, FuncDef)) + and stmt.name == current_overload_name + ): current_overload.append(stmt) else: if len(current_overload) == 1: @@ -331,7 +405,7 @@ def fix_function_overloads(self, stmts: List[Statement]) -> List[Statement]: return ret def in_method_scope(self) -> bool: - return self.class_and_function_stack[-2:] == ['C', 'F'] + return self.class_and_function_stack[-2:] == ["C", "F"] def translate_module_id(self, id: str) -> str: """Return the actual, internal module id for a source text id. @@ -339,11 +413,11 @@ def translate_module_id(self, id: str) -> str: For example, translate '__builtin__' in Python 2 to 'builtins'. """ if id == self.options.custom_typing_module: - return 'typing' - elif id == '__builtin__': + return "typing" + elif id == "__builtin__": # HACK: __builtin__ in Python 2 is aliases to builtins. However, the implementation # is named __builtin__.py (there is another layer of translation elsewhere). - return 'builtins' + return "builtins" return id def visit_Module(self, mod: ast27.Module) -> MypyFile: @@ -355,11 +429,7 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile: else: self.fail(INVALID_TYPE_IGNORE, ti.lineno, -1) body = self.fix_function_overloads(self.translate_stmt_list(mod.body, module=True)) - return MypyFile(body, - self.imports, - False, - self.type_ignores, - ) + return MypyFile(body, self.imports, False, self.type_ignores) # --- stmt --- # FunctionDef(identifier name, arguments args, @@ -367,10 +437,14 @@ def visit_Module(self, mod: ast27.Module) -> MypyFile: # arguments = (arg* args, arg? vararg, arg* kwonlyargs, expr* kw_defaults, # arg? kwarg, expr* defaults) def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: - self.class_and_function_stack.append('F') + self.class_and_function_stack.append("F") lineno = n.lineno - converter = TypeConverter(self.errors, line=lineno, override_column=n.col_offset, - assume_str_is_unicode=self.unicode_literals) + converter = TypeConverter( + self.errors, + line=lineno, + override_column=n.col_offset, + assume_str_is_unicode=self.unicode_literals, + ) args, decompose_stmts = self.transform_args(n.args, lineno) if special_function_elide_names(n.name): for arg in args: @@ -381,26 +455,31 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: arg_types: List[Optional[Type]] = [] type_comment = n.type_comment - if (n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list)): + if n.decorator_list and any(is_no_type_check_decorator(d) for d in n.decorator_list): arg_types = [None] * len(args) return_type = None elif type_comment is not None and len(type_comment) > 0: try: - func_type_ast = ast3_parse(type_comment, '', 'func_type') + func_type_ast = ast3_parse(type_comment, "", "func_type") assert isinstance(func_type_ast, ast3.FunctionType) # for ellipsis arg - if (len(func_type_ast.argtypes) == 1 and - isinstance(func_type_ast.argtypes[0], ast3.Ellipsis)): - arg_types = [a.type_annotation - if a.type_annotation is not None - else AnyType(TypeOfAny.unannotated) - for a in args] + if len(func_type_ast.argtypes) == 1 and isinstance( + func_type_ast.argtypes[0], ast3.Ellipsis + ): + arg_types = [ + a.type_annotation + if a.type_annotation is not None + else AnyType(TypeOfAny.unannotated) + for a in args + ] else: # PEP 484 disallows both type annotations and type comments if any(a.type_annotation is not None for a in args): self.fail(message_registry.DUPLICATE_TYPE_SIGNATURES, lineno, n.col_offset) - arg_types = [a if a is not None else AnyType(TypeOfAny.unannotated) for - a in converter.translate_expr_list(func_type_ast.argtypes)] + arg_types = [ + a if a is not None else AnyType(TypeOfAny.unannotated) + for a in converter.translate_expr_list(func_type_ast.argtypes) + ] return_type = converter.visit(func_type_ast.returns) # add implicit self type @@ -421,31 +500,34 @@ def visit_FunctionDef(self, n: ast27.FunctionDef) -> Statement: func_type = None if any(arg_types) or return_type: - if len(arg_types) != 1 and any(isinstance(t, EllipsisType) - for t in arg_types): - self.fail("Ellipses cannot accompany other argument types " - "in function type signature", lineno, n.col_offset) + if len(arg_types) != 1 and any(isinstance(t, EllipsisType) for t in arg_types): + self.fail( + "Ellipses cannot accompany other argument types " "in function type signature", + lineno, + n.col_offset, + ) elif len(arg_types) > len(arg_kinds): - self.fail('Type signature has too many arguments', lineno, n.col_offset, - blocker=False) + self.fail( + "Type signature has too many arguments", lineno, n.col_offset, blocker=False + ) elif len(arg_types) < len(arg_kinds): - self.fail('Type signature has too few arguments', lineno, n.col_offset, - blocker=False) + self.fail( + "Type signature has too few arguments", lineno, n.col_offset, blocker=False + ) else: any_type = AnyType(TypeOfAny.unannotated) - func_type = CallableType([a if a is not None else any_type for a in arg_types], - arg_kinds, - arg_names, - return_type if return_type is not None else any_type, - _dummy_fallback) + func_type = CallableType( + [a if a is not None else any_type for a in arg_types], + arg_kinds, + arg_names, + return_type if return_type is not None else any_type, + _dummy_fallback, + ) body = self.as_required_block(n.body, lineno) if decompose_stmts: body.body = decompose_stmts + body.body - func_def = FuncDef(n.name, - args, - body, - func_type) + func_def = FuncDef(n.name, args, body, func_type) if isinstance(func_def.type, CallableType): # semanal.py does some in-place modifications we want to avoid func_def.unanalyzed_type = func_def.type.copy_modified() @@ -475,23 +557,27 @@ def set_type_optional(self, type: Optional[Type], initializer: Optional[Expressi if self.options.no_implicit_optional: return # Indicate that type should be wrapped in an Optional if arg is initialized to None. - optional = isinstance(initializer, NameExpr) and initializer.name == 'None' + optional = isinstance(initializer, NameExpr) and initializer.name == "None" if isinstance(type, UnboundType): type.optional = optional - def transform_args(self, - n: ast27.arguments, - line: int, - ) -> Tuple[List[Argument], List[Statement]]: + def transform_args( + self, n: ast27.arguments, line: int + ) -> Tuple[List[Argument], List[Statement]]: type_comments: Sequence[Optional[str]] = n.type_comments - converter = TypeConverter(self.errors, line=line, - assume_str_is_unicode=self.unicode_literals) + converter = TypeConverter( + self.errors, line=line, assume_str_is_unicode=self.unicode_literals + ) decompose_stmts: List[Statement] = [] n_args = n.args - args = [(self.convert_arg(i, arg, line, decompose_stmts), - self.get_type(i, type_comments, converter)) - for i, arg in enumerate(n_args)] + args = [ + ( + self.convert_arg(i, arg, line, decompose_stmts), + self.get_type(i, type_comments, converter), + ) + for i, arg in enumerate(n_args) + ] defaults = self.translate_expr_list(n.defaults) names: List[str] = [name for arg in n_args for name in self.extract_names(arg)] @@ -507,17 +593,21 @@ def transform_args(self, # *arg if n.vararg is not None: - new_args.append(Argument(Var(n.vararg), - self.get_type(len(args), type_comments, converter), - None, - ARG_STAR)) + new_args.append( + Argument( + Var(n.vararg), + self.get_type(len(args), type_comments, converter), + None, + ARG_STAR, + ) + ) names.append(n.vararg) # **kwarg if n.kwarg is not None: - typ = self.get_type(len(args) + (0 if n.vararg is None else 1), - type_comments, - converter) + typ = self.get_type( + len(args) + (0 if n.vararg is None else 1), type_comments, converter + ) new_args.append(Argument(Var(n.kwarg), typ, None, ARG_STAR2)) names.append(n.kwarg) @@ -528,6 +618,7 @@ def transform_args(self, # We don't have any context object to give, but we have closed around the line num def fail_arg(msg: str, arg: None) -> None: self.fail(msg, line, 0) + check_arg_names(names, [None] * len(names), fail_arg) return new_args, decompose_stmts @@ -540,12 +631,13 @@ def extract_names(self, arg: ast27.expr) -> List[str]: else: return [] - def convert_arg(self, index: int, arg: ast27.expr, line: int, - decompose_stmts: List[Statement]) -> Var: + def convert_arg( + self, index: int, arg: ast27.expr, line: int, decompose_stmts: List[Statement] + ) -> Var: if isinstance(arg, Name): v = arg.id elif isinstance(arg, ast27_Tuple): - v = f'__tuple_arg_{index + 1}' + v = f"__tuple_arg_{index + 1}" rvalue = NameExpr(v) rvalue.set_line(line) assignment = AssignmentStmt([self.visit(arg)], rvalue) @@ -555,10 +647,9 @@ def convert_arg(self, index: int, arg: ast27.expr, line: int, raise RuntimeError(f"'{ast27.dump(arg)}' is not a valid argument.") return Var(v) - def get_type(self, - i: int, - type_comments: Sequence[Optional[str]], - converter: TypeConverter) -> Optional[Type]: + def get_type( + self, i: int, type_comments: Sequence[Optional[str]], converter: TypeConverter + ) -> Optional[Type]: if i < len(type_comments): comment = type_comments[i] if comment is not None: @@ -588,13 +679,15 @@ def stringify_name(self, n: AST) -> str: # stmt* body, # expr* decorator_list) def visit_ClassDef(self, n: ast27.ClassDef) -> ClassDef: - self.class_and_function_stack.append('C') - - cdef = ClassDef(n.name, - self.as_required_block(n.body, n.lineno), - None, - self.translate_expr_list(n.bases), - metaclass=None) + self.class_and_function_stack.append("C") + + cdef = ClassDef( + n.name, + self.as_required_block(n.body, n.lineno), + None, + self.translate_expr_list(n.bases), + metaclass=None, + ) cdef.decorators = self.translate_expr_list(n.decorator_list) cdef.line = n.lineno + len(n.decorator_list) cdef.column = n.col_offset @@ -621,49 +714,55 @@ def visit_Delete(self, n: ast27.Delete) -> DelStmt: # Assign(expr* targets, expr value, string? type_comment) def visit_Assign(self, n: ast27.Assign) -> AssignmentStmt: typ = self.translate_type_comment(n, n.type_comment) - stmt = AssignmentStmt(self.translate_expr_list(n.targets), - self.visit(n.value), - type=typ) + stmt = AssignmentStmt(self.translate_expr_list(n.targets), self.visit(n.value), type=typ) return self.set_line(stmt, n) # AugAssign(expr target, operator op, expr value) def visit_AugAssign(self, n: ast27.AugAssign) -> OperatorAssignmentStmt: - stmt = OperatorAssignmentStmt(self.from_operator(n.op), - self.visit(n.target), - self.visit(n.value)) + stmt = OperatorAssignmentStmt( + self.from_operator(n.op), self.visit(n.target), self.visit(n.value) + ) return self.set_line(stmt, n) # For(expr target, expr iter, stmt* body, stmt* orelse, string? type_comment) def visit_For(self, n: ast27.For) -> ForStmt: typ = self.translate_type_comment(n, n.type_comment) - stmt = ForStmt(self.visit(n.target), - self.visit(n.iter), - self.as_required_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno), - typ) + stmt = ForStmt( + self.visit(n.target), + self.visit(n.iter), + self.as_required_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno), + typ, + ) return self.set_line(stmt, n) # While(expr test, stmt* body, stmt* orelse) def visit_While(self, n: ast27.While) -> WhileStmt: - stmt = WhileStmt(self.visit(n.test), - self.as_required_block(n.body, n.lineno), - self.as_block(n.orelse, n.lineno)) + stmt = WhileStmt( + self.visit(n.test), + self.as_required_block(n.body, n.lineno), + self.as_block(n.orelse, n.lineno), + ) return self.set_line(stmt, n) # If(expr test, stmt* body, stmt* orelse) def visit_If(self, n: ast27.If) -> IfStmt: - stmt = IfStmt([self.visit(n.test)], - [self.as_required_block(n.body, n.lineno)], - self.as_block(n.orelse, n.lineno)) + stmt = IfStmt( + [self.visit(n.test)], + [self.as_required_block(n.body, n.lineno)], + self.as_block(n.orelse, n.lineno), + ) return self.set_line(stmt, n) # With(withitem* items, stmt* body, string? type_comment) def visit_With(self, n: ast27.With) -> WithStmt: typ = self.translate_type_comment(n, n.type_comment) - stmt = WithStmt([self.visit(n.context_expr)], - [self.visit(n.optional_vars)], - self.as_required_block(n.body, n.lineno), - typ) + stmt = WithStmt( + [self.visit(n.context_expr)], + [self.visit(n.optional_vars)], + self.as_required_block(n.body, n.lineno), + typ, + ) return self.set_line(stmt, n) # 'raise' [test [',' test [',' test]]] @@ -697,12 +796,14 @@ def visit_TryFinally(self, n: ast27.TryFinally) -> TryStmt: stmt = self.try_handler(n.body, [], [], n.finalbody, n.lineno) return self.set_line(stmt, n) - def try_handler(self, - body: List[ast27.stmt], - handlers: List[ast27.ExceptHandler], - orelse: List[ast27.stmt], - finalbody: List[ast27.stmt], - lineno: int) -> TryStmt: + def try_handler( + self, + body: List[ast27.stmt], + handlers: List[ast27.ExceptHandler], + orelse: List[ast27.stmt], + finalbody: List[ast27.stmt], + lineno: int, + ) -> TryStmt: vs: List[Optional[NameExpr]] = [] for item in handlers: if item.name is None: @@ -710,27 +811,30 @@ def try_handler(self, elif isinstance(item.name, Name): vs.append(self.set_line(NameExpr(item.name.id), item)) else: - self.fail('Sorry, "except , " is not supported', - item.lineno, item.col_offset) + self.fail( + 'Sorry, "except , " is not supported', + item.lineno, + item.col_offset, + ) vs.append(None) types = [self.visit(h.type) for h in handlers] handlers_ = [self.as_required_block(h.body, h.lineno) for h in handlers] - return TryStmt(self.as_required_block(body, lineno), - vs, - types, - handlers_, - self.as_block(orelse, lineno), - self.as_block(finalbody, lineno)) + return TryStmt( + self.as_required_block(body, lineno), + vs, + types, + handlers_, + self.as_block(orelse, lineno), + self.as_block(finalbody, lineno), + ) def visit_Print(self, n: ast27.Print) -> PrintStmt: stmt = PrintStmt(self.translate_expr_list(n.values), n.nl, self.visit(n.dest)) return self.set_line(stmt, n) def visit_Exec(self, n: ast27.Exec) -> ExecStmt: - stmt = ExecStmt(self.visit(n.body), - self.visit(n.globals), - self.visit(n.locals)) + stmt = ExecStmt(self.visit(n.body), self.visit(n.globals), self.visit(n.locals)) return self.set_line(stmt, n) def visit_Repr(self, n: ast27.Repr) -> BackquoteExpr: @@ -761,15 +865,15 @@ def visit_Import(self, n: ast27.Import) -> Import: # ImportFrom(identifier? module, alias* names, int? level) def visit_ImportFrom(self, n: ast27.ImportFrom) -> ImportBase: assert n.level is not None - if len(n.names) == 1 and n.names[0].name == '*': - mod = n.module if n.module is not None else '' + if len(n.names) == 1 and n.names[0].name == "*": + mod = n.module if n.module is not None else "" i: ImportBase = ImportAll(mod, n.level) else: - module_id = self.translate_module_id(n.module) if n.module is not None else '' + module_id = self.translate_module_id(n.module) if n.module is not None else "" i = ImportFrom(module_id, n.level, [(a.name, a.asname) for a in n.names]) # See comments in the constructor for more information about this field. - if module_id == '__future__' and any(a.name == 'unicode_literals' for a in n.names): + if module_id == "__future__" and any(a.name == "unicode_literals" for a in n.names): self.unicode_literals = True self.imports.append(i) return self.set_line(i, n) @@ -807,11 +911,11 @@ def visit_BoolOp(self, n: ast27.BoolOp) -> OpExpr: # mypy translates (1 and 2 and 3) as (1 and (2 and 3)) assert len(n.values) >= 2 if isinstance(n.op, ast27.And): - op = 'and' + op = "and" elif isinstance(n.op, ast27.Or): - op = 'or' + op = "or" else: - raise RuntimeError('unknown BoolOp ' + str(type(n))) + raise RuntimeError("unknown BoolOp " + str(type(n))) # potentially inefficient! e = self.group(self.translate_expr_list(n.values), op) @@ -828,7 +932,7 @@ def visit_BinOp(self, n: ast27.BinOp) -> OpExpr: op = self.from_operator(n.op) if op is None: - raise RuntimeError('cannot translate BinOp ' + str(type(n.op))) + raise RuntimeError("cannot translate BinOp " + str(type(n.op))) e = OpExpr(op, self.visit(n.left), self.visit(n.right)) return self.set_line(e, n) @@ -837,16 +941,16 @@ def visit_BinOp(self, n: ast27.BinOp) -> OpExpr: def visit_UnaryOp(self, n: ast27.UnaryOp) -> UnaryExpr: op = None if isinstance(n.op, ast27.Invert): - op = '~' + op = "~" elif isinstance(n.op, ast27.Not): - op = 'not' + op = "not" elif isinstance(n.op, ast27.UAdd): - op = '+' + op = "+" elif isinstance(n.op, ast27.USub): - op = '-' + op = "-" if op is None: - raise RuntimeError('cannot translate UnaryOp ' + str(type(n.op))) + raise RuntimeError("cannot translate UnaryOp " + str(type(n.op))) e = UnaryExpr(op, self.visit(n.operand)) return self.set_line(e, n) @@ -868,15 +972,14 @@ def visit_Lambda(self, n: ast27.Lambda) -> LambdaExpr: # IfExp(expr test, expr body, expr orelse) def visit_IfExp(self, n: ast27.IfExp) -> ConditionalExpr: - e = ConditionalExpr(self.visit(n.test), - self.visit(n.body), - self.visit(n.orelse)) + e = ConditionalExpr(self.visit(n.test), self.visit(n.body), self.visit(n.orelse)) return self.set_line(e, n) # Dict(expr* keys, expr* values) def visit_Dict(self, n: ast27.Dict) -> DictExpr: - e = DictExpr(list(zip(self.translate_expr_list(n.keys), - self.translate_expr_list(n.values)))) + e = DictExpr( + list(zip(self.translate_expr_list(n.keys), self.translate_expr_list(n.values))) + ) return self.set_line(e, n) # Set(expr* elts) @@ -899,12 +1002,14 @@ def visit_DictComp(self, n: ast27.DictComp) -> DictionaryComprehension: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] - e = DictionaryComprehension(self.visit(n.key), - self.visit(n.value), - targets, - iters, - ifs_list, - [False for _ in n.generators]) + e = DictionaryComprehension( + self.visit(n.key), + self.visit(n.value), + targets, + iters, + ifs_list, + [False for _ in n.generators], + ) return self.set_line(e, n) # GeneratorExp(expr elt, comprehension* generators) @@ -912,11 +1017,9 @@ def visit_GeneratorExp(self, n: ast27.GeneratorExp) -> GeneratorExpr: targets = [self.visit(c.target) for c in n.generators] iters = [self.visit(c.iter) for c in n.generators] ifs_list = [self.translate_expr_list(c.ifs) for c in n.generators] - e = GeneratorExpr(self.visit(n.elt), - targets, - iters, - ifs_list, - [False for _ in n.generators]) + e = GeneratorExpr( + self.visit(n.elt), targets, iters, ifs_list, [False for _ in n.generators] + ) return self.set_line(e, n) # Yield(expr? value) @@ -958,10 +1061,7 @@ def visit_Call(self, n: Call) -> CallExpr: arg_kinds.append(ARG_STAR2) signature.append(None) - e = CallExpr(self.visit(n.func), - self.translate_expr_list(arg_types), - arg_kinds, - signature) + e = CallExpr(self.visit(n.func), self.translate_expr_list(arg_types), arg_kinds, signature) return self.set_line(e, n) # Num(object n) -- a number as a PyObject. @@ -972,7 +1072,7 @@ def visit_Num(self, n: ast27.Num) -> Expression: # this by throwing away the type. value: object = n.n is_inverse = False - if str(n.n).startswith('-'): # Hackish because of complex. + if str(n.n).startswith("-"): # Hackish because of complex. value = -n.n is_inverse = True @@ -983,10 +1083,10 @@ def visit_Num(self, n: ast27.Num) -> Expression: elif isinstance(value, complex): expr = ComplexExpr(value) else: - raise RuntimeError('num not implemented for ' + str(type(n.n))) + raise RuntimeError("num not implemented for " + str(type(n.n))) if is_inverse: - expr = UnaryExpr('-', expr) + expr = UnaryExpr("-", expr) return self.set_line(expr, n) @@ -1020,9 +1120,11 @@ def visit_Attribute(self, n: Attribute) -> Expression: # less common than normal member expressions. member_expr = MemberExpr(self.visit(n.value), n.attr) obj = member_expr.expr - if (isinstance(obj, CallExpr) and - isinstance(obj.callee, NameExpr) and - obj.callee.name == 'super'): + if ( + isinstance(obj, CallExpr) + and isinstance(obj.callee, NameExpr) + and obj.callee.name == "super" + ): e: Expression = SuperExpr(member_expr.name, obj) else: e = member_expr @@ -1062,9 +1164,7 @@ def visit_Tuple(self, n: ast27_Tuple) -> TupleExpr: # Slice(expr? lower, expr? upper, expr? step) def visit_Slice(self, n: ast27.Slice) -> SliceExpr: - return SliceExpr(self.visit(n.lower), - self.visit(n.upper), - self.visit(n.step)) + return SliceExpr(self.visit(n.lower), self.visit(n.upper), self.visit(n.step)) # ExtSlice(slice* dims) def visit_ExtSlice(self, n: ast27.ExtSlice) -> TupleExpr: diff --git a/mypy/find_sources.py b/mypy/find_sources.py index 64e975f868337..cd9d9aa5f3630 100644 --- a/mypy/find_sources.py +++ b/mypy/find_sources.py @@ -2,12 +2,12 @@ import functools import os +from typing import List, Optional, Sequence, Set, Tuple -from typing import List, Sequence, Set, Tuple, Optional from typing_extensions import Final -from mypy.modulefinder import BuildSource, PYTHON_EXTENSIONS, mypy_path, matches_exclude from mypy.fscache import FileSystemCache +from mypy.modulefinder import PYTHON_EXTENSIONS, BuildSource, matches_exclude, mypy_path from mypy.options import Options PY_EXTENSIONS: Final = tuple(PYTHON_EXTENSIONS) @@ -17,9 +17,12 @@ class InvalidSourceList(Exception): """Exception indicating a problem in the list of sources given to mypy.""" -def create_source_list(paths: Sequence[str], options: Options, - fscache: Optional[FileSystemCache] = None, - allow_empty_dir: bool = False) -> List[BuildSource]: +def create_source_list( + paths: Sequence[str], + options: Options, + fscache: Optional[FileSystemCache] = None, + allow_empty_dir: bool = False, +) -> List[BuildSource]: """From a list of source files/directories, makes a list of BuildSources. Raises InvalidSourceList on errors. @@ -37,9 +40,7 @@ def create_source_list(paths: Sequence[str], options: Options, elif fscache.isdir(path): sub_sources = finder.find_sources_in_dir(path) if not sub_sources and not allow_empty_dir: - raise InvalidSourceList( - f"There are no .py[i] files in directory '{path}'" - ) + raise InvalidSourceList(f"There are no .py[i] files in directory '{path}'") sources.extend(sub_sources) else: mod = os.path.basename(path) if options.scripts_are_modules else None @@ -109,9 +110,7 @@ def find_sources_in_dir(self, path: str) -> List[BuildSource]: continue subpath = os.path.join(path, name) - if matches_exclude( - subpath, self.exclude, self.fscache, self.verbosity >= 2 - ): + if matches_exclude(subpath, self.exclude, self.fscache, self.verbosity >= 2): continue if self.fscache.isdir(subpath): @@ -176,7 +175,7 @@ def _crawl_up_helper(self, dir: str) -> Optional[Tuple[str, str]]: return "", dir parent, name = os.path.split(dir) - if name.endswith('-stubs'): + if name.endswith("-stubs"): name = name[:-6] # PEP-561 stub-only directory # recurse if there's an __init__.py @@ -218,10 +217,10 @@ def get_init_file(self, dir: str) -> Optional[str]: This prefers .pyi over .py (because of the ordering of PY_EXTENSIONS). """ for ext in PY_EXTENSIONS: - f = os.path.join(dir, '__init__' + ext) + f = os.path.join(dir, "__init__" + ext) if self.fscache.isfile(f): return f - if ext == '.py' and self.fscache.init_under_package_root(f): + if ext == ".py" and self.fscache.init_under_package_root(f): return f return None @@ -229,7 +228,7 @@ def get_init_file(self, dir: str) -> Optional[str]: def module_join(parent: str, child: str) -> str: """Join module ids, accounting for a possibly empty parent.""" if parent: - return parent + '.' + child + return parent + "." + child return child @@ -240,5 +239,5 @@ def strip_py(arg: str) -> Optional[str]: """ for ext in PY_EXTENSIONS: if arg.endswith(ext): - return arg[:-len(ext)] + return arg[: -len(ext)] return None diff --git a/mypy/fixup.py b/mypy/fixup.py index 85c1df079a5a5..d138b007bd002 100644 --- a/mypy/fixup.py +++ b/mypy/fixup.py @@ -1,27 +1,51 @@ """Fix up various things after deserialization.""" from typing import Any, Dict, Optional + from typing_extensions import Final +from mypy.lookup import lookup_fully_qualified from mypy.nodes import ( - MypyFile, SymbolTable, TypeInfo, FuncDef, OverloadedFuncDef, - Decorator, Var, TypeVarExpr, ClassDef, Block, TypeAlias, + Block, + ClassDef, + Decorator, + FuncDef, + MypyFile, + OverloadedFuncDef, + SymbolTable, + TypeAlias, + TypeInfo, + TypeVarExpr, + Var, ) from mypy.types import ( - CallableType, Instance, Overloaded, TupleType, TypedDictType, - TypeVarType, UnboundType, UnionType, TypeVisitor, LiteralType, - TypeType, NOT_READY, TypeAliasType, AnyType, TypeOfAny, ParamSpecType, - Parameters, UnpackType, TypeVarTupleType + NOT_READY, + AnyType, + CallableType, + Instance, + LiteralType, + Overloaded, + Parameters, + ParamSpecType, + TupleType, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UnionType, + UnpackType, ) from mypy.visitor import NodeVisitor -from mypy.lookup import lookup_fully_qualified # N.B: we do a allow_missing fixup when fixing up a fine-grained # incremental cache load (since there may be cross-refs into deleted # modules) -def fixup_module(tree: MypyFile, modules: Dict[str, MypyFile], - allow_missing: bool) -> None: +def fixup_module(tree: MypyFile, modules: Dict[str, MypyFile], allow_missing: bool) -> None: node_fixer = NodeFixer(modules, allow_missing) node_fixer.visit_symbol_table(tree.names, tree.fullname) @@ -59,9 +83,12 @@ def visit_type_info(self, info: TypeInfo) -> None: if info.metaclass_type: info.metaclass_type.accept(self.type_fixer) if info._mro_refs: - info.mro = [lookup_fully_qualified_typeinfo(self.modules, name, - allow_missing=self.allow_missing) - for name in info._mro_refs] + info.mro = [ + lookup_fully_qualified_typeinfo( + self.modules, name, allow_missing=self.allow_missing + ) + for name in info._mro_refs + ] info._mro_refs = None finally: self.current_info = save_info @@ -76,8 +103,9 @@ def visit_symbol_table(self, symtab: SymbolTable, table_fullname: str) -> None: if cross_ref in self.modules: value.node = self.modules[cross_ref] else: - stnode = lookup_fully_qualified(cross_ref, self.modules, - raise_on_missing=not self.allow_missing) + stnode = lookup_fully_qualified( + cross_ref, self.modules, raise_on_missing=not self.allow_missing + ) if stnode is not None: assert stnode.node is not None, (table_fullname + "." + key, cross_ref) value.node = stnode.node @@ -93,7 +121,7 @@ def visit_symbol_table(self, symtab: SymbolTable, table_fullname: str) -> None: elif value.node is not None: value.node.accept(self) else: - assert False, f'Unexpected empty node {key!r}: {value}' + assert False, f"Unexpected empty node {key!r}: {value}" def visit_func_def(self, func: FuncDef) -> None: if self.current_info is not None: @@ -154,8 +182,9 @@ def visit_instance(self, inst: Instance) -> None: if type_ref is None: return # We've already been here. inst.type_ref = None - inst.type = lookup_fully_qualified_typeinfo(self.modules, type_ref, - allow_missing=self.allow_missing) + inst.type = lookup_fully_qualified_typeinfo( + self.modules, type_ref, allow_missing=self.allow_missing + ) # TODO: Is this needed or redundant? # Also fix up the bases, just in case. for base in inst.type.bases: @@ -171,8 +200,9 @@ def visit_type_alias_type(self, t: TypeAliasType) -> None: if type_ref is None: return # We've already been here. t.type_ref = None - t.alias = lookup_fully_qualified_alias(self.modules, type_ref, - allow_missing=self.allow_missing) + t.alias = lookup_fully_qualified_alias( + self.modules, type_ref, allow_missing=self.allow_missing + ) for a in t.args: a.accept(self) @@ -229,11 +259,17 @@ def visit_typeddict_type(self, tdt: TypedDictType) -> None: it.accept(self) if tdt.fallback is not None: if tdt.fallback.type_ref is not None: - if lookup_fully_qualified(tdt.fallback.type_ref, self.modules, - raise_on_missing=not self.allow_missing) is None: + if ( + lookup_fully_qualified( + tdt.fallback.type_ref, + self.modules, + raise_on_missing=not self.allow_missing, + ) + is None + ): # We reject fake TypeInfos for TypedDict fallbacks because # the latter are used in type checking and must be valid. - tdt.fallback.type_ref = 'typing._TypedDict' + tdt.fallback.type_ref = "typing._TypedDict" tdt.fallback.accept(self) def visit_literal_type(self, lt: LiteralType) -> None: @@ -278,33 +314,37 @@ def visit_type_type(self, t: TypeType) -> None: t.item.accept(self) -def lookup_fully_qualified_typeinfo(modules: Dict[str, MypyFile], name: str, *, - allow_missing: bool) -> TypeInfo: +def lookup_fully_qualified_typeinfo( + modules: Dict[str, MypyFile], name: str, *, allow_missing: bool +) -> TypeInfo: stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing) node = stnode.node if stnode else None if isinstance(node, TypeInfo): return node else: # Looks like a missing TypeInfo during an initial daemon load, put something there - assert allow_missing, "Should never get here in normal mode," \ - " got {}:{} instead of TypeInfo".format(type(node).__name__, - node.fullname if node - else '') + assert ( + allow_missing + ), "Should never get here in normal mode," " got {}:{} instead of TypeInfo".format( + type(node).__name__, node.fullname if node else "" + ) return missing_info(modules) -def lookup_fully_qualified_alias(modules: Dict[str, MypyFile], name: str, *, - allow_missing: bool) -> TypeAlias: +def lookup_fully_qualified_alias( + modules: Dict[str, MypyFile], name: str, *, allow_missing: bool +) -> TypeAlias: stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing) node = stnode.node if stnode else None if isinstance(node, TypeAlias): return node else: # Looks like a missing TypeAlias during an initial daemon load, put something there - assert allow_missing, "Should never get here in normal mode," \ - " got {}:{} instead of TypeAlias".format(type(node).__name__, - node.fullname if node - else '') + assert ( + allow_missing + ), "Should never get here in normal mode," " got {}:{} instead of TypeAlias".format( + type(node).__name__, node.fullname if node else "" + ) return missing_alias() @@ -312,18 +352,17 @@ def lookup_fully_qualified_alias(modules: Dict[str, MypyFile], name: str, *, def missing_info(modules: Dict[str, MypyFile]) -> TypeInfo: - suggestion = _SUGGESTION.format('info') + suggestion = _SUGGESTION.format("info") dummy_def = ClassDef(suggestion, Block([])) dummy_def.fullname = suggestion info = TypeInfo(SymbolTable(), dummy_def, "") - obj_type = lookup_fully_qualified_typeinfo(modules, 'builtins.object', allow_missing=False) + obj_type = lookup_fully_qualified_typeinfo(modules, "builtins.object", allow_missing=False) info.bases = [Instance(obj_type, [])] info.mro = [info, obj_type] return info def missing_alias() -> TypeAlias: - suggestion = _SUGGESTION.format('alias') - return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, - line=-1, column=-1) + suggestion = _SUGGESTION.format("alias") + return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, line=-1, column=-1) diff --git a/mypy/freetree.py b/mypy/freetree.py index 28409ffbfddba..07eb4cf0ceb67 100644 --- a/mypy/freetree.py +++ b/mypy/freetree.py @@ -1,7 +1,7 @@ """Generic node traverser visitor""" -from mypy.traverser import TraverserVisitor from mypy.nodes import Block, MypyFile +from mypy.traverser import TraverserVisitor class TreeFreer(TraverserVisitor): diff --git a/mypy/fscache.py b/mypy/fscache.py index d0be1abd8cb93..365ca80d334e9 100644 --- a/mypy/fscache.py +++ b/mypy/fscache.py @@ -31,9 +31,11 @@ import os import stat from typing import Dict, List, Set -from mypy.util import hash_digest + from mypy_extensions import mypyc_attr +from mypy.util import hash_digest + @mypyc_attr(allow_interpreted_subclasses=True) # for tests class FileSystemCache: @@ -104,7 +106,7 @@ def init_under_package_root(self, path: str) -> bool: if not self.package_root: return False dirname, basename = os.path.split(path) - if basename != '__init__.py': + if basename != "__init__.py": return False if not os.path.basename(dirname).isidentifier(): # Can't put an __init__.py in a place that's not an identifier @@ -139,7 +141,7 @@ def _fake_init(self, path: str) -> os.stat_result: init_under_package_root() returns True. """ dirname, basename = os.path.split(path) - assert basename == '__init__.py', path + assert basename == "__init__.py", path assert not os.path.exists(path), path # Not cached! dirname = os.path.normpath(dirname) st = self.stat(dirname) # May raise OSError @@ -160,8 +162,8 @@ def listdir(self, path: str) -> List[str]: if path in self.listdir_cache: res = self.listdir_cache[path] # Check the fake cache. - if path in self.fake_package_cache and '__init__.py' not in res: - res.append('__init__.py') # Updates the result as well as the cache + if path in self.fake_package_cache and "__init__.py" not in res: + res.append("__init__.py") # Updates the result as well as the cache return res if path in self.listdir_error_cache: raise copy_os_error(self.listdir_error_cache[path]) @@ -173,8 +175,8 @@ def listdir(self, path: str) -> List[str]: raise err self.listdir_cache[path] = results # Check the fake cache. - if path in self.fake_package_cache and '__init__.py' not in results: - results.append('__init__.py') + if path in self.fake_package_cache and "__init__.py" not in results: + results.append("__init__.py") return results def isfile(self, path: str) -> bool: @@ -271,11 +273,11 @@ def read(self, path: str) -> bytes: dirname, basename = os.path.split(path) dirname = os.path.normpath(dirname) # Check the fake cache. - if basename == '__init__.py' and dirname in self.fake_package_cache: - data = b'' + if basename == "__init__.py" and dirname in self.fake_package_cache: + data = b"" else: try: - with open(path, 'rb') as f: + with open(path, "rb") as f: data = f.read() except OSError as err: self.read_error_cache[path] = err diff --git a/mypy/fswatcher.py b/mypy/fswatcher.py index 21ec306eea6a4..8144a7f43caa3 100644 --- a/mypy/fswatcher.py +++ b/mypy/fswatcher.py @@ -1,8 +1,9 @@ """Watch parts of the file system for changes.""" -from mypy.fscache import FileSystemCache from typing import AbstractSet, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple +from mypy.fscache import FileSystemCache + class FileData(NamedTuple): st_mtime: float @@ -89,10 +90,7 @@ def find_changed(self) -> AbstractSet[str]: """Return paths that have changes since the last call, in the watched set.""" return self._find_changed(self._paths) - def update_changed(self, - remove: List[str], - update: List[str], - ) -> AbstractSet[str]: + def update_changed(self, remove: List[str], update: List[str]) -> AbstractSet[str]: """Alternative to find_changed() given explicit changes. This only calls self.fs.stat() on added or updated files, not diff --git a/mypy/gclogger.py b/mypy/gclogger.py index b8d7980f5f435..65508d2fda7a9 100644 --- a/mypy/gclogger.py +++ b/mypy/gclogger.py @@ -1,13 +1,12 @@ import gc import time - from typing import Mapping, Optional class GcLogger: """Context manager to log GC stats and overall time.""" - def __enter__(self) -> 'GcLogger': + def __enter__(self) -> "GcLogger": self.gc_start_time: Optional[float] = None self.gc_time = 0.0 self.gc_calls = 0 @@ -18,16 +17,16 @@ def __enter__(self) -> 'GcLogger': return self def gc_callback(self, phase: str, info: Mapping[str, int]) -> None: - if phase == 'start': + if phase == "start": assert self.gc_start_time is None, "Start phase out of sequence" self.gc_start_time = time.time() - elif phase == 'stop': + elif phase == "stop": assert self.gc_start_time is not None, "Stop phase out of sequence" self.gc_calls += 1 self.gc_time += time.time() - self.gc_start_time self.gc_start_time = None - self.gc_collected += info['collected'] - self.gc_uncollectable += info['uncollectable'] + self.gc_collected += info["collected"] + self.gc_uncollectable += info["uncollectable"] else: assert False, f"Unrecognized gc phase ({phase!r})" @@ -38,9 +37,9 @@ def __exit__(self, *args: object) -> None: def get_stats(self) -> Mapping[str, float]: end_time = time.time() result = {} - result['gc_time'] = self.gc_time - result['gc_calls'] = self.gc_calls - result['gc_collected'] = self.gc_collected - result['gc_uncollectable'] = self.gc_uncollectable - result['build_time'] = end_time - self.start_time + result["gc_time"] = self.gc_time + result["gc_calls"] = self.gc_calls + result["gc_collected"] = self.gc_collected + result["gc_uncollectable"] = self.gc_uncollectable + result["build_time"] = end_time - self.start_time return result diff --git a/mypy/indirection.py b/mypy/indirection.py index 56c1f97928f2e..c241e55698ffd 100644 --- a/mypy/indirection.py +++ b/mypy/indirection.py @@ -1,7 +1,7 @@ from typing import Dict, Iterable, List, Optional, Set, Union -from mypy.types import TypeVisitor import mypy.types as types +from mypy.types import TypeVisitor from mypy.util import split_module_names diff --git a/mypy/infer.py b/mypy/infer.py index ca521e211493b..d3ad0bc19f9b3 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -1,13 +1,16 @@ """Utilities for type argument inference.""" -from typing import List, Optional, Sequence, NamedTuple +from typing import List, NamedTuple, Optional, Sequence from mypy.constraints import ( - infer_constraints, infer_constraints_for_callable, SUBTYPE_OF, SUPERTYPE_OF + SUBTYPE_OF, + SUPERTYPE_OF, + infer_constraints, + infer_constraints_for_callable, ) -from mypy.types import Type, TypeVarId, CallableType, Instance from mypy.nodes import ArgKind from mypy.solve import solve_constraints +from mypy.types import CallableType, Instance, Type, TypeVarId class ArgumentInferContext(NamedTuple): @@ -24,12 +27,14 @@ class ArgumentInferContext(NamedTuple): iterable_type: Instance -def infer_function_type_arguments(callee_type: CallableType, - arg_types: Sequence[Optional[Type]], - arg_kinds: List[ArgKind], - formal_to_actual: List[List[int]], - context: ArgumentInferContext, - strict: bool = True) -> List[Optional[Type]]: +def infer_function_type_arguments( + callee_type: CallableType, + arg_types: Sequence[Optional[Type]], + arg_kinds: List[ArgKind], + formal_to_actual: List[List[int]], + context: ArgumentInferContext, + strict: bool = True, +) -> List[Optional[Type]]: """Infer the type arguments of a generic function. Return an array of lower bound types for the type variables -1 (at @@ -45,18 +50,18 @@ def infer_function_type_arguments(callee_type: CallableType, """ # Infer constraints. constraints = infer_constraints_for_callable( - callee_type, arg_types, arg_kinds, formal_to_actual, context) + callee_type, arg_types, arg_kinds, formal_to_actual, context + ) # Solve constraints. type_vars = callee_type.type_var_ids() return solve_constraints(type_vars, constraints, strict) -def infer_type_arguments(type_var_ids: List[TypeVarId], - template: Type, actual: Type, - is_supertype: bool = False) -> List[Optional[Type]]: +def infer_type_arguments( + type_var_ids: List[TypeVarId], template: Type, actual: Type, is_supertype: bool = False +) -> List[Optional[Type]]: # Like infer_function_type_arguments, but only match a single type # against a generic type. - constraints = infer_constraints(template, actual, - SUPERTYPE_OF if is_supertype else SUBTYPE_OF) + constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) return solve_constraints(type_var_ids, constraints) diff --git a/mypy/ipc.py b/mypy/ipc.py index bf8bfa43b62af..08efc8c461cb1 100644 --- a/mypy/ipc.py +++ b/mypy/ipc.py @@ -9,17 +9,17 @@ import shutil import sys import tempfile +from types import TracebackType +from typing import Callable, Optional -from typing import Optional, Callable from typing_extensions import Final, Type -from types import TracebackType - -if sys.platform == 'win32': +if sys.platform == "win32": # This may be private, but it is needed for IPC on Windows, and is basically stable - import _winapi import ctypes + import _winapi + _IPCHandle = int kernel32 = ctypes.windll.kernel32 @@ -27,11 +27,13 @@ FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers else: import socket + _IPCHandle = socket.socket class IPCException(Exception): """Exception for IPC issues.""" + pass @@ -51,7 +53,7 @@ def __init__(self, name: str, timeout: Optional[float]) -> None: def read(self, size: int = 100000) -> bytes: """Read bytes from an IPC connection until its empty.""" bdata = bytearray() - if sys.platform == 'win32': + if sys.platform == "win32": while True: ov, err = _winapi.ReadFile(self.connection, size, overlapped=True) try: @@ -85,7 +87,7 @@ def read(self, size: int = 100000) -> bytes: def write(self, data: bytes) -> None: """Write bytes to an IPC connection.""" - if sys.platform == 'win32': + if sys.platform == "win32": try: ov, err = _winapi.WriteFile(self.connection, data, overlapped=True) # TODO: remove once typeshed supports Literal types @@ -112,7 +114,7 @@ def write(self, data: bytes) -> None: self.connection.shutdown(socket.SHUT_WR) def close(self) -> None: - if sys.platform == 'win32': + if sys.platform == "win32": if self.connection != _winapi.NULL: _winapi.CloseHandle(self.connection) else: @@ -124,7 +126,7 @@ class IPCClient(IPCBase): def __init__(self, name: str, timeout: Optional[float]) -> None: super().__init__(name, timeout) - if sys.platform == 'win32': + if sys.platform == "win32": timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER try: _winapi.WaitNamedPipe(self.name, timeout) @@ -150,39 +152,41 @@ def __init__(self, name: str, timeout: Optional[float]) -> None: raise IPCException("The connection is busy.") from e else: raise - _winapi.SetNamedPipeHandleState(self.connection, - _winapi.PIPE_READMODE_MESSAGE, - None, - None) + _winapi.SetNamedPipeHandleState( + self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None + ) else: self.connection = socket.socket(socket.AF_UNIX) self.connection.settimeout(timeout) self.connection.connect(name) - def __enter__(self) -> 'IPCClient': + def __enter__(self) -> "IPCClient": return self - def __exit__(self, - exc_ty: 'Optional[Type[BaseException]]' = None, - exc_val: Optional[BaseException] = None, - exc_tb: Optional[TracebackType] = None, - ) -> None: + def __exit__( + self, + exc_ty: "Optional[Type[BaseException]]" = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: self.close() class IPCServer(IPCBase): - BUFFER_SIZE: Final = 2 ** 16 + BUFFER_SIZE: Final = 2**16 def __init__(self, name: str, timeout: Optional[float] = None) -> None: - if sys.platform == 'win32': - name = r'\\.\pipe\{}-{}.pipe'.format( - name, base64.urlsafe_b64encode(os.urandom(6)).decode()) + if sys.platform == "win32": + name = r"\\.\pipe\{}-{}.pipe".format( + name, base64.urlsafe_b64encode(os.urandom(6)).decode() + ) else: - name = f'{name}.sock' + name = f"{name}.sock" super().__init__(name, timeout) - if sys.platform == 'win32': - self.connection = _winapi.CreateNamedPipe(self.name, + if sys.platform == "win32": + self.connection = _winapi.CreateNamedPipe( + self.name, _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE | _winapi.FILE_FLAG_OVERLAPPED, @@ -195,10 +199,10 @@ def __init__(self, name: str, timeout: Optional[float] = None) -> None: self.BUFFER_SIZE, _winapi.NMPWAIT_WAIT_FOREVER, 0, # Use default security descriptor - ) + ) if self.connection == -1: # INVALID_HANDLE_VALUE err = _winapi.GetLastError() - raise IPCException(f'Invalid handle to pipe: {err}') + raise IPCException(f"Invalid handle to pipe: {err}") else: self.sock_directory = tempfile.mkdtemp() sockfile = os.path.join(self.sock_directory, self.name) @@ -208,8 +212,8 @@ def __init__(self, name: str, timeout: Optional[float] = None) -> None: if timeout is not None: self.sock.settimeout(timeout) - def __enter__(self) -> 'IPCServer': - if sys.platform == 'win32': + def __enter__(self) -> "IPCServer": + if sys.platform == "win32": # NOTE: It is theoretically possible that this will hang forever if the # client never connects, though this can be "solved" by killing the server try: @@ -235,34 +239,36 @@ def __enter__(self) -> 'IPCServer': try: self.connection, _ = self.sock.accept() except socket.timeout as e: - raise IPCException('The socket timed out') from e + raise IPCException("The socket timed out") from e return self - def __exit__(self, - exc_ty: 'Optional[Type[BaseException]]' = None, - exc_val: Optional[BaseException] = None, - exc_tb: Optional[TracebackType] = None, - ) -> None: - if sys.platform == 'win32': + def __exit__( + self, + exc_ty: "Optional[Type[BaseException]]" = None, + exc_val: Optional[BaseException] = None, + exc_tb: Optional[TracebackType] = None, + ) -> None: + if sys.platform == "win32": try: # Wait for the client to finish reading the last write before disconnecting if not FlushFileBuffers(self.connection): - raise IPCException("Failed to flush NamedPipe buffer," - "maybe the client hung up?") + raise IPCException( + "Failed to flush NamedPipe buffer," "maybe the client hung up?" + ) finally: DisconnectNamedPipe(self.connection) else: self.close() def cleanup(self) -> None: - if sys.platform == 'win32': + if sys.platform == "win32": self.close() else: shutil.rmtree(self.sock_directory) @property def connection_name(self) -> str: - if sys.platform == 'win32': + if sys.platform == "win32": return self.name else: return self.sock.getsockname() diff --git a/mypy/join.py b/mypy/join.py index 70c250a7703c3..31f31ed887144 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -1,23 +1,50 @@ """Calculation of the least upper bound types (joins).""" -from mypy.backports import OrderedDict from typing import List, Optional, Tuple -from mypy.types import ( - Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType, - TupleType, TypedDictType, ErasedType, UnionType, FunctionLike, Overloaded, LiteralType, - PartialType, DeletedType, UninhabitedType, TypeType, TypeOfAny, get_proper_type, - ProperType, get_proper_types, TypeAliasType, PlaceholderType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, -) +import mypy.typeops +from mypy.backports import OrderedDict from mypy.maptype import map_instance_to_supertype +from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT +from mypy.state import state from mypy.subtypes import ( - is_subtype, is_equivalent, is_proper_subtype, - is_protocol_implementation, find_member + find_member, + is_equivalent, + is_proper_subtype, + is_protocol_implementation, + is_subtype, +) +from mypy.types import ( + AnyType, + CallableType, + DeletedType, + ErasedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + PlaceholderType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, + get_proper_types, ) -from mypy.nodes import INVARIANT, COVARIANT, CONTRAVARIANT -import mypy.typeops -from mypy.state import state class InstanceJoiner: @@ -315,11 +342,15 @@ def visit_callable_type(self, t: CallableType) -> ProperType: result = join_similar_callables(t, self.s) # We set the from_type_type flag to suppress error when a collection of # concrete class objects gets inferred as their common abstract superclass. - if not ((t.is_type_obj() and t.type_object().is_abstract) or - (self.s.is_type_obj() and self.s.type_object().is_abstract)): + if not ( + (t.is_type_obj() and t.type_object().is_abstract) + or (self.s.is_type_obj() and self.s.type_object().is_abstract) + ): result.from_type_type = True - if any(isinstance(tp, (NoneType, UninhabitedType)) - for tp in get_proper_types(result.arg_types)): + if any( + isinstance(tp, (NoneType, UninhabitedType)) + for tp in get_proper_types(result.arg_types) + ): # We don't want to return unusable Callable, attempt fallback instead. return join_types(t.fallback, self.s) return result @@ -400,8 +431,9 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: if isinstance(self.s, TupleType) and self.s.length() == t.length(): if self.instance_joiner is None: self.instance_joiner = InstanceJoiner() - fallback = self.instance_joiner.join_instances(mypy.typeops.tuple_fallback(self.s), - mypy.typeops.tuple_fallback(t)) + fallback = self.instance_joiner.join_instances( + mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t) + ) assert isinstance(fallback, Instance) if self.s.length() == t.length(): items: List[Type] = [] @@ -415,12 +447,16 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: def visit_typeddict_type(self, t: TypedDictType) -> ProperType: if isinstance(self.s, TypedDictType): - items = OrderedDict([ - (item_name, s_item_type) - for (item_name, s_item_type, t_item_type) in self.s.zip(t) - if (is_equivalent(s_item_type, t_item_type) and - (item_name in t.required_keys) == (item_name in self.s.required_keys)) - ]) + items = OrderedDict( + [ + (item_name, s_item_type) + for (item_name, s_item_type, t_item_type) in self.s.zip(t) + if ( + is_equivalent(s_item_type, t_item_type) + and (item_name in t.required_keys) == (item_name in self.s.required_keys) + ) + ] + ) fallback = self.s.create_anonymous_fallback() # We need to filter by items.keys() since some required keys present in both t and # self.s might be missing from the join if the types are incompatible. @@ -449,7 +485,7 @@ def visit_partial_type(self, t: PartialType) -> ProperType: def visit_type_type(self, t: TypeType) -> ProperType: if isinstance(self.s, TypeType): return TypeType.make_normalized(self.join(t.item, self.s.item), line=t.line) - elif isinstance(self.s, Instance) and self.s.type.fullname == 'builtins.type': + elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type": return self.s else: return self.default(self.s) @@ -499,8 +535,11 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool: """Return True if t and s have identical numbers of arguments, default arguments and varargs. """ - return (len(t.arg_types) == len(s.arg_types) and t.min_args == s.min_args and - t.is_var_arg == s.is_var_arg) + return ( + len(t.arg_types) == len(s.arg_types) + and t.min_args == s.min_args + and t.is_var_arg == s.is_var_arg + ) def join_similar_callables(t: CallableType, s: CallableType) -> CallableType: @@ -512,15 +551,17 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType: # TODO in combine_similar_callables also applies here (names and kinds) # The fallback type can be either 'function' or 'type'. The result should have 'type' as # fallback only if both operands have it as 'type'. - if t.fallback.type.fullname != 'builtins.type': + if t.fallback.type.fullname != "builtins.type": fallback = t.fallback else: fallback = s.fallback - return t.copy_modified(arg_types=arg_types, - arg_names=combine_arg_names(t, s), - ret_type=join_types(t.ret_type, s.ret_type), - fallback=fallback, - name=None) + return t.copy_modified( + arg_types=arg_types, + arg_names=combine_arg_names(t, s), + ret_type=join_types(t.ret_type, s.ret_type), + fallback=fallback, + name=None, + ) def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType: @@ -530,15 +571,17 @@ def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType: # TODO kinds and argument names # The fallback type can be either 'function' or 'type'. The result should have 'type' as # fallback only if both operands have it as 'type'. - if t.fallback.type.fullname != 'builtins.type': + if t.fallback.type.fullname != "builtins.type": fallback = t.fallback else: fallback = s.fallback - return t.copy_modified(arg_types=arg_types, - arg_names=combine_arg_names(t, s), - ret_type=join_types(t.ret_type, s.ret_type), - fallback=fallback, - name=None) + return t.copy_modified( + arg_types=arg_types, + arg_names=combine_arg_names(t, s), + ret_type=join_types(t.ret_type, s.ret_type), + fallback=fallback, + name=None, + ) def combine_arg_names(t: CallableType, s: CallableType) -> List[Optional[str]]: @@ -612,6 +655,6 @@ def join_type_list(types: List[Type]) -> ProperType: def unpack_callback_protocol(t: Instance) -> Optional[Type]: assert t.type.is_protocol - if t.type.protocol_members == ['__call__']: - return find_member('__call__', t, t, is_operator=True) + if t.type.protocol_members == ["__call__"]: + return find_member("__call__", t, t, is_operator=True) return None diff --git a/mypy/literals.py b/mypy/literals.py index e20e37412ab27..2c272edc5fabf 100644 --- a/mypy/literals.py +++ b/mypy/literals.py @@ -1,15 +1,58 @@ -from typing import Optional, Union, Any, Tuple, Iterable +from typing import Any, Iterable, Optional, Tuple, Union + from typing_extensions import Final from mypy.nodes import ( - Expression, ComparisonExpr, OpExpr, MemberExpr, UnaryExpr, StarExpr, IndexExpr, LITERAL_YES, - LITERAL_NO, NameExpr, LITERAL_TYPE, IntExpr, FloatExpr, ComplexExpr, StrExpr, BytesExpr, - UnicodeExpr, ListExpr, TupleExpr, SetExpr, DictExpr, CallExpr, SliceExpr, CastExpr, - ConditionalExpr, EllipsisExpr, YieldFromExpr, YieldExpr, RevealExpr, SuperExpr, - TypeApplication, LambdaExpr, ListComprehension, SetComprehension, DictionaryComprehension, - GeneratorExpr, BackquoteExpr, TypeVarExpr, TypeAliasExpr, NamedTupleExpr, EnumCallExpr, - TypedDictExpr, NewTypeExpr, PromoteExpr, AwaitExpr, TempNode, AssignmentExpr, ParamSpecExpr, - AssertTypeExpr, TypeVarTupleExpr, + LITERAL_NO, + LITERAL_TYPE, + LITERAL_YES, + AssertTypeExpr, + AssignmentExpr, + AwaitExpr, + BackquoteExpr, + BytesExpr, + CallExpr, + CastExpr, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + EnumCallExpr, + Expression, + FloatExpr, + GeneratorExpr, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MemberExpr, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + OpExpr, + ParamSpecExpr, + PromoteExpr, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + StrExpr, + SuperExpr, + TempNode, + TupleExpr, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeVarExpr, + TypeVarTupleExpr, + UnaryExpr, + UnicodeExpr, + YieldExpr, + YieldFromExpr, ) from mypy.visitor import ExpressionVisitor @@ -96,45 +139,45 @@ def literal_hash(e: Expression) -> Optional[Key]: class _Hasher(ExpressionVisitor[Optional[Key]]): def visit_int_expr(self, e: IntExpr) -> Key: - return ('Literal', e.value) + return ("Literal", e.value) def visit_str_expr(self, e: StrExpr) -> Key: - return ('Literal', e.value, e.from_python_3) + return ("Literal", e.value, e.from_python_3) def visit_bytes_expr(self, e: BytesExpr) -> Key: - return ('Literal', e.value) + return ("Literal", e.value) def visit_unicode_expr(self, e: UnicodeExpr) -> Key: - return ('Literal', e.value) + return ("Literal", e.value) def visit_float_expr(self, e: FloatExpr) -> Key: - return ('Literal', e.value) + return ("Literal", e.value) def visit_complex_expr(self, e: ComplexExpr) -> Key: - return ('Literal', e.value) + return ("Literal", e.value) def visit_star_expr(self, e: StarExpr) -> Key: - return ('Star', literal_hash(e.expr)) + return ("Star", literal_hash(e.expr)) def visit_name_expr(self, e: NameExpr) -> Key: # N.B: We use the node itself as the key, and not the name, # because using the name causes issues when there is shadowing # (for example, in list comprehensions). - return ('Var', e.node) + return ("Var", e.node) def visit_member_expr(self, e: MemberExpr) -> Key: - return ('Member', literal_hash(e.expr), e.name) + return ("Member", literal_hash(e.expr), e.name) def visit_op_expr(self, e: OpExpr) -> Key: - return ('Binary', e.op, literal_hash(e.left), literal_hash(e.right)) + return ("Binary", e.op, literal_hash(e.left), literal_hash(e.right)) def visit_comparison_expr(self, e: ComparisonExpr) -> Key: rest: Any = tuple(e.operators) rest += tuple(literal_hash(o) for o in e.operands) - return ('Comparison',) + rest + return ("Comparison",) + rest def visit_unary_expr(self, e: UnaryExpr) -> Key: - return ('Unary', e.op, literal_hash(e.expr)) + return ("Unary", e.op, literal_hash(e.expr)) def seq_expr(self, e: Union[ListExpr, TupleExpr, SetExpr], name: str) -> Optional[Key]: if all(literal(x) == LITERAL_YES for x in e.items): @@ -143,7 +186,7 @@ def seq_expr(self, e: Union[ListExpr, TupleExpr, SetExpr], name: str) -> Optiona return None def visit_list_expr(self, e: ListExpr) -> Optional[Key]: - return self.seq_expr(e, 'List') + return self.seq_expr(e, "List") def visit_dict_expr(self, e: DictExpr) -> Optional[Key]: if all(a and literal(a) == literal(b) == LITERAL_YES for a, b in e.items): @@ -154,14 +197,14 @@ def visit_dict_expr(self, e: DictExpr) -> Optional[Key]: return None def visit_tuple_expr(self, e: TupleExpr) -> Optional[Key]: - return self.seq_expr(e, 'Tuple') + return self.seq_expr(e, "Tuple") def visit_set_expr(self, e: SetExpr) -> Optional[Key]: - return self.seq_expr(e, 'Set') + return self.seq_expr(e, "Set") def visit_index_expr(self, e: IndexExpr) -> Optional[Key]: if literal(e.index) == LITERAL_YES: - return ('Index', literal_hash(e.base), literal_hash(e.index)) + return ("Index", literal_hash(e.base), literal_hash(e.index)) return None def visit_assignment_expr(self, e: AssignmentExpr) -> Optional[Key]: diff --git a/mypy/lookup.py b/mypy/lookup.py index 8a8350080bc26..aa555ad113235 100644 --- a/mypy/lookup.py +++ b/mypy/lookup.py @@ -3,14 +3,16 @@ functions that will find a semantic node by its name. """ -from mypy.nodes import MypyFile, SymbolTableNode, TypeInfo from typing import Dict, Optional +from mypy.nodes import MypyFile, SymbolTableNode, TypeInfo + # TODO: gradually move existing lookup functions to this module. -def lookup_fully_qualified(name: str, modules: Dict[str, MypyFile], *, - raise_on_missing: bool = False) -> Optional[SymbolTableNode]: +def lookup_fully_qualified( + name: str, modules: Dict[str, MypyFile], *, raise_on_missing: bool = False +) -> Optional[SymbolTableNode]: """Find a symbol using it fully qualified name. The algorithm has two steps: first we try splitting the name on '.' to find @@ -24,11 +26,11 @@ def lookup_fully_qualified(name: str, modules: Dict[str, MypyFile], *, rest = [] # 1. Find a module tree in modules dictionary. while True: - if '.' not in head: + if "." not in head: if raise_on_missing: - assert '.' in head, f"Cannot find module for {name}" + assert "." in head, f"Cannot find module for {name}" return None - head, tail = head.rsplit('.', maxsplit=1) + head, tail = head.rsplit(".", maxsplit=1) rest.append(tail) mod = modules.get(head) if mod is not None: diff --git a/mypy/main.py b/mypy/main.py index 619147a1c2770..85a1eb0765eb7 100644 --- a/mypy/main.py +++ b/mypy/main.py @@ -1,31 +1,24 @@ """Mypy type checker command line tool.""" import argparse -from gettext import gettext import os import subprocess import sys import time +from gettext import gettext +from typing import IO, Any, Dict, List, Optional, Sequence, TextIO, Tuple, Union -from typing import Any, Dict, IO, List, Optional, Sequence, Tuple, TextIO, Union from typing_extensions import Final, NoReturn -from mypy import build -from mypy import defaults -from mypy import state -from mypy import util -from mypy.modulefinder import ( - BuildSource, FindModuleCache, SearchPaths, - get_search_dirs, mypy_path, -) -from mypy.find_sources import create_source_list, InvalidSourceList -from mypy.fscache import FileSystemCache -from mypy.errors import CompileError +from mypy import build, defaults, state, util +from mypy.config_parser import get_config_module_names, parse_config_file, parse_version from mypy.errorcodes import error_codes -from mypy.options import Options, BuildType -from mypy.config_parser import get_config_module_names, parse_version, parse_config_file +from mypy.errors import CompileError +from mypy.find_sources import InvalidSourceList, create_source_list +from mypy.fscache import FileSystemCache +from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths, get_search_dirs, mypy_path +from mypy.options import BuildType, Options from mypy.split_namespace import SplitNamespace - from mypy.version import __version__ orig_stat: Final = os.stat @@ -39,17 +32,20 @@ def stat_proxy(path: str) -> os.stat_result: print(f"stat({path!r}) -> {err}") raise else: - print("stat(%r) -> (st_mode=%o, st_mtime=%d, st_size=%d)" % - (path, st.st_mode, st.st_mtime, st.st_size)) + print( + "stat(%r) -> (st_mode=%o, st_mtime=%d, st_size=%d)" + % (path, st.st_mode, st.st_mtime, st.st_size) + ) return st -def main(script_path: Optional[str], - stdout: TextIO, - stderr: TextIO, - args: Optional[List[str]] = None, - clean_exit: bool = False, - ) -> None: +def main( + script_path: Optional[str], + stdout: TextIO, + stderr: TextIO, + args: Optional[List[str]] = None, + clean_exit: bool = False, +) -> None: """Main entry point to the type checker. Args: @@ -59,16 +55,15 @@ def main(script_path: Optional[str], clean_exit: Don't hard kill the process on exit. This allows catching SystemExit. """ - util.check_python_version('mypy') + util.check_python_version("mypy") t0 = time.time() # To log stat() calls: os.stat = stat_proxy - sys.setrecursionlimit(2 ** 14) + sys.setrecursionlimit(2**14) if args is None: args = sys.argv[1:] fscache = FileSystemCache() - sources, options = process_options(args, stdout=stdout, stderr=stderr, - fscache=fscache) + sources, options = process_options(args, stdout=stdout, stderr=stderr, fscache=fscache) if clean_exit: options.fast_exit = False @@ -82,12 +77,16 @@ def main(script_path: Optional[str], fail("error: --non-interactive is only supported with --install-types", stderr, options) if options.install_types and not options.incremental: - fail("error: --install-types not supported with incremental mode disabled", - stderr, options) + fail( + "error: --install-types not supported with incremental mode disabled", stderr, options + ) if options.install_types and options.python_executable is None: - fail("error: --install-types not supported without python executable or site packages", - stderr, options) + fail( + "error: --install-types not supported without python executable or site packages", + stderr, + options, + ) if options.install_types and not sources: install_types(formatter, options, non_interactive=options.non_interactive) @@ -107,6 +106,7 @@ def main(script_path: Optional[str], if MEM_PROFILE: from mypy.memprofile import print_memory_profile + print_memory_profile() code = 0 @@ -116,13 +116,12 @@ def main(script_path: Optional[str], n_errors, n_notes, n_files = util.count_stats(messages) if n_errors: summary = formatter.format_error( - n_errors, n_files, len(sources), blockers=blockers, - use_color=options.color_output + n_errors, n_files, len(sources), blockers=blockers, use_color=options.color_output ) - stdout.write(summary + '\n') + stdout.write(summary + "\n") # Only notes should also output success elif not messages or n_notes == len(messages): - stdout.write(formatter.format_success(len(sources), options.color_output) + '\n') + stdout.write(formatter.format_success(len(sources), options.color_output) + "\n") stdout.flush() if options.install_types and not options.non_interactive: @@ -144,12 +143,14 @@ def main(script_path: Optional[str], list([res]) -def run_build(sources: List[BuildSource], - options: Options, - fscache: FileSystemCache, - t0: float, - stdout: TextIO, - stderr: TextIO) -> Tuple[Optional[build.BuildResult], List[str], bool]: +def run_build( + sources: List[BuildSource], + options: Options, + fscache: FileSystemCache, + t0: float, + stdout: TextIO, + stderr: TextIO, +) -> Tuple[Optional[build.BuildResult], List[str], bool]: formatter = util.FancyFormatter(stdout, stderr, options.show_error_codes) messages = [] @@ -175,28 +176,38 @@ def flush_errors(new_messages: List[str], serious: bool) -> None: blockers = True if not e.use_stdout: serious = True - if (options.warn_unused_configs - and options.unused_configs - and not options.incremental - and not options.non_interactive): - print("Warning: unused section(s) in %s: %s" % - (options.config_file, - get_config_module_names(options.config_file, - [glob for glob in options.per_module_options.keys() - if glob in options.unused_configs])), - file=stderr) + if ( + options.warn_unused_configs + and options.unused_configs + and not options.incremental + and not options.non_interactive + ): + print( + "Warning: unused section(s) in %s: %s" + % ( + options.config_file, + get_config_module_names( + options.config_file, + [ + glob + for glob in options.per_module_options.keys() + if glob in options.unused_configs + ], + ), + ), + file=stderr, + ) maybe_write_junit_xml(time.time() - t0, serious, messages, options) return res, messages, blockers -def show_messages(messages: List[str], - f: TextIO, - formatter: util.FancyFormatter, - options: Options) -> None: +def show_messages( + messages: List[str], f: TextIO, formatter: util.FancyFormatter, options: Options +) -> None: for msg in messages: if options.color_output: msg = formatter.colorize(msg) - f.write(msg + '\n') + f.write(msg + "\n") f.flush() @@ -206,7 +217,7 @@ def __init__(self, prog: str) -> None: super().__init__(prog=prog, max_help_position=28) def _fill_text(self, text: str, width: int, indent: str) -> str: - if '\n' in text: + if "\n" in text: # Assume we want to manually format the text return super()._fill_text(text, width, indent) else: @@ -216,10 +227,7 @@ def _fill_text(self, text: str, width: int, indent: str) -> str: # Define pairs of flag prefixes with inverse meaning. -flag_prefix_pairs: Final = [ - ('allow', 'disallow'), - ('show', 'hide'), -] +flag_prefix_pairs: Final = [("allow", "disallow"), ("show", "hide")] flag_prefix_map: Final[Dict[str, str]] = {} for a, b in flag_prefix_pairs: flag_prefix_map[a] = b @@ -227,15 +235,15 @@ def _fill_text(self, text: str, width: int, indent: str) -> str: def invert_flag_name(flag: str) -> str: - split = flag[2:].split('-', 1) + split = flag[2:].split("-", 1) if len(split) == 2: prefix, rest = split if prefix in flag_prefix_map: - return f'--{flag_prefix_map[prefix]}-{rest}' - elif prefix == 'no': - return f'--{rest}' + return f"--{flag_prefix_map[prefix]}-{rest}" + elif prefix == "no": + return f"--{rest}" - return f'--no-{flag[2:]}' + return f"--no-{flag[2:]}" class PythonExecutableInferenceError(Exception): @@ -243,34 +251,38 @@ class PythonExecutableInferenceError(Exception): def python_executable_prefix(v: str) -> List[str]: - if sys.platform == 'win32': + if sys.platform == "win32": # on Windows, all Python executables are named `python`. To handle this, there # is the `py` launcher, which can be passed a version e.g. `py -3.8`, and it will # execute an installed Python 3.8 interpreter. See also: # https://docs.python.org/3/using/windows.html#python-launcher-for-windows - return ['py', f'-{v}'] + return ["py", f"-{v}"] else: - return [f'python{v}'] + return [f"python{v}"] def _python_executable_from_version(python_version: Tuple[int, int]) -> str: if sys.version_info[:2] == python_version: return sys.executable - str_ver = '.'.join(map(str, python_version)) + str_ver = ".".join(map(str, python_version)) try: - sys_exe = subprocess.check_output(python_executable_prefix(str_ver) + - ['-c', 'import sys; print(sys.executable)'], - stderr=subprocess.STDOUT).decode().strip() + sys_exe = ( + subprocess.check_output( + python_executable_prefix(str_ver) + ["-c", "import sys; print(sys.executable)"], + stderr=subprocess.STDOUT, + ) + .decode() + .strip() + ) return sys_exe except (subprocess.CalledProcessError, FileNotFoundError) as e: raise PythonExecutableInferenceError( - 'failed to find a Python executable matching version {},' - ' perhaps try --python-executable, or --no-site-packages?'.format(python_version) + "failed to find a Python executable matching version {}," + " perhaps try --python-executable, or --no-site-packages?".format(python_version) ) from e -def infer_python_executable(options: Options, - special_opts: argparse.Namespace) -> None: +def infer_python_executable(options: Options, special_opts: argparse.Namespace) -> None: """Infer the Python executable from the given version. This function mutates options based on special_opts to infer the correct Python executable @@ -331,8 +343,8 @@ class CapturableArgumentParser(argparse.ArgumentParser): """ def __init__(self, *args: Any, **kwargs: Any): - self.stdout = kwargs.pop('stdout', sys.stdout) - self.stderr = kwargs.pop('stderr', sys.stderr) + self.stdout = kwargs.pop("stdout", sys.stdout) + self.stderr = kwargs.pop("stderr", sys.stderr) super().__init__(*args, **kwargs) # ===================== @@ -372,8 +384,8 @@ def error(self, message: str) -> NoReturn: should either exit or raise an exception. """ self.print_usage(self.stderr) - args = {'prog': self.prog, 'message': message} - self.exit(2, gettext('%(prog)s: error: %(message)s\n') % args) + args = {"prog": self.prog, "message": message} + self.exit(2, gettext("%(prog)s: error: %(message)s\n") % args) class CapturableVersionAction(argparse.Action): @@ -388,42 +400,44 @@ class CapturableVersionAction(argparse.Action): (which does not appear to exist). """ - def __init__(self, - option_strings: Sequence[str], - version: str, - dest: str = argparse.SUPPRESS, - default: str = argparse.SUPPRESS, - help: str = "show program's version number and exit", - stdout: Optional[IO[str]] = None): + def __init__( + self, + option_strings: Sequence[str], + version: str, + dest: str = argparse.SUPPRESS, + default: str = argparse.SUPPRESS, + help: str = "show program's version number and exit", + stdout: Optional[IO[str]] = None, + ): super().__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help) + option_strings=option_strings, dest=dest, default=default, nargs=0, help=help + ) self.version = version self.stdout = stdout or sys.stdout - def __call__(self, - parser: argparse.ArgumentParser, - namespace: argparse.Namespace, - values: Union[str, Sequence[Any], None], - option_string: Optional[str] = None) -> NoReturn: + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Union[str, Sequence[Any], None], + option_string: Optional[str] = None, + ) -> NoReturn: formatter = parser._get_formatter() formatter.add_text(self.version) parser._print_message(formatter.format_help(), self.stdout) parser.exit() -def process_options(args: List[str], - stdout: Optional[TextIO] = None, - stderr: Optional[TextIO] = None, - require_targets: bool = True, - server_options: bool = False, - fscache: Optional[FileSystemCache] = None, - program: str = 'mypy', - header: str = HEADER, - ) -> Tuple[List[BuildSource], Options]: +def process_options( + args: List[str], + stdout: Optional[TextIO] = None, + stderr: Optional[TextIO] = None, + require_targets: bool = True, + server_options: bool = False, + fscache: Optional[FileSystemCache] = None, + program: str = "mypy", + header: str = HEADER, +) -> Tuple[List[BuildSource], Options]: """Parse command line arguments. If a FileSystemCache is passed in, and package_root options are given, @@ -432,28 +446,31 @@ def process_options(args: List[str], stdout = stdout or sys.stdout stderr = stderr or sys.stderr - parser = CapturableArgumentParser(prog=program, - usage=header, - description=DESCRIPTION, - epilog=FOOTER, - fromfile_prefix_chars='@', - formatter_class=AugmentedHelpFormatter, - add_help=False, - stdout=stdout, - stderr=stderr) + parser = CapturableArgumentParser( + prog=program, + usage=header, + description=DESCRIPTION, + epilog=FOOTER, + fromfile_prefix_chars="@", + formatter_class=AugmentedHelpFormatter, + add_help=False, + stdout=stdout, + stderr=stderr, + ) strict_flag_names: List[str] = [] strict_flag_assignments: List[Tuple[str, bool]] = [] - def add_invertible_flag(flag: str, - *, - inverse: Optional[str] = None, - default: bool, - dest: Optional[str] = None, - help: str, - strict_flag: bool = False, - group: Optional[argparse._ActionsContainer] = None - ) -> None: + def add_invertible_flag( + flag: str, + *, + inverse: Optional[str] = None, + default: bool, + dest: Optional[str] = None, + help: str, + strict_flag: bool = False, + group: Optional[argparse._ActionsContainer] = None, + ) -> None: if inverse is None: inverse = invert_flag_name(flag) if group is None: @@ -462,15 +479,16 @@ def add_invertible_flag(flag: str, if help is not argparse.SUPPRESS: help += f" (inverse: {inverse})" - arg = group.add_argument(flag, - action='store_false' if default else 'store_true', - dest=dest, - help=help) + arg = group.add_argument( + flag, action="store_false" if default else "store_true", dest=dest, help=help + ) dest = arg.dest - arg = group.add_argument(inverse, - action='store_true' if default else 'store_false', - dest=dest, - help=argparse.SUPPRESS) + arg = group.add_argument( + inverse, + action="store_true" if default else "store_false", + dest=dest, + help=argparse.SUPPRESS, + ) if strict_flag: assert dest is not None strict_flag_names.append(flag) @@ -484,172 +502,280 @@ def add_invertible_flag(flag: str, # Note: we have a style guide for formatting the mypy --help text. See # https://github.com/python/mypy/wiki/Documentation-Conventions - general_group = parser.add_argument_group( - title='Optional arguments') + general_group = parser.add_argument_group(title="Optional arguments") general_group.add_argument( - '-h', '--help', action='help', - help="Show this help message and exit") + "-h", "--help", action="help", help="Show this help message and exit" + ) general_group.add_argument( - '-v', '--verbose', action='count', dest='verbosity', - help="More verbose messages") + "-v", "--verbose", action="count", dest="verbosity", help="More verbose messages" + ) compilation_status = "no" if __file__.endswith(".py") else "yes" general_group.add_argument( - '-V', '--version', action=CapturableVersionAction, - version='%(prog)s ' + __version__ + f" (compiled: {compilation_status})", + "-V", + "--version", + action=CapturableVersionAction, + version="%(prog)s " + __version__ + f" (compiled: {compilation_status})", help="Show program's version number and exit", - stdout=stdout) + stdout=stdout, + ) config_group = parser.add_argument_group( - title='Config file', + title="Config file", description="Use a config file instead of command line arguments. " - "This is useful if you are using many flags or want " - "to set different options per each module.") + "This is useful if you are using many flags or want " + "to set different options per each module.", + ) config_group.add_argument( - '--config-file', + "--config-file", help="Configuration file, must have a [mypy] section " - "(defaults to {})".format(', '.join(defaults.CONFIG_FILES))) - add_invertible_flag('--warn-unused-configs', default=False, strict_flag=True, - help="Warn about unused '[mypy-]' or '[[tool.mypy.overrides]]' " - "config sections", - group=config_group) + "(defaults to {})".format(", ".join(defaults.CONFIG_FILES)), + ) + add_invertible_flag( + "--warn-unused-configs", + default=False, + strict_flag=True, + help="Warn about unused '[mypy-]' or '[[tool.mypy.overrides]]' " + "config sections", + group=config_group, + ) imports_group = parser.add_argument_group( - title='Import discovery', - description="Configure how imports are discovered and followed.") + title="Import discovery", description="Configure how imports are discovered and followed." + ) add_invertible_flag( - '--namespace-packages', default=False, + "--namespace-packages", + default=False, help="Support namespace packages (PEP 420, __init__.py-less)", - group=imports_group) + group=imports_group, + ) imports_group.add_argument( - '--ignore-missing-imports', action='store_true', - help="Silently ignore imports of missing modules") + "--ignore-missing-imports", + action="store_true", + help="Silently ignore imports of missing modules", + ) imports_group.add_argument( - '--follow-imports', choices=['normal', 'silent', 'skip', 'error'], - default='normal', help="How to treat imports (default normal)") + "--follow-imports", + choices=["normal", "silent", "skip", "error"], + default="normal", + help="How to treat imports (default normal)", + ) imports_group.add_argument( - '--python-executable', action='store', metavar='EXECUTABLE', + "--python-executable", + action="store", + metavar="EXECUTABLE", help="Python executable used for finding PEP 561 compliant installed" - " packages and stubs", - dest='special-opts:python_executable') + " packages and stubs", + dest="special-opts:python_executable", + ) imports_group.add_argument( - '--no-site-packages', action='store_true', - dest='special-opts:no_executable', - help="Do not search for installed PEP 561 compliant packages") + "--no-site-packages", + action="store_true", + dest="special-opts:no_executable", + help="Do not search for installed PEP 561 compliant packages", + ) imports_group.add_argument( - '--no-silence-site-packages', action='store_true', - help="Do not silence errors in PEP 561 compliant installed packages") + "--no-silence-site-packages", + action="store_true", + help="Do not silence errors in PEP 561 compliant installed packages", + ) platform_group = parser.add_argument_group( - title='Platform configuration', + title="Platform configuration", description="Type check code assuming it will be run under certain " - "runtime conditions. By default, mypy assumes your code " - "will be run using the same operating system and Python " - "version you are using to run mypy itself.") + "runtime conditions. By default, mypy assumes your code " + "will be run using the same operating system and Python " + "version you are using to run mypy itself.", + ) platform_group.add_argument( - '--python-version', type=parse_version, metavar='x.y', - help='Type check code assuming it will be running on Python x.y', - dest='special-opts:python_version') + "--python-version", + type=parse_version, + metavar="x.y", + help="Type check code assuming it will be running on Python x.y", + dest="special-opts:python_version", + ) platform_group.add_argument( - '-2', '--py2', dest='special-opts:python_version', action='store_const', + "-2", + "--py2", + dest="special-opts:python_version", + action="store_const", const=defaults.PYTHON2_VERSION, - help="Use Python 2 mode (same as --python-version 2.7)") + help="Use Python 2 mode (same as --python-version 2.7)", + ) platform_group.add_argument( - '--platform', action='store', metavar='PLATFORM', + "--platform", + action="store", + metavar="PLATFORM", help="Type check special-cased code for the given OS platform " - "(defaults to sys.platform)") + "(defaults to sys.platform)", + ) platform_group.add_argument( - '--always-true', metavar='NAME', action='append', default=[], - help="Additional variable to be considered True (may be repeated)") + "--always-true", + metavar="NAME", + action="append", + default=[], + help="Additional variable to be considered True (may be repeated)", + ) platform_group.add_argument( - '--always-false', metavar='NAME', action='append', default=[], - help="Additional variable to be considered False (may be repeated)") + "--always-false", + metavar="NAME", + action="append", + default=[], + help="Additional variable to be considered False (may be repeated)", + ) disallow_any_group = parser.add_argument_group( - title='Disallow dynamic typing', - description="Disallow the use of the dynamic 'Any' type under certain conditions.") + title="Disallow dynamic typing", + description="Disallow the use of the dynamic 'Any' type under certain conditions.", + ) disallow_any_group.add_argument( - '--disallow-any-unimported', default=False, action='store_true', - help="Disallow Any types resulting from unfollowed imports") + "--disallow-any-unimported", + default=False, + action="store_true", + help="Disallow Any types resulting from unfollowed imports", + ) disallow_any_group.add_argument( - '--disallow-any-expr', default=False, action='store_true', - help='Disallow all expressions that have type Any') + "--disallow-any-expr", + default=False, + action="store_true", + help="Disallow all expressions that have type Any", + ) disallow_any_group.add_argument( - '--disallow-any-decorated', default=False, action='store_true', - help='Disallow functions that have Any in their signature ' - 'after decorator transformation') + "--disallow-any-decorated", + default=False, + action="store_true", + help="Disallow functions that have Any in their signature " + "after decorator transformation", + ) disallow_any_group.add_argument( - '--disallow-any-explicit', default=False, action='store_true', - help='Disallow explicit Any in type positions') - add_invertible_flag('--disallow-any-generics', default=False, strict_flag=True, - help='Disallow usage of generic types that do not specify explicit type ' - 'parameters', group=disallow_any_group) - add_invertible_flag('--disallow-subclassing-any', default=False, strict_flag=True, - help="Disallow subclassing values of type 'Any' when defining classes", - group=disallow_any_group) + "--disallow-any-explicit", + default=False, + action="store_true", + help="Disallow explicit Any in type positions", + ) + add_invertible_flag( + "--disallow-any-generics", + default=False, + strict_flag=True, + help="Disallow usage of generic types that do not specify explicit type " "parameters", + group=disallow_any_group, + ) + add_invertible_flag( + "--disallow-subclassing-any", + default=False, + strict_flag=True, + help="Disallow subclassing values of type 'Any' when defining classes", + group=disallow_any_group, + ) untyped_group = parser.add_argument_group( - title='Untyped definitions and calls', + title="Untyped definitions and calls", description="Configure how untyped definitions and calls are handled. " - "Note: by default, mypy ignores any untyped function definitions " - "and assumes any calls to such functions have a return " - "type of 'Any'.") - add_invertible_flag('--disallow-untyped-calls', default=False, strict_flag=True, - help="Disallow calling functions without type annotations" - " from functions with type annotations", - group=untyped_group) - add_invertible_flag('--disallow-untyped-defs', default=False, strict_flag=True, - help="Disallow defining functions without type annotations" - " or with incomplete type annotations", - group=untyped_group) - add_invertible_flag('--disallow-incomplete-defs', default=False, strict_flag=True, - help="Disallow defining functions with incomplete type annotations", - group=untyped_group) - add_invertible_flag('--check-untyped-defs', default=False, strict_flag=True, - help="Type check the interior of functions without type annotations", - group=untyped_group) - add_invertible_flag('--disallow-untyped-decorators', default=False, strict_flag=True, - help="Disallow decorating typed functions with untyped decorators", - group=untyped_group) + "Note: by default, mypy ignores any untyped function definitions " + "and assumes any calls to such functions have a return " + "type of 'Any'.", + ) + add_invertible_flag( + "--disallow-untyped-calls", + default=False, + strict_flag=True, + help="Disallow calling functions without type annotations" + " from functions with type annotations", + group=untyped_group, + ) + add_invertible_flag( + "--disallow-untyped-defs", + default=False, + strict_flag=True, + help="Disallow defining functions without type annotations" + " or with incomplete type annotations", + group=untyped_group, + ) + add_invertible_flag( + "--disallow-incomplete-defs", + default=False, + strict_flag=True, + help="Disallow defining functions with incomplete type annotations", + group=untyped_group, + ) + add_invertible_flag( + "--check-untyped-defs", + default=False, + strict_flag=True, + help="Type check the interior of functions without type annotations", + group=untyped_group, + ) + add_invertible_flag( + "--disallow-untyped-decorators", + default=False, + strict_flag=True, + help="Disallow decorating typed functions with untyped decorators", + group=untyped_group, + ) none_group = parser.add_argument_group( - title='None and Optional handling', + title="None and Optional handling", description="Adjust how values of type 'None' are handled. For more context on " - "how mypy handles values of type 'None', see: " - "https://mypy.readthedocs.io/en/stable/kinds_of_types.html#no-strict-optional") - add_invertible_flag('--no-implicit-optional', default=False, strict_flag=True, - help="Don't assume arguments with default values of None are Optional", - group=none_group) - none_group.add_argument( - '--strict-optional', action='store_true', - help=argparse.SUPPRESS) + "how mypy handles values of type 'None', see: " + "https://mypy.readthedocs.io/en/stable/kinds_of_types.html#no-strict-optional", + ) + add_invertible_flag( + "--no-implicit-optional", + default=False, + strict_flag=True, + help="Don't assume arguments with default values of None are Optional", + group=none_group, + ) + none_group.add_argument("--strict-optional", action="store_true", help=argparse.SUPPRESS) none_group.add_argument( - '--no-strict-optional', action='store_false', dest='strict_optional', - help="Disable strict Optional checks (inverse: --strict-optional)") + "--no-strict-optional", + action="store_false", + dest="strict_optional", + help="Disable strict Optional checks (inverse: --strict-optional)", + ) none_group.add_argument( - '--strict-optional-whitelist', metavar='GLOB', nargs='*', - help=argparse.SUPPRESS) + "--strict-optional-whitelist", metavar="GLOB", nargs="*", help=argparse.SUPPRESS + ) lint_group = parser.add_argument_group( - title='Configuring warnings', - description="Detect code that is sound but redundant or problematic.") - add_invertible_flag('--warn-redundant-casts', default=False, strict_flag=True, - help="Warn about casting an expression to its inferred type", - group=lint_group) - add_invertible_flag('--warn-unused-ignores', default=False, strict_flag=True, - help="Warn about unneeded '# type: ignore' comments", - group=lint_group) - add_invertible_flag('--no-warn-no-return', dest='warn_no_return', default=True, - help="Do not warn about functions that end without returning", - group=lint_group) - add_invertible_flag('--warn-return-any', default=False, strict_flag=True, - help="Warn about returning values of type Any" - " from non-Any typed functions", - group=lint_group) - add_invertible_flag('--warn-unreachable', default=False, strict_flag=False, - help="Warn about statements or expressions inferred to be" - " unreachable", - group=lint_group) + title="Configuring warnings", + description="Detect code that is sound but redundant or problematic.", + ) + add_invertible_flag( + "--warn-redundant-casts", + default=False, + strict_flag=True, + help="Warn about casting an expression to its inferred type", + group=lint_group, + ) + add_invertible_flag( + "--warn-unused-ignores", + default=False, + strict_flag=True, + help="Warn about unneeded '# type: ignore' comments", + group=lint_group, + ) + add_invertible_flag( + "--no-warn-no-return", + dest="warn_no_return", + default=True, + help="Do not warn about functions that end without returning", + group=lint_group, + ) + add_invertible_flag( + "--warn-return-any", + default=False, + strict_flag=True, + help="Warn about returning values of type Any" " from non-Any typed functions", + group=lint_group, + ) + add_invertible_flag( + "--warn-unreachable", + default=False, + strict_flag=False, + help="Warn about statements or expressions inferred to be" " unreachable", + group=lint_group, + ) # Note: this group is intentionally added here even though we don't add # --strict to this group near the end. @@ -658,237 +784,341 @@ def add_invertible_flag(flag: str, # but before the remaining flags. # We add `--strict` near the end so we don't accidentally miss any strictness # flags that are added after this group. - strictness_group = parser.add_argument_group( - title='Miscellaneous strictness flags') + strictness_group = parser.add_argument_group(title="Miscellaneous strictness flags") - add_invertible_flag('--allow-untyped-globals', default=False, strict_flag=False, - help="Suppress toplevel errors caused by missing annotations", - group=strictness_group) + add_invertible_flag( + "--allow-untyped-globals", + default=False, + strict_flag=False, + help="Suppress toplevel errors caused by missing annotations", + group=strictness_group, + ) - add_invertible_flag('--allow-redefinition', default=False, strict_flag=False, - help="Allow unconditional variable redefinition with a new type", - group=strictness_group) + add_invertible_flag( + "--allow-redefinition", + default=False, + strict_flag=False, + help="Allow unconditional variable redefinition with a new type", + group=strictness_group, + ) - add_invertible_flag('--no-implicit-reexport', default=True, strict_flag=True, - dest='implicit_reexport', - help="Treat imports as private unless aliased", - group=strictness_group) + add_invertible_flag( + "--no-implicit-reexport", + default=True, + strict_flag=True, + dest="implicit_reexport", + help="Treat imports as private unless aliased", + group=strictness_group, + ) - add_invertible_flag('--strict-equality', default=False, strict_flag=True, - help="Prohibit equality, identity, and container checks for" - " non-overlapping types", - group=strictness_group) + add_invertible_flag( + "--strict-equality", + default=False, + strict_flag=True, + help="Prohibit equality, identity, and container checks for" " non-overlapping types", + group=strictness_group, + ) - add_invertible_flag('--strict-concatenate', default=False, strict_flag=True, - help="Make arguments prepended via Concatenate be truly positional-only", - group=strictness_group) + add_invertible_flag( + "--strict-concatenate", + default=False, + strict_flag=True, + help="Make arguments prepended via Concatenate be truly positional-only", + group=strictness_group, + ) strict_help = "Strict mode; enables the following flags: {}".format( - ", ".join(strict_flag_names)) + ", ".join(strict_flag_names) + ) strictness_group.add_argument( - '--strict', action='store_true', dest='special-opts:strict', - help=strict_help) + "--strict", action="store_true", dest="special-opts:strict", help=strict_help + ) strictness_group.add_argument( - '--disable-error-code', metavar='NAME', action='append', default=[], - help="Disable a specific error code") + "--disable-error-code", + metavar="NAME", + action="append", + default=[], + help="Disable a specific error code", + ) strictness_group.add_argument( - '--enable-error-code', metavar='NAME', action='append', default=[], - help="Enable a specific error code" + "--enable-error-code", + metavar="NAME", + action="append", + default=[], + help="Enable a specific error code", ) error_group = parser.add_argument_group( - title='Configuring error messages', - description="Adjust the amount of detail shown in error messages.") - add_invertible_flag('--show-error-context', default=False, - dest='show_error_context', - help='Precede errors with "note:" messages explaining context', - group=error_group) - add_invertible_flag('--show-column-numbers', default=False, - help="Show column numbers in error messages", - group=error_group) - add_invertible_flag('--show-error-end', default=False, - help="Show end line/end column numbers in error messages." - " This implies --show-column-numbers", - group=error_group) - add_invertible_flag('--show-error-codes', default=False, - help="Show error codes in error messages", - group=error_group) - add_invertible_flag('--pretty', default=False, - help="Use visually nicer output in error messages:" - " Use soft word wrap, show source code snippets," - " and show error location markers", - group=error_group) - add_invertible_flag('--no-color-output', dest='color_output', default=True, - help="Do not colorize error messages", - group=error_group) - add_invertible_flag('--no-error-summary', dest='error_summary', default=True, - help="Do not show error stats summary", - group=error_group) - add_invertible_flag('--show-absolute-path', default=False, - help="Show absolute paths to files", - group=error_group) - error_group.add_argument('--soft-error-limit', default=defaults.MANY_ERRORS_THRESHOLD, - type=int, dest="many_errors_threshold", help=argparse.SUPPRESS) + title="Configuring error messages", + description="Adjust the amount of detail shown in error messages.", + ) + add_invertible_flag( + "--show-error-context", + default=False, + dest="show_error_context", + help='Precede errors with "note:" messages explaining context', + group=error_group, + ) + add_invertible_flag( + "--show-column-numbers", + default=False, + help="Show column numbers in error messages", + group=error_group, + ) + add_invertible_flag( + "--show-error-end", + default=False, + help="Show end line/end column numbers in error messages." + " This implies --show-column-numbers", + group=error_group, + ) + add_invertible_flag( + "--show-error-codes", + default=False, + help="Show error codes in error messages", + group=error_group, + ) + add_invertible_flag( + "--pretty", + default=False, + help="Use visually nicer output in error messages:" + " Use soft word wrap, show source code snippets," + " and show error location markers", + group=error_group, + ) + add_invertible_flag( + "--no-color-output", + dest="color_output", + default=True, + help="Do not colorize error messages", + group=error_group, + ) + add_invertible_flag( + "--no-error-summary", + dest="error_summary", + default=True, + help="Do not show error stats summary", + group=error_group, + ) + add_invertible_flag( + "--show-absolute-path", + default=False, + help="Show absolute paths to files", + group=error_group, + ) + error_group.add_argument( + "--soft-error-limit", + default=defaults.MANY_ERRORS_THRESHOLD, + type=int, + dest="many_errors_threshold", + help=argparse.SUPPRESS, + ) incremental_group = parser.add_argument_group( - title='Incremental mode', + title="Incremental mode", description="Adjust how mypy incrementally type checks and caches modules. " - "Mypy caches type information about modules into a cache to " - "let you speed up future invocations of mypy. Also see " - "mypy's daemon mode: " - "mypy.readthedocs.io/en/stable/mypy_daemon.html#mypy-daemon") + "Mypy caches type information about modules into a cache to " + "let you speed up future invocations of mypy. Also see " + "mypy's daemon mode: " + "mypy.readthedocs.io/en/stable/mypy_daemon.html#mypy-daemon", + ) incremental_group.add_argument( - '-i', '--incremental', action='store_true', - help=argparse.SUPPRESS) + "-i", "--incremental", action="store_true", help=argparse.SUPPRESS + ) incremental_group.add_argument( - '--no-incremental', action='store_false', dest='incremental', - help="Disable module cache (inverse: --incremental)") + "--no-incremental", + action="store_false", + dest="incremental", + help="Disable module cache (inverse: --incremental)", + ) incremental_group.add_argument( - '--cache-dir', action='store', metavar='DIR', + "--cache-dir", + action="store", + metavar="DIR", help="Store module cache info in the given folder in incremental mode " - "(defaults to '{}')".format(defaults.CACHE_DIR)) - add_invertible_flag('--sqlite-cache', default=False, - help="Use a sqlite database to store the cache", - group=incremental_group) + "(defaults to '{}')".format(defaults.CACHE_DIR), + ) + add_invertible_flag( + "--sqlite-cache", + default=False, + help="Use a sqlite database to store the cache", + group=incremental_group, + ) incremental_group.add_argument( - '--cache-fine-grained', action='store_true', - help="Include fine-grained dependency information in the cache for the mypy daemon") + "--cache-fine-grained", + action="store_true", + help="Include fine-grained dependency information in the cache for the mypy daemon", + ) incremental_group.add_argument( - '--skip-version-check', action='store_true', - help="Allow using cache written by older mypy version") + "--skip-version-check", + action="store_true", + help="Allow using cache written by older mypy version", + ) incremental_group.add_argument( - '--skip-cache-mtime-checks', action='store_true', - help="Skip cache internal consistency checks based on mtime") + "--skip-cache-mtime-checks", + action="store_true", + help="Skip cache internal consistency checks based on mtime", + ) internals_group = parser.add_argument_group( - title='Advanced options', - description="Debug and customize mypy internals.") - internals_group.add_argument( - '--pdb', action='store_true', help="Invoke pdb on fatal error") + title="Advanced options", description="Debug and customize mypy internals." + ) + internals_group.add_argument("--pdb", action="store_true", help="Invoke pdb on fatal error") internals_group.add_argument( - '--show-traceback', '--tb', action='store_true', - help="Show traceback on fatal error") + "--show-traceback", "--tb", action="store_true", help="Show traceback on fatal error" + ) internals_group.add_argument( - '--raise-exceptions', action='store_true', help="Raise exception on fatal error" + "--raise-exceptions", action="store_true", help="Raise exception on fatal error" ) internals_group.add_argument( - '--custom-typing-module', metavar='MODULE', dest='custom_typing_module', - help="Use a custom typing module") + "--custom-typing-module", + metavar="MODULE", + dest="custom_typing_module", + help="Use a custom typing module", + ) internals_group.add_argument( - '--custom-typeshed-dir', metavar='DIR', - help="Use the custom typeshed in DIR") - add_invertible_flag('--warn-incomplete-stub', default=False, - help="Warn if missing type annotation in typeshed, only relevant with" - " --disallow-untyped-defs or --disallow-incomplete-defs enabled", - group=internals_group) + "--custom-typeshed-dir", metavar="DIR", help="Use the custom typeshed in DIR" + ) + add_invertible_flag( + "--warn-incomplete-stub", + default=False, + help="Warn if missing type annotation in typeshed, only relevant with" + " --disallow-untyped-defs or --disallow-incomplete-defs enabled", + group=internals_group, + ) internals_group.add_argument( - '--shadow-file', nargs=2, metavar=('SOURCE_FILE', 'SHADOW_FILE'), - dest='shadow_file', action='append', + "--shadow-file", + nargs=2, + metavar=("SOURCE_FILE", "SHADOW_FILE"), + dest="shadow_file", + action="append", help="When encountering SOURCE_FILE, read and type check " - "the contents of SHADOW_FILE instead.") - add_invertible_flag('--fast-exit', default=True, help=argparse.SUPPRESS, - group=internals_group) + "the contents of SHADOW_FILE instead.", + ) + add_invertible_flag("--fast-exit", default=True, help=argparse.SUPPRESS, group=internals_group) report_group = parser.add_argument_group( - title='Report generation', - description='Generate a report in the specified format.') + title="Report generation", description="Generate a report in the specified format." + ) for report_type in sorted(defaults.REPORTER_NAMES): - if report_type not in {'memory-xml'}: - report_group.add_argument(f"--{report_type.replace('_', '-')}-report", - metavar='DIR', - dest=f'special-opts:{report_type}_report') + if report_type not in {"memory-xml"}: + report_group.add_argument( + f"--{report_type.replace('_', '-')}-report", + metavar="DIR", + dest=f"special-opts:{report_type}_report", + ) - other_group = parser.add_argument_group( - title='Miscellaneous') + other_group = parser.add_argument_group(title="Miscellaneous") + other_group.add_argument("--quickstart-file", help=argparse.SUPPRESS) + other_group.add_argument("--junit-xml", help="Write junit.xml to the given file") other_group.add_argument( - '--quickstart-file', help=argparse.SUPPRESS) - other_group.add_argument( - '--junit-xml', help="Write junit.xml to the given file") - other_group.add_argument( - '--find-occurrences', metavar='CLASS.MEMBER', - dest='special-opts:find_occurrences', - help="Print out all usages of a class member (experimental)") + "--find-occurrences", + metavar="CLASS.MEMBER", + dest="special-opts:find_occurrences", + help="Print out all usages of a class member (experimental)", + ) other_group.add_argument( - '--scripts-are-modules', action='store_true', - help="Script x becomes module x instead of __main__") + "--scripts-are-modules", + action="store_true", + help="Script x becomes module x instead of __main__", + ) - add_invertible_flag('--install-types', default=False, strict_flag=False, - help="Install detected missing library stub packages using pip", - group=other_group) - add_invertible_flag('--non-interactive', default=False, strict_flag=False, - help=("Install stubs without asking for confirmation and hide " + - "errors, with --install-types"), - group=other_group, inverse="--interactive") + add_invertible_flag( + "--install-types", + default=False, + strict_flag=False, + help="Install detected missing library stub packages using pip", + group=other_group, + ) + add_invertible_flag( + "--non-interactive", + default=False, + strict_flag=False, + help=( + "Install stubs without asking for confirmation and hide " + + "errors, with --install-types" + ), + group=other_group, + inverse="--interactive", + ) if server_options: # TODO: This flag is superfluous; remove after a short transition (2018-03-16) other_group.add_argument( - '--experimental', action='store_true', dest='fine_grained_incremental', - help="Enable fine-grained incremental mode") + "--experimental", + action="store_true", + dest="fine_grained_incremental", + help="Enable fine-grained incremental mode", + ) other_group.add_argument( - '--use-fine-grained-cache', action='store_true', - help="Use the cache in fine-grained incremental mode") + "--use-fine-grained-cache", + action="store_true", + help="Use the cache in fine-grained incremental mode", + ) # hidden options parser.add_argument( - '--stats', action='store_true', dest='dump_type_stats', help=argparse.SUPPRESS) - parser.add_argument( - '--inferstats', action='store_true', dest='dump_inference_stats', - help=argparse.SUPPRESS) + "--stats", action="store_true", dest="dump_type_stats", help=argparse.SUPPRESS + ) parser.add_argument( - '--dump-build-stats', action='store_true', - help=argparse.SUPPRESS) + "--inferstats", action="store_true", dest="dump_inference_stats", help=argparse.SUPPRESS + ) + parser.add_argument("--dump-build-stats", action="store_true", help=argparse.SUPPRESS) # dump timing stats for each processed file into the given output file - parser.add_argument( - '--timing-stats', dest='timing_stats', help=argparse.SUPPRESS) + parser.add_argument("--timing-stats", dest="timing_stats", help=argparse.SUPPRESS) # --debug-cache will disable any cache-related compressions/optimizations, # which will make the cache writing process output pretty-printed JSON (which # is easier to debug). - parser.add_argument('--debug-cache', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--debug-cache", action="store_true", help=argparse.SUPPRESS) # --dump-deps will dump all fine-grained dependencies to stdout - parser.add_argument('--dump-deps', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--dump-deps", action="store_true", help=argparse.SUPPRESS) # --dump-graph will dump the contents of the graph of SCCs and exit. - parser.add_argument('--dump-graph', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--dump-graph", action="store_true", help=argparse.SUPPRESS) # --semantic-analysis-only does exactly that. - parser.add_argument('--semantic-analysis-only', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--semantic-analysis-only", action="store_true", help=argparse.SUPPRESS) # --local-partial-types disallows partial types spanning module top level and a function # (implicitly defined in fine-grained incremental mode) - parser.add_argument('--local-partial-types', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--local-partial-types", action="store_true", help=argparse.SUPPRESS) # --logical-deps adds some more dependencies that are not semantically needed, but # may be helpful to determine relative importance of classes and functions for overall # type precision in a code base. It also _removes_ some deps, so this flag should be never # used except for generating code stats. This also automatically enables --cache-fine-grained. # NOTE: This is an experimental option that may be modified or removed at any time. - parser.add_argument('--logical-deps', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--logical-deps", action="store_true", help=argparse.SUPPRESS) # --bazel changes some behaviors for use with Bazel (https://bazel.build). - parser.add_argument('--bazel', action='store_true', help=argparse.SUPPRESS) + parser.add_argument("--bazel", action="store_true", help=argparse.SUPPRESS) # --package-root adds a directory below which directories are considered # packages even without __init__.py. May be repeated. - parser.add_argument('--package-root', metavar='ROOT', action='append', default=[], - help=argparse.SUPPRESS) + parser.add_argument( + "--package-root", metavar="ROOT", action="append", default=[], help=argparse.SUPPRESS + ) # --cache-map FILE ... gives a mapping from source files to cache files. # Each triple of arguments is a source file, a cache meta file, and a cache data file. # Modules not mentioned in the file will go through cache_dir. # Must be followed by another flag or by '--' (and then only file args may follow). - parser.add_argument('--cache-map', nargs='+', dest='special-opts:cache_map', - help=argparse.SUPPRESS) - parser.add_argument('--enable-incomplete-features', action='store_true', - help=argparse.SUPPRESS) + parser.add_argument( + "--cache-map", nargs="+", dest="special-opts:cache_map", help=argparse.SUPPRESS + ) + parser.add_argument( + "--enable-incomplete-features", action="store_true", help=argparse.SUPPRESS + ) # options specifying code to check code_group = parser.add_argument_group( title="Running code", description="Specify the code you want to type check. For more details, see " - "mypy.readthedocs.io/en/stable/running_mypy.html#running-mypy") + "mypy.readthedocs.io/en/stable/running_mypy.html#running-mypy", + ) add_invertible_flag( - '--explicit-package-bases', default=False, + "--explicit-package-bases", + default=False, help="Use current directory and MYPYPATH to determine module names of files passed", - group=code_group) + group=code_group, + ) add_invertible_flag( - '--fast-module-lookup', default=False, - help=argparse.SUPPRESS, - group=code_group) + "--fast-module-lookup", default=False, help=argparse.SUPPRESS, group=code_group + ) code_group.add_argument( "--exclude", action="append", @@ -898,25 +1128,40 @@ def add_invertible_flag(flag: str, "Regular expression to match file names, directory names or paths which mypy should " "ignore while recursively discovering files to check, e.g. --exclude '/setup\\.py$'. " "May be specified more than once, eg. --exclude a --exclude b" - ) + ), ) code_group.add_argument( - '-m', '--module', action='append', metavar='MODULE', + "-m", + "--module", + action="append", + metavar="MODULE", default=[], - dest='special-opts:modules', - help="Type-check module; can repeat for more modules") + dest="special-opts:modules", + help="Type-check module; can repeat for more modules", + ) code_group.add_argument( - '-p', '--package', action='append', metavar='PACKAGE', + "-p", + "--package", + action="append", + metavar="PACKAGE", default=[], - dest='special-opts:packages', - help="Type-check package recursively; can be repeated") + dest="special-opts:packages", + help="Type-check package recursively; can be repeated", + ) code_group.add_argument( - '-c', '--command', action='append', metavar='PROGRAM_TEXT', - dest='special-opts:command', - help="Type-check program passed in as string") + "-c", + "--command", + action="append", + metavar="PROGRAM_TEXT", + dest="special-opts:command", + help="Type-check program passed in as string", + ) code_group.add_argument( - metavar='files', nargs='*', dest='special-opts:files', - help="Type-check given files or directories") + metavar="files", + nargs="*", + dest="special-opts:files", + help="Type-check given files or directories", + ) # Parse arguments once into a dummy namespace so we can get the # filename for the config file and know if the user requested all strict options. @@ -939,18 +1184,18 @@ def set_strict_flags() -> None: # Set strict flags before parsing (if strict mode enabled), so other command # line options can override. - if getattr(dummy, 'special-opts:strict'): # noqa + if getattr(dummy, "special-opts:strict"): # noqa set_strict_flags() # Override cache_dir if provided in the environment - environ_cache_dir = os.getenv('MYPY_CACHE_DIR', '') + environ_cache_dir = os.getenv("MYPY_CACHE_DIR", "") if environ_cache_dir.strip(): options.cache_dir = environ_cache_dir options.cache_dir = os.path.expanduser(options.cache_dir) # Parse command line for real, using a split namespace. special_opts = argparse.Namespace() - parser.parse_args(args, SplitNamespace(options, special_opts, 'special-opts:')) + parser.parse_args(args, SplitNamespace(options, special_opts, "special-opts:")) # The python_version is either the default, which can be overridden via a config file, # or stored in special_opts and is passed via the command line. @@ -975,9 +1220,14 @@ def set_strict_flags() -> None: # Check for invalid argument combinations. if require_targets: - code_methods = sum(bool(c) for c in [special_opts.modules + special_opts.packages, - special_opts.command, - special_opts.files]) + code_methods = sum( + bool(c) + for c in [ + special_opts.modules + special_opts.packages, + special_opts.command, + special_opts.files, + ] + ) if code_methods == 0 and not options.install_types: parser.error("Missing target module, package, files, or command.") elif code_methods > 1: @@ -991,8 +1241,10 @@ def set_strict_flags() -> None: # Check for overlapping `--always-true` and `--always-false` flags. overlap = set(options.always_true) & set(options.always_false) if overlap: - parser.error("You can't make a variable always true and always false (%s)" % - ', '.join(sorted(overlap))) + parser.error( + "You can't make a variable always true and always false (%s)" + % ", ".join(sorted(overlap)) + ) # Process `--enable-error-code` and `--disable-error-code` flags disabled_codes = set(options.disable_error_code) @@ -1015,7 +1267,7 @@ def set_strict_flags() -> None: # TODO: Deprecate, then kill this flag options.strict_optional = True if special_opts.find_occurrences: - state.find_occurrences = special_opts.find_occurrences.split('.') + state.find_occurrences = special_opts.find_occurrences.split(".") assert state.find_occurrences is not None if len(state.find_occurrences) < 2: parser.error("Can only find occurrences of class members.") @@ -1024,8 +1276,8 @@ def set_strict_flags() -> None: # Set reports. for flag, val in vars(special_opts).items(): - if flag.endswith('_report') and val is not None: - report_type = flag[:-7].replace('_', '-') + if flag.endswith("_report") and val is not None: + report_type = flag[:-7].replace("_", "-") report_dir = val options.report_dirs[report_type] = report_dir @@ -1065,8 +1317,7 @@ def set_strict_flags() -> None: cache = FindModuleCache(search_paths, fscache, options) for p in special_opts.packages: if os.sep in p or os.altsep and os.altsep in p: - fail(f"Package name '{p}' cannot have a slash in it.", - stderr, options) + fail(f"Package name '{p}' cannot have a slash in it.", stderr, options) p_targets = cache.find_modules_recursive(p) if not p_targets: fail(f"Can't find package '{p}'", stderr, options) @@ -1076,7 +1327,7 @@ def set_strict_flags() -> None: return targets, options elif special_opts.command: options.build_type = BuildType.PROGRAM_TEXT - targets = [BuildSource(None, None, '\n'.join(special_opts.command))] + targets = [BuildSource(None, None, "\n".join(special_opts.command))] return targets, options else: try: @@ -1089,9 +1340,9 @@ def set_strict_flags() -> None: return targets, options -def process_package_roots(fscache: Optional[FileSystemCache], - parser: argparse.ArgumentParser, - options: Options) -> None: +def process_package_roots( + fscache: Optional[FileSystemCache], parser: argparse.ArgumentParser, options: Options +) -> None: """Validate and normalize package_root.""" if fscache is None: parser.error("--package-root does not work here (no fscache)") @@ -1117,45 +1368,48 @@ def process_package_roots(fscache: Optional[FileSystemCache], if root.startswith(dotdotslash): parser.error(f"Package root cannot be above current directory: {root!r}") if root in trivial_paths: - root = '' + root = "" package_root.append(root) options.package_root = package_root # Pass the package root on the the filesystem cache. fscache.set_package_root(package_root) -def process_cache_map(parser: argparse.ArgumentParser, - special_opts: argparse.Namespace, - options: Options) -> None: +def process_cache_map( + parser: argparse.ArgumentParser, special_opts: argparse.Namespace, options: Options +) -> None: """Validate cache_map and copy into options.cache_map.""" n = len(special_opts.cache_map) if n % 3 != 0: parser.error("--cache-map requires one or more triples (see source)") for i in range(0, n, 3): - source, meta_file, data_file = special_opts.cache_map[i:i + 3] + source, meta_file, data_file = special_opts.cache_map[i : i + 3] if source in options.cache_map: parser.error(f"Duplicate --cache-map source {source})") - if not source.endswith('.py') and not source.endswith('.pyi'): + if not source.endswith(".py") and not source.endswith(".pyi"): parser.error(f"Invalid --cache-map source {source} (triple[0] must be *.py[i])") - if not meta_file.endswith('.meta.json'): - parser.error("Invalid --cache-map meta_file %s (triple[1] must be *.meta.json)" % - meta_file) - if not data_file.endswith('.data.json'): - parser.error("Invalid --cache-map data_file %s (triple[2] must be *.data.json)" % - data_file) + if not meta_file.endswith(".meta.json"): + parser.error( + "Invalid --cache-map meta_file %s (triple[1] must be *.meta.json)" % meta_file + ) + if not data_file.endswith(".data.json"): + parser.error( + "Invalid --cache-map data_file %s (triple[2] must be *.data.json)" % data_file + ) options.cache_map[source] = (meta_file, data_file) def maybe_write_junit_xml(td: float, serious: bool, messages: List[str], options: Options) -> None: if options.junit_xml: - py_version = f'{options.python_version[0]}_{options.python_version[1]}' + py_version = f"{options.python_version[0]}_{options.python_version[1]}" util.write_junit_xml( - td, serious, messages, options.junit_xml, py_version, options.platform) + td, serious, messages, options.junit_xml, py_version, options.platform + ) def fail(msg: str, stderr: TextIO, options: Options) -> NoReturn: """Fail with a serious error.""" - stderr.write(f'{msg}\n') + stderr.write(f"{msg}\n") maybe_write_junit_xml(0.0, serious=True, messages=[msg], options=options) sys.exit(2) @@ -1164,13 +1418,11 @@ def read_types_packages_to_install(cache_dir: str, after_run: bool) -> List[str] if not os.path.isdir(cache_dir): if not after_run: sys.stderr.write( - "error: Can't determine which types to install with no files to check " + - "(and no cache from previous mypy run)\n" + "error: Can't determine which types to install with no files to check " + + "(and no cache from previous mypy run)\n" ) else: - sys.stderr.write( - "error: --install-types failed (no mypy cache directory)\n" - ) + sys.stderr.write("error: --install-types failed (no mypy cache directory)\n") sys.exit(2) fnam = build.missing_stubs_file(cache_dir) if not os.path.isfile(fnam): @@ -1180,11 +1432,13 @@ def read_types_packages_to_install(cache_dir: str, after_run: bool) -> List[str] return [line.strip() for line in f.readlines()] -def install_types(formatter: util.FancyFormatter, - options: Options, - *, - after_run: bool = False, - non_interactive: bool = False) -> bool: +def install_types( + formatter: util.FancyFormatter, + options: Options, + *, + after_run: bool = False, + non_interactive: bool = False, +) -> bool: """Install stub packages using pip if some missing stubs were detected.""" packages = read_types_packages_to_install(options.cache_dir, after_run) if not packages: @@ -1192,15 +1446,15 @@ def install_types(formatter: util.FancyFormatter, return False if after_run and not non_interactive: print() - print('Installing missing stub packages:') - assert options.python_executable, 'Python executable required to install types' - cmd = [options.python_executable, '-m', 'pip', 'install'] + packages - print(formatter.style(' '.join(cmd), 'none', bold=True)) + print("Installing missing stub packages:") + assert options.python_executable, "Python executable required to install types" + cmd = [options.python_executable, "-m", "pip", "install"] + packages + print(formatter.style(" ".join(cmd), "none", bold=True)) print() if not non_interactive: - x = input('Install? [yN] ') - if not x.strip() or not x.lower().startswith('y'): - print(formatter.style('mypy: Skipping installation', 'red', bold=True)) + x = input("Install? [yN] ") + if not x.strip() or not x.lower().startswith("y"): + print(formatter.style("mypy: Skipping installation", "red", bold=True)) sys.exit(2) print() subprocess.run(cmd) diff --git a/mypy/maptype.py b/mypy/maptype.py index 1216c6015378c..59d86d9f79b8f 100644 --- a/mypy/maptype.py +++ b/mypy/maptype.py @@ -2,11 +2,10 @@ from mypy.expandtype import expand_type from mypy.nodes import TypeInfo -from mypy.types import Type, TypeVarId, Instance, AnyType, TypeOfAny, ProperType +from mypy.types import AnyType, Instance, ProperType, Type, TypeOfAny, TypeVarId -def map_instance_to_supertype(instance: Instance, - superclass: TypeInfo) -> Instance: +def map_instance_to_supertype(instance: Instance, superclass: TypeInfo) -> Instance: """Produce a supertype of `instance` that is an Instance of `superclass`, mapping type arguments up the chain of bases. @@ -24,8 +23,7 @@ def map_instance_to_supertype(instance: Instance, return map_instance_to_supertypes(instance, superclass)[0] -def map_instance_to_supertypes(instance: Instance, - supertype: TypeInfo) -> List[Instance]: +def map_instance_to_supertypes(instance: Instance, supertype: TypeInfo) -> List[Instance]: # FIX: Currently we should only have one supertype per interface, so no # need to return an array result: List[Instance] = [] @@ -45,8 +43,7 @@ def map_instance_to_supertypes(instance: Instance, return [Instance(supertype, [any_type] * len(supertype.type_vars))] -def class_derivation_paths(typ: TypeInfo, - supertype: TypeInfo) -> List[List[TypeInfo]]: +def class_derivation_paths(typ: TypeInfo, supertype: TypeInfo) -> List[List[TypeInfo]]: """Return an array of non-empty paths of direct base classes from type to supertype. Return [] if no such path could be found. @@ -70,8 +67,7 @@ def class_derivation_paths(typ: TypeInfo, return result -def map_instance_to_direct_supertypes(instance: Instance, - supertype: TypeInfo) -> List[Instance]: +def map_instance_to_direct_supertypes(instance: Instance, supertype: TypeInfo) -> List[Instance]: # FIX: There should only be one supertypes, always. typ = instance.type result: List[Instance] = [] diff --git a/mypy/meet.py b/mypy/meet.py index ebaf0f675ef13..deb95f11283a5 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -1,19 +1,44 @@ -from mypy.backports import OrderedDict -from typing import List, Optional, Tuple, Callable +from typing import Callable, List, Optional, Tuple -from mypy.types import ( - Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType, - TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType, - UninhabitedType, TypeType, TypeOfAny, Overloaded, FunctionLike, LiteralType, - ProperType, get_proper_type, get_proper_types, TypeAliasType, TypeGuardedType, - ParamSpecType, Parameters, UnpackType, TypeVarTupleType, TypeVarLikeType -) -from mypy.subtypes import is_equivalent, is_subtype, is_callable_compatible, is_proper_subtype +from mypy import join +from mypy.backports import OrderedDict from mypy.erasetype import erase_type from mypy.maptype import map_instance_to_supertype -from mypy.typeops import tuple_fallback, make_simplified_union, is_recursive_pair from mypy.state import state -from mypy import join +from mypy.subtypes import is_callable_compatible, is_equivalent, is_proper_subtype, is_subtype +from mypy.typeops import is_recursive_pair, make_simplified_union, tuple_fallback +from mypy.types import ( + AnyType, + CallableType, + DeletedType, + ErasedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeGuardedType, + TypeOfAny, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, + get_proper_types, +) # TODO Describe this module. @@ -62,28 +87,31 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: if declared == narrowed: return declared if isinstance(declared, UnionType): - return make_simplified_union([narrow_declared_type(x, narrowed) - for x in declared.relevant_items()]) + return make_simplified_union( + [narrow_declared_type(x, narrowed) for x in declared.relevant_items()] + ) if is_enum_overlapping_union(declared, narrowed): return narrowed - elif not is_overlapping_types(declared, narrowed, - prohibit_none_typevar_overlap=True): + elif not is_overlapping_types(declared, narrowed, prohibit_none_typevar_overlap=True): if state.strict_optional: return UninhabitedType() else: return NoneType() elif isinstance(narrowed, UnionType): - return make_simplified_union([narrow_declared_type(declared, x) - for x in narrowed.relevant_items()]) + return make_simplified_union( + [narrow_declared_type(declared, x) for x in narrowed.relevant_items()] + ) elif isinstance(narrowed, AnyType): return narrowed elif isinstance(narrowed, TypeVarType) and is_subtype(narrowed.upper_bound, declared): return narrowed elif isinstance(declared, TypeType) and isinstance(narrowed, TypeType): return TypeType.make_normalized(narrow_declared_type(declared.item, narrowed.item)) - elif (isinstance(declared, TypeType) - and isinstance(narrowed, Instance) - and narrowed.type.is_metaclass()): + elif ( + isinstance(declared, TypeType) + and isinstance(narrowed, Instance) + and narrowed.type.is_metaclass() + ): # We'd need intersection types, so give up. return declared elif isinstance(declared, Instance): @@ -95,8 +123,9 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type: return meet_types(declared, narrowed) elif isinstance(declared, TypedDictType) and isinstance(narrowed, Instance): # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). - if (narrowed.type.fullname == 'builtins.dict' and - all(isinstance(t, AnyType) for t in get_proper_types(narrowed.args))): + if narrowed.type.fullname == "builtins.dict" and all( + isinstance(t, AnyType) for t in get_proper_types(narrowed.args) + ): return declared return meet_types(declared, narrowed) return narrowed @@ -149,32 +178,39 @@ def get_possible_variants(typ: Type) -> List[Type]: def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool: """Return True if x is an Enum, and y is an Union with at least one Literal from x""" return ( - isinstance(x, Instance) and x.type.is_enum and - isinstance(y, UnionType) and - any(isinstance(p, LiteralType) and x.type == p.fallback.type - for p in (get_proper_type(z) for z in y.relevant_items())) + isinstance(x, Instance) + and x.type.is_enum + and isinstance(y, UnionType) + and any( + isinstance(p, LiteralType) and x.type == p.fallback.type + for p in (get_proper_type(z) for z in y.relevant_items()) + ) ) def is_literal_in_union(x: ProperType, y: ProperType) -> bool: """Return True if x is a Literal and y is an Union that includes x""" - return (isinstance(x, LiteralType) and isinstance(y, UnionType) and - any(x == get_proper_type(z) for z in y.items)) + return ( + isinstance(x, LiteralType) + and isinstance(y, UnionType) + and any(x == get_proper_type(z) for z in y.items) + ) -def is_overlapping_types(left: Type, - right: Type, - ignore_promotions: bool = False, - prohibit_none_typevar_overlap: bool = False) -> bool: +def is_overlapping_types( + left: Type, + right: Type, + ignore_promotions: bool = False, + prohibit_none_typevar_overlap: bool = False, +) -> bool: """Can a value of type 'left' also be of type 'right' or vice-versa? If 'ignore_promotions' is True, we ignore promotions while checking for overlaps. If 'prohibit_none_typevar_overlap' is True, we disallow None from overlapping with TypeVars (in both strict-optional and non-strict-optional mode). """ - if ( - isinstance(left, TypeGuardedType) # type: ignore[misc] - or isinstance(right, TypeGuardedType) # type: ignore[misc] + if isinstance(left, TypeGuardedType) or isinstance( # type: ignore[misc] + right, TypeGuardedType ): # A type guard forces the new type even if it doesn't overlap the old. return True @@ -182,13 +218,15 @@ def is_overlapping_types(left: Type, left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: - '''Encode the kind of overlapping check to perform. + """Encode the kind of overlapping check to perform. - This function mostly exists so we don't have to repeat keyword arguments everywhere.''' + This function mostly exists so we don't have to repeat keyword arguments everywhere.""" return is_overlapping_types( - left, right, + left, + right, ignore_promotions=ignore_promotions, - prohibit_none_typevar_overlap=prohibit_none_typevar_overlap) + prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, + ) # We should never encounter this type. if isinstance(left, PartialType) or isinstance(right, PartialType): @@ -235,8 +273,9 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: ): return True - if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions) - or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)): + if is_proper_subtype(left, right, ignore_promotions=ignore_promotions) or is_proper_subtype( + right, left, ignore_promotions=ignore_promotions + ): return True # See the docstring for 'get_possible_variants' for more info on what the @@ -263,8 +302,12 @@ def is_none_typevarlike_overlap(t1: Type, t2: Type) -> bool: if is_none_typevarlike_overlap(left, right) or is_none_typevarlike_overlap(right, left): return False - if (len(left_possible) > 1 or len(right_possible) > 1 - or isinstance(left, TypeVarLikeType) or isinstance(right, TypeVarLikeType)): + if ( + len(left_possible) > 1 + or len(right_possible) > 1 + or isinstance(left, TypeVarLikeType) + or isinstance(right, TypeVarLikeType) + ): for l in left_possible: for r in right_possible: if _is_overlapping_types(l, r): @@ -293,8 +336,7 @@ def is_none_typevarlike_overlap(t1: Type, t2: Type) -> bool: return are_typed_dicts_overlapping(left, right, ignore_promotions=ignore_promotions) elif typed_dict_mapping_pair(left, right): # Overlaps between TypedDicts and Mappings require dedicated logic. - return typed_dict_mapping_overlap(left, right, - overlapping=_is_overlapping_types) + return typed_dict_mapping_overlap(left, right, overlapping=_is_overlapping_types) elif isinstance(left, TypedDictType): left = left.fallback elif isinstance(right, TypedDictType): @@ -329,9 +371,9 @@ def _type_object_overlap(left: Type, right: Type) -> bool: if left_meta is not None: return _is_overlapping_types(left_meta, right) # builtins.type (default metaclass) overlaps with all metaclasses - return right.type.has_base('builtins.type') + return right.type.has_base("builtins.type") elif isinstance(left.item, AnyType): - return right.type.has_base('builtins.type') + return right.type.has_base("builtins.type") # 3. Callable[..., C] vs Meta is considered below, when we switch to fallbacks. return False @@ -339,10 +381,13 @@ def _type_object_overlap(left: Type, right: Type) -> bool: return _type_object_overlap(left, right) or _type_object_overlap(right, left) if isinstance(left, CallableType) and isinstance(right, CallableType): - return is_callable_compatible(left, right, - is_compat=_is_overlapping_types, - ignore_pos_arg_names=True, - allow_partial_overlap=True) + return is_callable_compatible( + left, + right, + is_compat=_is_overlapping_types, + ignore_pos_arg_names=True, + allow_partial_overlap=True, + ) elif isinstance(left, CallableType): left = left.fallback elif isinstance(right, CallableType): @@ -366,8 +411,9 @@ def _type_object_overlap(left: Type, right: Type) -> bool: if isinstance(left, Instance) and isinstance(right, Instance): # First we need to handle promotions and structural compatibility for instances # that came as fallbacks, so simply call is_subtype() to avoid code duplication. - if (is_subtype(left, right, ignore_promotions=ignore_promotions) - or is_subtype(right, left, ignore_promotions=ignore_promotions)): + if is_subtype(left, right, ignore_promotions=ignore_promotions) or is_subtype( + right, left, ignore_promotions=ignore_promotions + ): return True # Two unrelated types cannot be partially overlapping: they're disjoint. @@ -393,8 +439,10 @@ def _type_object_overlap(left: Type, right: Type) -> bool: # Or, to use a more concrete example, List[Union[A, B]] and List[Union[B, C]] # would be considered partially overlapping since it's possible for both lists # to contain only instances of B at runtime. - if all(_is_overlapping_types(left_arg, right_arg) - for left_arg, right_arg in zip(left.args, right.args)): + if all( + _is_overlapping_types(left_arg, right_arg) + for left_arg, right_arg in zip(left.args, right.args) + ): return True return False @@ -409,33 +457,45 @@ def _type_object_overlap(left: Type, right: Type) -> bool: return False -def is_overlapping_erased_types(left: Type, right: Type, *, - ignore_promotions: bool = False) -> bool: +def is_overlapping_erased_types( + left: Type, right: Type, *, ignore_promotions: bool = False +) -> bool: """The same as 'is_overlapping_erased_types', except the types are erased first.""" - return is_overlapping_types(erase_type(left), erase_type(right), - ignore_promotions=ignore_promotions, - prohibit_none_typevar_overlap=True) + return is_overlapping_types( + erase_type(left), + erase_type(right), + ignore_promotions=ignore_promotions, + prohibit_none_typevar_overlap=True, + ) -def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, - ignore_promotions: bool = False, - prohibit_none_typevar_overlap: bool = False) -> bool: +def are_typed_dicts_overlapping( + left: TypedDictType, + right: TypedDictType, + *, + ignore_promotions: bool = False, + prohibit_none_typevar_overlap: bool = False, +) -> bool: """Returns 'true' if left and right are overlapping TypeDictTypes.""" # All required keys in left are present and overlapping with something in right for key in left.required_keys: if key not in right.items: return False - if not is_overlapping_types(left.items[key], right.items[key], - ignore_promotions=ignore_promotions, - prohibit_none_typevar_overlap=prohibit_none_typevar_overlap): + if not is_overlapping_types( + left.items[key], + right.items[key], + ignore_promotions=ignore_promotions, + prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, + ): return False # Repeat check in the other direction for key in right.required_keys: if key not in left.items: return False - if not is_overlapping_types(left.items[key], right.items[key], - ignore_promotions=ignore_promotions): + if not is_overlapping_types( + left.items[key], right.items[key], ignore_promotions=ignore_promotions + ): return False # The presence of any additional optional keys does not affect whether the two @@ -444,26 +504,35 @@ def are_typed_dicts_overlapping(left: TypedDictType, right: TypedDictType, *, return True -def are_tuples_overlapping(left: Type, right: Type, *, - ignore_promotions: bool = False, - prohibit_none_typevar_overlap: bool = False) -> bool: +def are_tuples_overlapping( + left: Type, + right: Type, + *, + ignore_promotions: bool = False, + prohibit_none_typevar_overlap: bool = False, +) -> bool: """Returns true if left and right are overlapping tuples.""" left, right = get_proper_types((left, right)) left = adjust_tuple(left, right) or left right = adjust_tuple(right, left) or right - assert isinstance(left, TupleType), f'Type {left} is not a tuple' - assert isinstance(right, TupleType), f'Type {right} is not a tuple' + assert isinstance(left, TupleType), f"Type {left} is not a tuple" + assert isinstance(right, TupleType), f"Type {right} is not a tuple" if len(left.items) != len(right.items): return False - return all(is_overlapping_types(l, r, - ignore_promotions=ignore_promotions, - prohibit_none_typevar_overlap=prohibit_none_typevar_overlap) - for l, r in zip(left.items, right.items)) + return all( + is_overlapping_types( + l, + r, + ignore_promotions=ignore_promotions, + prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, + ) + for l, r in zip(left.items, right.items) + ) def adjust_tuple(left: ProperType, r: ProperType) -> Optional[TupleType]: """Find out if `left` is a Tuple[A, ...], and adjust its length to `right`""" - if isinstance(left, Instance) and left.type.fullname == 'builtins.tuple': + if isinstance(left, Instance) and left.type.fullname == "builtins.tuple": n = r.length() if isinstance(r, TupleType) else 1 return TupleType([left.args[0]] * n, left) return None @@ -471,8 +540,9 @@ def adjust_tuple(left: ProperType, r: ProperType) -> Optional[TupleType]: def is_tuple(typ: Type) -> bool: typ = get_proper_type(typ) - return (isinstance(typ, TupleType) - or (isinstance(typ, Instance) and typ.type.fullname == 'builtins.tuple')) + return isinstance(typ, TupleType) or ( + isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple" + ) class TypeMeetVisitor(TypeVisitor[ProperType]): @@ -500,14 +570,14 @@ def visit_union_type(self, t: UnionType) -> ProperType: for y in self.s.items: meets.append(meet_types(x, y)) else: - meets = [meet_types(x, self.s) - for x in t.items] + meets = [meet_types(x, self.s) for x in t.items] return make_simplified_union(meets) def visit_none_type(self, t: NoneType) -> ProperType: if state.strict_optional: - if isinstance(self.s, NoneType) or (isinstance(self.s, Instance) and - self.s.type.fullname == 'builtins.object'): + if isinstance(self.s, NoneType) or ( + isinstance(self.s, Instance) and self.s.type.fullname == "builtins.object" + ): return t else: return UninhabitedType() @@ -622,8 +692,10 @@ def visit_callable_type(self, t: CallableType) -> ProperType: result = meet_similar_callables(t, self.s) # We set the from_type_type flag to suppress error when a collection of # concrete class objects gets inferred as their common abstract superclass. - if not ((t.is_type_obj() and t.type_object().is_abstract) or - (self.s.is_type_obj() and self.s.type_object().is_abstract)): + if not ( + (t.is_type_obj() and t.type_object().is_abstract) + or (self.s.is_type_obj() and self.s.type_object().is_abstract) + ): result.from_type_type = True if isinstance(get_proper_type(result.ret_type), UninhabitedType): # Return a plain None or instead of a weird function. @@ -669,7 +741,7 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: return TupleType(items, tuple_fallback(t)) elif isinstance(self.s, Instance): # meet(Tuple[t1, t2, <...>], Tuple[s, ...]) == Tuple[meet(t1, s), meet(t2, s), <...>]. - if self.s.type.fullname == 'builtins.tuple' and self.s.args: + if self.s.type.fullname == "builtins.tuple" and self.s.args: return t.copy_modified(items=[meet_types(it, self.s.args[0]) for it in t.items]) elif is_proper_subtype(t, self.s): # A named tuple that inherits from a normal class @@ -679,8 +751,9 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: def visit_typeddict_type(self, t: TypedDictType) -> ProperType: if isinstance(self.s, TypedDictType): for (name, l, r) in self.s.zip(t): - if (not is_equivalent(l, r) or - (name in t.required_keys) != (name in self.s.required_keys)): + if not is_equivalent(l, r) or (name in t.required_keys) != ( + name in self.s.required_keys + ): return self.default(self.s) item_list: List[Tuple[str, Type]] = [] for (item_name, s_item_type, t_item_type) in self.s.zipall(t): @@ -709,7 +782,7 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: def visit_partial_type(self, t: PartialType) -> ProperType: # We can't determine the meet of partial types. We should never get here. - assert False, 'Internal error' + assert False, "Internal error" def visit_type_type(self, t: TypeType) -> ProperType: if isinstance(self.s, TypeType): @@ -717,7 +790,7 @@ def visit_type_type(self, t: TypeType) -> ProperType: if not isinstance(typ, NoneType): typ = TypeType.make_normalized(typ, line=t.line) return typ - elif isinstance(self.s, Instance) and self.s.type.fullname == 'builtins.type': + elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type": return t elif isinstance(self.s, CallableType): return self.meet(t, self.s) @@ -749,14 +822,16 @@ def meet_similar_callables(t: CallableType, s: CallableType) -> CallableType: # TODO in combine_similar_callables also applies here (names and kinds) # The fallback type can be either 'function' or 'type'. The result should have 'function' as # fallback only if both operands have it as 'function'. - if t.fallback.type.fullname != 'builtins.function': + if t.fallback.type.fullname != "builtins.function": fallback = t.fallback else: fallback = s.fallback - return t.copy_modified(arg_types=arg_types, - ret_type=meet_types(t.ret_type, s.ret_type), - fallback=fallback, - name=None) + return t.copy_modified( + arg_types=arg_types, + ret_type=meet_types(t.ret_type, s.ret_type), + fallback=fallback, + name=None, + ) def meet_type_list(types: List[Type]) -> Type: @@ -787,11 +862,12 @@ def typed_dict_mapping_pair(left: Type, right: Type) -> bool: _, other = right, left else: return False - return isinstance(other, Instance) and other.type.has_base('typing.Mapping') + return isinstance(other, Instance) and other.type.has_base("typing.Mapping") -def typed_dict_mapping_overlap(left: Type, right: Type, - overlapping: Callable[[Type, Type], bool]) -> bool: +def typed_dict_mapping_overlap( + left: Type, right: Type, overlapping: Callable[[Type, Type], bool] +) -> bool: """Check if a TypedDict type is overlapping with a Mapping. The basic logic here consists of two rules: @@ -831,7 +907,7 @@ def typed_dict_mapping_overlap(left: Type, right: Type, assert isinstance(right, TypedDictType) typed, other = right, left - mapping = next(base for base in other.type.mro if base.fullname == 'typing.Mapping') + mapping = next(base for base in other.type.mro if base.fullname == "typing.Mapping") other = map_instance_to_supertype(other, mapping) key_type, value_type = get_proper_types(other.args) diff --git a/mypy/memprofile.py b/mypy/memprofile.py index ac49fd346abc4..b49bf8048e3b4 100644 --- a/mypy/memprofile.py +++ b/mypy/memprofile.py @@ -4,18 +4,17 @@ owned by particular AST nodes, etc. """ -from collections import defaultdict import gc import sys -from typing import List, Dict, Iterable, Tuple, cast +from collections import defaultdict +from typing import Dict, Iterable, List, Tuple, cast from mypy.nodes import FakeInfo, Node from mypy.types import Type from mypy.util import get_class_descriptors -def collect_memory_stats() -> Tuple[Dict[str, int], - Dict[str, int]]: +def collect_memory_stats() -> Tuple[Dict[str, int], Dict[str, int]]: """Return stats about memory use. Return a tuple with these items: @@ -31,25 +30,25 @@ def collect_memory_stats() -> Tuple[Dict[str, int], # Processing these would cause a crash. continue n = type(obj).__name__ - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): # Keep track of which class a particular __dict__ is associated with. - inferred[id(obj.__dict__)] = f'{n} (__dict__)' + inferred[id(obj.__dict__)] = f"{n} (__dict__)" if isinstance(obj, (Node, Type)): # type: ignore - if hasattr(obj, '__dict__'): + if hasattr(obj, "__dict__"): for x in obj.__dict__.values(): if isinstance(x, list): # Keep track of which node a list is associated with. - inferred[id(x)] = f'{n} (list)' + inferred[id(x)] = f"{n} (list)" if isinstance(x, tuple): # Keep track of which node a list is associated with. - inferred[id(x)] = f'{n} (tuple)' + inferred[id(x)] = f"{n} (tuple)" for k in get_class_descriptors(type(obj)): x = getattr(obj, k, None) if isinstance(x, list): - inferred[id(x)] = f'{n} (list)' + inferred[id(x)] = f"{n} (list)" if isinstance(x, tuple): - inferred[id(x)] = f'{n} (tuple)' + inferred[id(x)] = f"{n} (tuple)" freqs: Dict[str, int] = {} memuse: Dict[str, int] = {} @@ -65,27 +64,28 @@ def collect_memory_stats() -> Tuple[Dict[str, int], def print_memory_profile(run_gc: bool = True) -> None: - if not sys.platform.startswith('win'): + if not sys.platform.startswith("win"): import resource + system_memuse = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss else: system_memuse = -1 # TODO: Support this on Windows if run_gc: gc.collect() freqs, memuse = collect_memory_stats() - print('%7s %7s %7s %s' % ('Freq', 'Size(k)', 'AvgSize', 'Type')) - print('-------------------------------------------') + print("%7s %7s %7s %s" % ("Freq", "Size(k)", "AvgSize", "Type")) + print("-------------------------------------------") totalmem = 0 i = 0 for n, mem in sorted(memuse.items(), key=lambda x: -x[1]): f = freqs[n] if i < 50: - print('%7d %7d %7.0f %s' % (f, mem // 1024, mem / f, n)) + print("%7d %7d %7.0f %s" % (f, mem // 1024, mem / f, n)) i += 1 totalmem += mem print() - print('Mem usage RSS ', system_memuse // 1024) - print('Total reachable ', totalmem // 1024) + print("Mem usage RSS ", system_memuse // 1024) + print("Total reachable ", totalmem // 1024) def find_recursive_objects(objs: List[object]) -> None: @@ -112,8 +112,8 @@ def visit(o: object) -> None: if type(obj) in (list, tuple, set): for x in cast(Iterable[object], obj): visit(x) - if hasattr(obj, '__slots__'): + if hasattr(obj, "__slots__"): for base in type.mro(type(obj)): - for slot in getattr(base, '__slots__', ()): + for slot in getattr(base, "__slots__", ()): if hasattr(obj, slot): visit(getattr(obj, slot)) diff --git a/mypy/message_registry.py b/mypy/message_registry.py index 981c82cfc12cf..422b57bebfa41 100644 --- a/mypy/message_registry.py +++ b/mypy/message_registry.py @@ -7,6 +7,7 @@ """ from typing import NamedTuple, Optional + from typing_extensions import Final from mypy import errorcodes as codes @@ -68,7 +69,7 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": INCOMPATIBLE_TYPES_IN_YIELD: Final = ErrorMessage('Incompatible types in "yield"') INCOMPATIBLE_TYPES_IN_YIELD_FROM: Final = ErrorMessage('Incompatible types in "yield from"') INCOMPATIBLE_TYPES_IN_STR_INTERPOLATION: Final = "Incompatible types in string interpolation" -INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage('Incompatible types in capture pattern') +INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage("Incompatible types in capture pattern") MUST_HAVE_NONE_RETURN_TYPE: Final = ErrorMessage('The return type of "{}" must be None') INVALID_TUPLE_INDEX_TYPE: Final = ErrorMessage("Invalid tuple index type") TUPLE_INDEX_OUT_OF_RANGE: Final = ErrorMessage("Tuple index out of range") @@ -76,7 +77,7 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": CANNOT_INFER_LAMBDA_TYPE: Final = ErrorMessage("Cannot infer type of lambda") CANNOT_ACCESS_INIT: Final = ( 'Accessing "__init__" on an instance is unsound, since instance.__init__ could be from' - ' an incompatible subclass' + " an incompatible subclass" ) NON_INSTANCE_NEW_TYPE: Final = ErrorMessage('"__new__" must return a class instance (got {})') INVALID_NEW_TYPE: Final = ErrorMessage('Incompatible return type for "__new__"') @@ -141,14 +142,13 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": code=codes.TRUTHY_BOOL, ) FUNCTION_ALWAYS_TRUE: Final = ErrorMessage( - 'Function {} could always be true in boolean context', - code=codes.TRUTHY_BOOL, + "Function {} could always be true in boolean context", code=codes.TRUTHY_BOOL ) -NOT_CALLABLE: Final = '{} not callable' +NOT_CALLABLE: Final = "{} not callable" PYTHON2_PRINT_FILE_TYPE: Final = ( 'Argument "file" to "print" has incompatible type "{}"; expected "{}"' ) -TYPE_MUST_BE_USED: Final = 'Value of type {} must be used' +TYPE_MUST_BE_USED: Final = "Value of type {} must be used" # Generic GENERIC_INSTANCE_VAR_CLASS_ACCESS: Final = ( @@ -171,8 +171,9 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": TYPEVAR_BOUND_MUST_BE_TYPE: Final = 'TypeVar "bound" must be a type' TYPEVAR_UNEXPECTED_ARGUMENT: Final = 'Unexpected argument to "TypeVar()"' UNBOUND_TYPEVAR: Final = ( - 'A function returning TypeVar should receive at least ' - 'one argument containing the same Typevar') + "A function returning TypeVar should receive at least " + "one argument containing the same Typevar" +) # Super TOO_MANY_ARGS_FOR_SUPER: Final = ErrorMessage('Too many arguments for "super"') @@ -230,10 +231,8 @@ def format(self, *args: object, **kwargs: object) -> "ErrorMessage": 'Cannot override class variable (previously declared on base class "{}") with instance ' "variable" ) -CLASS_VAR_WITH_TYPEVARS: Final = 'ClassVar cannot contain type variables' -CLASS_VAR_OUTSIDE_OF_CLASS: Final = ( - 'ClassVar can only be used for assignments in class body' -) +CLASS_VAR_WITH_TYPEVARS: Final = "ClassVar cannot contain type variables" +CLASS_VAR_OUTSIDE_OF_CLASS: Final = "ClassVar can only be used for assignments in class body" # Protocol RUNTIME_PROTOCOL_EXPECTED: Final = ErrorMessage( diff --git a/mypy/messages.py b/mypy/messages.py index 47d8669a6c4fd..3068390ad30c1 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -8,57 +8,109 @@ Historically we tried to avoid all message string literals in the type checker but we are moving away from this convention. """ -from contextlib import contextmanager - -from mypy.backports import OrderedDict -import re import difflib +import re +from contextlib import contextmanager from textwrap import dedent - from typing import ( - cast, List, Dict, Any, Sequence, Iterable, Iterator, Tuple, Set, Optional, Union, Callable + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + Union, + cast, ) + from typing_extensions import Final +from mypy import errorcodes as codes, message_registry +from mypy.backports import OrderedDict from mypy.erasetype import erase_type -from mypy.errors import Errors, ErrorWatcher, ErrorInfo -from mypy.types import ( - Type, CallableType, Instance, TypeVarType, TupleType, TypedDictType, LiteralType, - UnionType, NoneType, AnyType, Overloaded, FunctionLike, DeletedType, TypeType, - UninhabitedType, TypeOfAny, UnboundType, PartialType, get_proper_type, ProperType, - ParamSpecType, Parameters, get_proper_types -) -from mypy.typetraverser import TypeTraverserVisitor +from mypy.errorcodes import ErrorCode +from mypy.errors import ErrorInfo, Errors, ErrorWatcher from mypy.nodes import ( - TypeInfo, Context, MypyFile, FuncDef, reverse_builtin_aliases, - ArgKind, ARG_POS, ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, - ReturnStmt, NameExpr, Var, CONTRAVARIANT, COVARIANT, SymbolNode, - CallExpr, IndexExpr, StrExpr, SymbolTable, SYMBOL_FUNCBASE_TYPES + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + CONTRAVARIANT, + COVARIANT, + SYMBOL_FUNCBASE_TYPES, + ArgKind, + CallExpr, + Context, + FuncDef, + IndexExpr, + MypyFile, + NameExpr, + ReturnStmt, + StrExpr, + SymbolNode, + SymbolTable, + TypeInfo, + Var, + reverse_builtin_aliases, ) from mypy.operators import op_methods, op_methods_to_symbols +from mypy.sametypes import is_same_type from mypy.subtypes import ( - is_subtype, find_member, get_member_flags, - IS_SETTABLE, IS_CLASSVAR, IS_CLASS_OR_STATIC, + IS_CLASS_OR_STATIC, + IS_CLASSVAR, + IS_SETTABLE, + find_member, + get_member_flags, + is_subtype, ) -from mypy.sametypes import is_same_type from mypy.typeops import separate_union_literals -from mypy.util import unmangle, plural_s -from mypy.errorcodes import ErrorCode -from mypy import message_registry, errorcodes as codes +from mypy.types import ( + AnyType, + CallableType, + DeletedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + get_proper_type, + get_proper_types, +) +from mypy.typetraverser import TypeTraverserVisitor +from mypy.util import plural_s, unmangle TYPES_FOR_UNIMPORTED_HINTS: Final = { - 'typing.Any', - 'typing.Callable', - 'typing.Dict', - 'typing.Iterable', - 'typing.Iterator', - 'typing.List', - 'typing.Optional', - 'typing.Set', - 'typing.Tuple', - 'typing.TypeVar', - 'typing.Union', - 'typing.cast', + "typing.Any", + "typing.Callable", + "typing.Dict", + "typing.Iterable", + "typing.Iterator", + "typing.List", + "typing.Optional", + "typing.Set", + "typing.Tuple", + "typing.TypeVar", + "typing.Union", + "typing.cast", } @@ -76,16 +128,16 @@ # test-data/unit/fixtures/) that provides the definition. This is used for # generating better error messages when running mypy tests only. SUGGESTED_TEST_FIXTURES: Final = { - 'builtins.list': 'list.pyi', - 'builtins.dict': 'dict.pyi', - 'builtins.set': 'set.pyi', - 'builtins.tuple': 'tuple.pyi', - 'builtins.bool': 'bool.pyi', - 'builtins.Exception': 'exception.pyi', - 'builtins.BaseException': 'exception.pyi', - 'builtins.isinstance': 'isinstancelist.pyi', - 'builtins.property': 'property.pyi', - 'builtins.classmethod': 'classmethod.pyi', + "builtins.list": "list.pyi", + "builtins.dict": "dict.pyi", + "builtins.set": "set.pyi", + "builtins.tuple": "tuple.pyi", + "builtins.bool": "bool.pyi", + "builtins.Exception": "exception.pyi", + "builtins.BaseException": "exception.pyi", + "builtins.isinstance": "isinstancelist.pyi", + "builtins.property": "property.pyi", + "builtins.classmethod": "classmethod.pyi", } @@ -118,10 +170,12 @@ def __init__(self, errors: Errors, modules: Dict[str, MypyFile]) -> None: # Helpers # - def filter_errors(self, *, filter_errors: bool = True, - save_filtered_errors: bool = False) -> ErrorWatcher: - return ErrorWatcher(self.errors, filter_errors=filter_errors, - save_filtered_errors=save_filtered_errors) + def filter_errors( + self, *, filter_errors: bool = True, save_filtered_errors: bool = False + ) -> ErrorWatcher: + return ErrorWatcher( + self.errors, filter_errors=filter_errors, save_filtered_errors=save_filtered_errors + ) def add_errors(self, errors: List[ErrorInfo]) -> None: """Add errors in messages to this builder.""" @@ -139,16 +193,18 @@ def disable_type_names(self) -> Iterator[None]: def are_type_names_disabled(self) -> bool: return len(self._disable_type_names) > 0 and self._disable_type_names[-1] - def report(self, - msg: str, - context: Optional[Context], - severity: str, - *, - code: Optional[ErrorCode] = None, - file: Optional[str] = None, - origin: Optional[Context] = None, - offset: int = 0, - allow_dups: bool = False) -> None: + def report( + self, + msg: str, + context: Optional[Context], + severity: str, + *, + code: Optional[ErrorCode] = None, + file: Optional[str] = None, + origin: Optional[Context] = None, + offset: int = 0, + allow_dups: bool = False, + ) -> None: """Report an error or note (unless disabled).""" if origin is not None: end_line = origin.end_line @@ -156,47 +212,80 @@ def report(self, end_line = context.end_line else: end_line = None - self.errors.report(context.get_line() if context else -1, - context.get_column() if context else -1, - msg, severity=severity, file=file, offset=offset, - origin_line=origin.get_line() if origin else None, - end_line=end_line, - end_column=context.end_column if context else -1, - code=code, allow_dups=allow_dups) - - def fail(self, - msg: str, - context: Optional[Context], - *, - code: Optional[ErrorCode] = None, - file: Optional[str] = None, - origin: Optional[Context] = None, - allow_dups: bool = False) -> None: + self.errors.report( + context.get_line() if context else -1, + context.get_column() if context else -1, + msg, + severity=severity, + file=file, + offset=offset, + origin_line=origin.get_line() if origin else None, + end_line=end_line, + end_column=context.end_column if context else -1, + code=code, + allow_dups=allow_dups, + ) + + def fail( + self, + msg: str, + context: Optional[Context], + *, + code: Optional[ErrorCode] = None, + file: Optional[str] = None, + origin: Optional[Context] = None, + allow_dups: bool = False, + ) -> None: """Report an error message (unless disabled).""" - self.report(msg, context, 'error', code=code, file=file, - origin=origin, allow_dups=allow_dups) - - def note(self, - msg: str, - context: Context, - file: Optional[str] = None, - origin: Optional[Context] = None, - offset: int = 0, - allow_dups: bool = False, - *, - code: Optional[ErrorCode] = None) -> None: + self.report( + msg, context, "error", code=code, file=file, origin=origin, allow_dups=allow_dups + ) + + def note( + self, + msg: str, + context: Context, + file: Optional[str] = None, + origin: Optional[Context] = None, + offset: int = 0, + allow_dups: bool = False, + *, + code: Optional[ErrorCode] = None, + ) -> None: """Report a note (unless disabled).""" - self.report(msg, context, 'note', file=file, origin=origin, - offset=offset, allow_dups=allow_dups, code=code) - - def note_multiline(self, messages: str, context: Context, file: Optional[str] = None, - origin: Optional[Context] = None, offset: int = 0, - allow_dups: bool = False, - code: Optional[ErrorCode] = None) -> None: + self.report( + msg, + context, + "note", + file=file, + origin=origin, + offset=offset, + allow_dups=allow_dups, + code=code, + ) + + def note_multiline( + self, + messages: str, + context: Context, + file: Optional[str] = None, + origin: Optional[Context] = None, + offset: int = 0, + allow_dups: bool = False, + code: Optional[ErrorCode] = None, + ) -> None: """Report as many notes as lines in the message (unless disabled).""" for msg in messages.splitlines(): - self.report(msg, context, 'note', file=file, origin=origin, - offset=offset, allow_dups=allow_dups, code=code) + self.report( + msg, + context, + "note", + file=file, + origin=origin, + offset=offset, + allow_dups=allow_dups, + code=code, + ) # # Specific operations @@ -206,12 +295,14 @@ def note_multiline(self, messages: str, context: Context, file: Optional[str] = # get some information as arguments, and they build an error message based # on them. - def has_no_attr(self, - original_type: Type, - typ: Type, - member: str, - context: Context, - module_symbol_table: Optional[SymbolTable] = None) -> Type: + def has_no_attr( + self, + original_type: Type, + typ: Type, + member: str, + context: Context, + module_symbol_table: Optional[SymbolTable] = None, + ) -> Type: """Report a missing or non-accessible member. original_type is the top-level type on which the error occurred. @@ -231,12 +322,14 @@ def has_no_attr(self, original_type = get_proper_type(original_type) typ = get_proper_type(typ) - if (isinstance(original_type, Instance) and - original_type.type.has_readable_member(member)): + if isinstance(original_type, Instance) and original_type.type.has_readable_member(member): self.fail(f'Member "{member}" is not assignable', context) - elif member == '__contains__': - self.fail('Unsupported right operand type for in ({})'.format( - format_type(original_type)), context, code=codes.OPERATOR) + elif member == "__contains__": + self.fail( + "Unsupported right operand type for in ({})".format(format_type(original_type)), + context, + code=codes.OPERATOR, + ) elif member in op_methods.values(): # Access to a binary operator member (e.g. _add). This case does # not handle indexing operations. @@ -244,44 +337,69 @@ def has_no_attr(self, if method == member: self.unsupported_left_operand(op, original_type, context) break - elif member == '__neg__': - self.fail('Unsupported operand type for unary - ({})'.format( - format_type(original_type)), context, code=codes.OPERATOR) - elif member == '__pos__': - self.fail('Unsupported operand type for unary + ({})'.format( - format_type(original_type)), context, code=codes.OPERATOR) - elif member == '__invert__': - self.fail('Unsupported operand type for ~ ({})'.format( - format_type(original_type)), context, code=codes.OPERATOR) - elif member == '__getitem__': + elif member == "__neg__": + self.fail( + "Unsupported operand type for unary - ({})".format(format_type(original_type)), + context, + code=codes.OPERATOR, + ) + elif member == "__pos__": + self.fail( + "Unsupported operand type for unary + ({})".format(format_type(original_type)), + context, + code=codes.OPERATOR, + ) + elif member == "__invert__": + self.fail( + "Unsupported operand type for ~ ({})".format(format_type(original_type)), + context, + code=codes.OPERATOR, + ) + elif member == "__getitem__": # Indexed get. # TODO: Fix this consistently in format_type if isinstance(original_type, CallableType) and original_type.is_type_obj(): - self.fail('The type {} is not generic and not indexable'.format( - format_type(original_type)), context) + self.fail( + "The type {} is not generic and not indexable".format( + format_type(original_type) + ), + context, + ) else: - self.fail('Value of type {} is not indexable'.format( - format_type(original_type)), context, code=codes.INDEX) - elif member == '__setitem__': + self.fail( + "Value of type {} is not indexable".format(format_type(original_type)), + context, + code=codes.INDEX, + ) + elif member == "__setitem__": # Indexed set. - self.fail('Unsupported target for indexed assignment ({})'.format( - format_type(original_type)), context, code=codes.INDEX) - elif member == '__call__': - if isinstance(original_type, Instance) and \ - (original_type.type.fullname == 'builtins.function'): + self.fail( + "Unsupported target for indexed assignment ({})".format( + format_type(original_type) + ), + context, + code=codes.INDEX, + ) + elif member == "__call__": + if isinstance(original_type, Instance) and ( + original_type.type.fullname == "builtins.function" + ): # "'function' not callable" is a confusing error message. # Explain that the problem is that the type of the function is not known. - self.fail('Cannot call function of unknown type', context, code=codes.OPERATOR) + self.fail("Cannot call function of unknown type", context, code=codes.OPERATOR) else: - self.fail(message_registry.NOT_CALLABLE.format( - format_type(original_type)), context, code=codes.OPERATOR) + self.fail( + message_registry.NOT_CALLABLE.format(format_type(original_type)), + context, + code=codes.OPERATOR, + ) else: # The non-special case: a missing ordinary attribute. - extra = '' - if member == '__iter__': - extra = ' (not iterable)' - elif member == '__aiter__': - extra = ' (not async iterable)' + extra = "" + if member == "__iter__": + extra = " (not iterable)" + elif member == "__aiter__": + extra = " (not async iterable)" if not self.are_type_names_disabled(): failed = False if isinstance(original_type, Instance) and original_type.type.names: @@ -297,9 +415,9 @@ def has_no_attr(self, matches = [m for m in COMMON_MISTAKES.get(member, []) if m in alternatives] matches.extend(best_matches(member, alternatives)[:3]) - if member == '__aiter__' and matches == ['__iter__']: + if member == "__aiter__" and matches == ["__iter__"]: matches = [] # Avoid misleading suggestion - if member == '__div__' and matches == ['__truediv__']: + if member == "__div__" and matches == ["__truediv__"]: # TODO: Handle differences in division between Python 2 and 3 more cleanly matches = [] if matches: @@ -311,71 +429,83 @@ def has_no_attr(self, extra, ), context, - code=codes.ATTR_DEFINED) + code=codes.ATTR_DEFINED, + ) failed = True if not failed: self.fail( '{} has no attribute "{}"{}'.format( - format_type(original_type), member, extra), + format_type(original_type), member, extra + ), context, - code=codes.ATTR_DEFINED) + code=codes.ATTR_DEFINED, + ) elif isinstance(original_type, UnionType): # The checker passes "object" in lieu of "None" for attribute # checks, so we manually convert it back. typ_format, orig_type_format = format_type_distinctly(typ, original_type) - if typ_format == '"object"' and \ - any(type(item) == NoneType for item in original_type.items): + if typ_format == '"object"' and any( + type(item) == NoneType for item in original_type.items + ): typ_format = '"None"' - self.fail('Item {} of {} has no attribute "{}"{}'.format( - typ_format, orig_type_format, member, extra), context, - code=codes.UNION_ATTR) + self.fail( + 'Item {} of {} has no attribute "{}"{}'.format( + typ_format, orig_type_format, member, extra + ), + context, + code=codes.UNION_ATTR, + ) elif isinstance(original_type, TypeVarType): bound = get_proper_type(original_type.upper_bound) if isinstance(bound, UnionType): typ_fmt, bound_fmt = format_type_distinctly(typ, bound) original_type_fmt = format_type(original_type) self.fail( - 'Item {} of the upper bound {} of type variable {} has no ' + "Item {} of the upper bound {} of type variable {} has no " 'attribute "{}"{}'.format( - typ_fmt, bound_fmt, original_type_fmt, member, extra), - context, code=codes.UNION_ATTR) + typ_fmt, bound_fmt, original_type_fmt, member, extra + ), + context, + code=codes.UNION_ATTR, + ) return AnyType(TypeOfAny.from_error) - def unsupported_operand_types(self, - op: str, - left_type: Any, - right_type: Any, - context: Context, - *, - code: ErrorCode = codes.OPERATOR) -> None: + def unsupported_operand_types( + self, + op: str, + left_type: Any, + right_type: Any, + context: Context, + *, + code: ErrorCode = codes.OPERATOR, + ) -> None: """Report unsupported operand types for a binary operation. Types can be Type objects or strings. """ - left_str = '' + left_str = "" if isinstance(left_type, str): left_str = left_type else: left_str = format_type(left_type) - right_str = '' + right_str = "" if isinstance(right_type, str): right_str = right_type else: right_str = format_type(right_type) if self.are_type_names_disabled(): - msg = f'Unsupported operand types for {op} (likely involving Union)' + msg = f"Unsupported operand types for {op} (likely involving Union)" else: - msg = f'Unsupported operand types for {op} ({left_str} and {right_str})' + msg = f"Unsupported operand types for {op} ({left_str} and {right_str})" self.fail(msg, context, code=code) - def unsupported_left_operand(self, op: str, typ: Type, - context: Context) -> None: + def unsupported_left_operand(self, op: str, typ: Type, context: Context) -> None: if self.are_type_names_disabled(): - msg = f'Unsupported left operand type for {op} (some union)' + msg = f"Unsupported left operand type for {op} (some union)" else: - msg = f'Unsupported left operand type for {op} ({format_type(typ)})' + msg = f"Unsupported left operand type for {op} ({format_type(typ)})" self.fail(msg, context, code=codes.OPERATOR) def not_callable(self, typ: Type, context: Context) -> Type: @@ -383,20 +513,25 @@ def not_callable(self, typ: Type, context: Context) -> Type: return AnyType(TypeOfAny.from_error) def untyped_function_call(self, callee: CallableType, context: Context) -> Type: - name = callable_name(callee) or '(unknown)' - self.fail(f'Call to untyped function {name} in typed context', context, - code=codes.NO_UNTYPED_CALL) + name = callable_name(callee) or "(unknown)" + self.fail( + f"Call to untyped function {name} in typed context", + context, + code=codes.NO_UNTYPED_CALL, + ) return AnyType(TypeOfAny.from_error) - def incompatible_argument(self, - n: int, - m: int, - callee: CallableType, - arg_type: Type, - arg_kind: ArgKind, - object_type: Optional[Type], - context: Context, - outer_context: Context) -> Optional[ErrorCode]: + def incompatible_argument( + self, + n: int, + m: int, + callee: CallableType, + arg_type: Type, + arg_kind: ArgKind, + object_type: Optional[Type], + context: Context, + outer_context: Context, + ) -> Optional[ErrorCode]: """Report an error about an incompatible argument type. The argument type is arg_type, argument number is n and the @@ -409,7 +544,7 @@ def incompatible_argument(self, """ arg_type = get_proper_type(arg_type) - target = '' + target = "" callee_name = callable_name(callee) if callee_name is not None: name = callee_name @@ -419,56 +554,70 @@ def incompatible_argument(self, base = extract_type(name) for method, op in op_methods_to_symbols.items(): - for variant in method, '__r' + method[2:]: + for variant in method, "__r" + method[2:]: # FIX: do not rely on textual formatting if name.startswith(f'"{variant}" of'): - if op == 'in' or variant != method: + if op == "in" or variant != method: # Reversed order of base/argument. - self.unsupported_operand_types(op, arg_type, base, - context, code=codes.OPERATOR) + self.unsupported_operand_types( + op, arg_type, base, context, code=codes.OPERATOR + ) else: - self.unsupported_operand_types(op, base, arg_type, - context, code=codes.OPERATOR) + self.unsupported_operand_types( + op, base, arg_type, context, code=codes.OPERATOR + ) return codes.OPERATOR if name.startswith('"__cmp__" of'): - self.unsupported_operand_types("comparison", arg_type, base, - context, code=codes.OPERATOR) + self.unsupported_operand_types( + "comparison", arg_type, base, context, code=codes.OPERATOR + ) return codes.INDEX if name.startswith('"__getitem__" of'): - self.invalid_index_type(arg_type, callee.arg_types[n - 1], base, context, - code=codes.INDEX) + self.invalid_index_type( + arg_type, callee.arg_types[n - 1], base, context, code=codes.INDEX + ) return codes.INDEX if name.startswith('"__setitem__" of'): if n == 1: - self.invalid_index_type(arg_type, callee.arg_types[n - 1], base, context, - code=codes.INDEX) + self.invalid_index_type( + arg_type, callee.arg_types[n - 1], base, context, code=codes.INDEX + ) return codes.INDEX else: - msg = '{} (expression has type {}, target has type {})' - arg_type_str, callee_type_str = format_type_distinctly(arg_type, - callee.arg_types[n - 1]) - self.fail(msg.format(message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, - arg_type_str, callee_type_str), - context, code=codes.ASSIGNMENT) + msg = "{} (expression has type {}, target has type {})" + arg_type_str, callee_type_str = format_type_distinctly( + arg_type, callee.arg_types[n - 1] + ) + self.fail( + msg.format( + message_registry.INCOMPATIBLE_TYPES_IN_ASSIGNMENT, + arg_type_str, + callee_type_str, + ), + context, + code=codes.ASSIGNMENT, + ) return codes.ASSIGNMENT - target = f'to {name} ' + target = f"to {name} " - msg = '' + msg = "" code = codes.MISC notes: List[str] = [] - if callee_name == '': + if callee_name == "": name = callee_name[1:-1] n -= 1 - actual_type_str, expected_type_str = format_type_distinctly(arg_type, - callee.arg_types[0]) - msg = '{} item {} has incompatible type {}; expected {}'.format( - name.title(), n, actual_type_str, expected_type_str) + actual_type_str, expected_type_str = format_type_distinctly( + arg_type, callee.arg_types[0] + ) + msg = "{} item {} has incompatible type {}; expected {}".format( + name.title(), n, actual_type_str, expected_type_str + ) code = codes.LIST_ITEM - elif callee_name == '': + elif callee_name == "": name = callee_name[1:-1] n -= 1 key_type, value_type = cast(TupleType, arg_type).items @@ -480,54 +629,66 @@ def incompatible_argument(self, expected_key_type_str = format_type(expected_key_type) else: key_type_str, expected_key_type_str = format_type_distinctly( - key_type, expected_key_type) + key_type, expected_key_type + ) if is_subtype(value_type, expected_value_type): value_type_str = format_type(value_type) expected_value_type_str = format_type(expected_value_type) else: value_type_str, expected_value_type_str = format_type_distinctly( - value_type, expected_value_type) - - msg = '{} entry {} has incompatible type {}: {}; expected {}: {}'.format( - name.title(), n, key_type_str, value_type_str, - expected_key_type_str, expected_value_type_str) + value_type, expected_value_type + ) + + msg = "{} entry {} has incompatible type {}: {}; expected {}: {}".format( + name.title(), + n, + key_type_str, + value_type_str, + expected_key_type_str, + expected_value_type_str, + ) code = codes.DICT_ITEM - elif callee_name == '': - actual_type_str, expected_type_str = map(strip_quotes, - format_type_distinctly(arg_type, - callee.arg_types[0])) - msg = 'List comprehension has incompatible type List[{}]; expected List[{}]'.format( - actual_type_str, expected_type_str) - elif callee_name == '': - actual_type_str, expected_type_str = map(strip_quotes, - format_type_distinctly(arg_type, - callee.arg_types[0])) - msg = 'Set comprehension has incompatible type Set[{}]; expected Set[{}]'.format( - actual_type_str, expected_type_str) - elif callee_name == '': - actual_type_str, expected_type_str = format_type_distinctly(arg_type, - callee.arg_types[n - 1]) - msg = ('{} expression in dictionary comprehension has incompatible type {}; ' - 'expected type {}').format( - 'Key' if n == 1 else 'Value', - actual_type_str, - expected_type_str) - elif callee_name == '': - actual_type_str, expected_type_str = format_type_distinctly(arg_type, - callee.arg_types[0]) - msg = 'Generator has incompatible item type {}; expected {}'.format( - actual_type_str, expected_type_str) + elif callee_name == "": + actual_type_str, expected_type_str = map( + strip_quotes, format_type_distinctly(arg_type, callee.arg_types[0]) + ) + msg = "List comprehension has incompatible type List[{}]; expected List[{}]".format( + actual_type_str, expected_type_str + ) + elif callee_name == "": + actual_type_str, expected_type_str = map( + strip_quotes, format_type_distinctly(arg_type, callee.arg_types[0]) + ) + msg = "Set comprehension has incompatible type Set[{}]; expected Set[{}]".format( + actual_type_str, expected_type_str + ) + elif callee_name == "": + actual_type_str, expected_type_str = format_type_distinctly( + arg_type, callee.arg_types[n - 1] + ) + msg = ( + "{} expression in dictionary comprehension has incompatible type {}; " + "expected type {}" + ).format("Key" if n == 1 else "Value", actual_type_str, expected_type_str) + elif callee_name == "": + actual_type_str, expected_type_str = format_type_distinctly( + arg_type, callee.arg_types[0] + ) + msg = "Generator has incompatible item type {}; expected {}".format( + actual_type_str, expected_type_str + ) else: try: expected_type = callee.arg_types[m - 1] except IndexError: # Varargs callees expected_type = callee.arg_types[-1] arg_type_str, expected_type_str = format_type_distinctly( - arg_type, expected_type, bare=True) + arg_type, expected_type, bare=True + ) if arg_kind == ARG_STAR: - arg_type_str = '*' + arg_type_str + arg_type_str = "*" + arg_type_str elif arg_kind == ARG_STAR2: - arg_type_str = '**' + arg_type_str + arg_type_str = "**" + arg_type_str # For function calls with keyword arguments, display the argument name rather than the # number. @@ -536,26 +697,32 @@ def incompatible_argument(self, arg_name = outer_context.arg_names[n - 1] if arg_name is not None: arg_label = f'"{arg_name}"' - if (arg_kind == ARG_STAR2 - and isinstance(arg_type, TypedDictType) - and m <= len(callee.arg_names) - and callee.arg_names[m - 1] is not None - and callee.arg_kinds[m - 1] != ARG_STAR2): + if ( + arg_kind == ARG_STAR2 + and isinstance(arg_type, TypedDictType) + and m <= len(callee.arg_names) + and callee.arg_names[m - 1] is not None + and callee.arg_kinds[m - 1] != ARG_STAR2 + ): arg_name = callee.arg_names[m - 1] assert arg_name is not None arg_type_str, expected_type_str = format_type_distinctly( - arg_type.items[arg_name], - expected_type, - bare=True) + arg_type.items[arg_name], expected_type, bare=True + ) arg_label = f'"{arg_name}"' if isinstance(outer_context, IndexExpr) and isinstance(outer_context.index, StrExpr): - msg = 'Value of "{}" has incompatible type {}; expected {}' .format( - outer_context.index.value, quote_type_string(arg_type_str), - quote_type_string(expected_type_str)) + msg = 'Value of "{}" has incompatible type {}; expected {}'.format( + outer_context.index.value, + quote_type_string(arg_type_str), + quote_type_string(expected_type_str), + ) else: - msg = 'Argument {} {}has incompatible type {}; expected {}'.format( - arg_label, target, quote_type_string(arg_type_str), - quote_type_string(expected_type_str)) + msg = "Argument {} {}has incompatible type {}; expected {}".format( + arg_label, + target, + quote_type_string(arg_type_str), + quote_type_string(expected_type_str), + ) object_type = get_proper_type(object_type) if isinstance(object_type, TypedDictType): code = codes.TYPEDDICT_ITEM @@ -575,43 +742,51 @@ def incompatible_argument(self, self.note(note_msg, context, code=code) return code - def incompatible_argument_note(self, - original_caller_type: ProperType, - callee_type: ProperType, - context: Context, - code: Optional[ErrorCode]) -> None: + def incompatible_argument_note( + self, + original_caller_type: ProperType, + callee_type: ProperType, + context: Context, + code: Optional[ErrorCode], + ) -> None: if isinstance(original_caller_type, (Instance, TupleType, TypedDictType)): if isinstance(callee_type, Instance) and callee_type.type.is_protocol: - self.report_protocol_problems(original_caller_type, callee_type, - context, code=code) + self.report_protocol_problems( + original_caller_type, callee_type, context, code=code + ) if isinstance(callee_type, UnionType): for item in callee_type.items: item = get_proper_type(item) if isinstance(item, Instance) and item.type.is_protocol: - self.report_protocol_problems(original_caller_type, item, - context, code=code) - if (isinstance(callee_type, CallableType) and - isinstance(original_caller_type, Instance)): - call = find_member('__call__', original_caller_type, original_caller_type, - is_operator=True) + self.report_protocol_problems( + original_caller_type, item, context, code=code + ) + if isinstance(callee_type, CallableType) and isinstance(original_caller_type, Instance): + call = find_member( + "__call__", original_caller_type, original_caller_type, is_operator=True + ) if call: self.note_call(original_caller_type, call, context, code=code) self.maybe_note_concatenate_pos_args(original_caller_type, callee_type, context, code) - def maybe_note_concatenate_pos_args(self, - original_caller_type: ProperType, - callee_type: ProperType, - context: Context, - code: Optional[ErrorCode] = None) -> None: + def maybe_note_concatenate_pos_args( + self, + original_caller_type: ProperType, + callee_type: ProperType, + context: Context, + code: Optional[ErrorCode] = None, + ) -> None: # pos-only vs positional can be confusing, with Concatenate - if (isinstance(callee_type, CallableType) and - isinstance(original_caller_type, CallableType) and - (original_caller_type.from_concatenate or callee_type.from_concatenate)): + if ( + isinstance(callee_type, CallableType) + and isinstance(original_caller_type, CallableType) + and (original_caller_type.from_concatenate or callee_type.from_concatenate) + ): names: List[str] = [] for c, o in zip( - callee_type.formal_arguments(), - original_caller_type.formal_arguments()): + callee_type.formal_arguments(), original_caller_type.formal_arguments() + ): if None in (c.pos, o.pos): # non-positional continue @@ -620,34 +795,54 @@ def maybe_note_concatenate_pos_args(self, if names: missing_arguments = '"' + '", "'.join(names) + '"' - self.note(f'This may be because "{original_caller_type.name}" has arguments ' - f'named: {missing_arguments}', context, code=code) - - def invalid_index_type(self, index_type: Type, expected_type: Type, base_str: str, - context: Context, *, code: ErrorCode) -> None: + self.note( + f'This may be because "{original_caller_type.name}" has arguments ' + f"named: {missing_arguments}", + context, + code=code, + ) + + def invalid_index_type( + self, + index_type: Type, + expected_type: Type, + base_str: str, + context: Context, + *, + code: ErrorCode, + ) -> None: index_str, expected_str = format_type_distinctly(index_type, expected_type) - self.fail('Invalid index type {} for {}; expected type {}'.format( - index_str, base_str, expected_str), context, code=code) - - def too_few_arguments(self, callee: CallableType, context: Context, - argument_names: Optional[Sequence[Optional[str]]]) -> None: + self.fail( + "Invalid index type {} for {}; expected type {}".format( + index_str, base_str, expected_str + ), + context, + code=code, + ) + + def too_few_arguments( + self, + callee: CallableType, + context: Context, + argument_names: Optional[Sequence[Optional[str]]], + ) -> None: if argument_names is not None: num_positional_args = sum(k is None for k in argument_names) - arguments_left = callee.arg_names[num_positional_args:callee.min_args] + arguments_left = callee.arg_names[num_positional_args : callee.min_args] diff = [k for k in arguments_left if k not in argument_names] if len(diff) == 1: - msg = 'Missing positional argument' + msg = "Missing positional argument" else: - msg = 'Missing positional arguments' + msg = "Missing positional arguments" callee_name = callable_name(callee) if callee_name is not None and diff and all(d is not None for d in diff): args = '", "'.join(cast(List[str], diff)) msg += f' "{args}" in call to {callee_name}' else: - msg = 'Too few arguments' + for_function(callee) + msg = "Too few arguments" + for_function(callee) else: - msg = 'Too few arguments' + for_function(callee) + msg = "Too few arguments" + for_function(callee) self.fail(msg, context, code=codes.CALL_ARG) def missing_named_argument(self, callee: CallableType, context: Context, name: str) -> None: @@ -655,14 +850,13 @@ def missing_named_argument(self, callee: CallableType, context: Context, name: s self.fail(msg, context, code=codes.CALL_ARG) def too_many_arguments(self, callee: CallableType, context: Context) -> None: - msg = 'Too many arguments' + for_function(callee) + msg = "Too many arguments" + for_function(callee) self.fail(msg, context, code=codes.CALL_ARG) self.maybe_note_about_special_args(callee, context) - def too_many_arguments_from_typed_dict(self, - callee: CallableType, - arg_type: TypedDictType, - context: Context) -> None: + def too_many_arguments_from_typed_dict( + self, callee: CallableType, arg_type: TypedDictType, context: Context + ) -> None: # Try to determine the name of the extra argument. for key in arg_type.items: if key not in callee.arg_names: @@ -673,25 +867,25 @@ def too_many_arguments_from_typed_dict(self, return self.fail(msg, context) - def too_many_positional_arguments(self, callee: CallableType, - context: Context) -> None: - msg = 'Too many positional arguments' + for_function(callee) + def too_many_positional_arguments(self, callee: CallableType, context: Context) -> None: + msg = "Too many positional arguments" + for_function(callee) self.fail(msg, context) self.maybe_note_about_special_args(callee, context) def maybe_note_about_special_args(self, callee: CallableType, context: Context) -> None: # https://github.com/python/mypy/issues/11309 - first_arg = callee.def_extras.get('first_arg') - if first_arg and first_arg not in {'self', 'cls', 'mcs'}: + first_arg = callee.def_extras.get("first_arg") + if first_arg and first_arg not in {"self", "cls", "mcs"}: self.note( - 'Looks like the first special argument in a method ' + "Looks like the first special argument in a method " 'is not named "self", "cls", or "mcs", ' - 'maybe it is missing?', + "maybe it is missing?", context, ) - def unexpected_keyword_argument(self, callee: CallableType, name: str, arg_type: Type, - context: Context) -> None: + def unexpected_keyword_argument( + self, callee: CallableType, name: str, arg_type: Type, context: Context + ) -> None: msg = f'Unexpected keyword argument "{name}"' + for_function(callee) # Suggest intended keyword, look for type match else fallback on any match. matching_type_args = [] @@ -714,15 +908,22 @@ def unexpected_keyword_argument(self, callee: CallableType, name: str, arg_type: assert callee.definition is not None fname = callable_name(callee) if not fname: # an alias to function with a different name - fname = 'Called function' - self.note(f'{fname} defined here', callee.definition, - file=module.path, origin=context, code=codes.CALL_ARG) + fname = "Called function" + self.note( + f"{fname} defined here", + callee.definition, + file=module.path, + origin=context, + code=codes.CALL_ARG, + ) - def duplicate_argument_value(self, callee: CallableType, index: int, - context: Context) -> None: - self.fail('{} gets multiple values for keyword argument "{}"'. - format(callable_name(callee) or 'Function', callee.arg_names[index]), - context) + def duplicate_argument_value(self, callee: CallableType, index: int, context: Context) -> None: + self.fail( + '{} gets multiple values for keyword argument "{}"'.format( + callable_name(callee) or "Function", callee.arg_names[index] + ), + context, + ) def does_not_return_value(self, callee_type: Optional[Type], context: Context) -> None: """Report an error about use of an unusable type.""" @@ -731,10 +932,13 @@ def does_not_return_value(self, callee_type: Optional[Type], context: Context) - if isinstance(callee_type, FunctionLike): name = callable_name(callee_type) if name is not None: - self.fail(f'{capitalize(name)} does not return a value', context, - code=codes.FUNC_RETURNS_VALUE) + self.fail( + f"{capitalize(name)} does not return a value", + context, + code=codes.FUNC_RETURNS_VALUE, + ) else: - self.fail('Function does not return a value', context, code=codes.FUNC_RETURNS_VALUE) + self.fail("Function does not return a value", context, code=codes.FUNC_RETURNS_VALUE) def underscore_function_call(self, context: Context) -> None: self.fail('Calling function named "_" is not allowed', context) @@ -745,7 +949,7 @@ def deleted_as_rvalue(self, typ: DeletedType, context: Context) -> None: s = "" else: s = f' "{typ.source}"' - self.fail(f'Trying to read deleted variable{s}', context) + self.fail(f"Trying to read deleted variable{s}", context) def deleted_as_lvalue(self, typ: DeletedType, context: Context) -> None: """Report an error about using an deleted type as an lvalue. @@ -757,83 +961,103 @@ def deleted_as_lvalue(self, typ: DeletedType, context: Context) -> None: s = "" else: s = f' "{typ.source}"' - self.fail(f'Assignment to variable{s} outside except: block', context) - - def no_variant_matches_arguments(self, - overload: Overloaded, - arg_types: List[Type], - context: Context, - *, - code: Optional[ErrorCode] = None) -> None: + self.fail(f"Assignment to variable{s} outside except: block", context) + + def no_variant_matches_arguments( + self, + overload: Overloaded, + arg_types: List[Type], + context: Context, + *, + code: Optional[ErrorCode] = None, + ) -> None: code = code or codes.CALL_OVERLOAD name = callable_name(overload) if name: - name_str = f' of {name}' + name_str = f" of {name}" else: - name_str = '' - arg_types_str = ', '.join(format_type(arg) for arg in arg_types) + name_str = "" + arg_types_str = ", ".join(format_type(arg) for arg in arg_types) num_args = len(arg_types) if num_args == 0: - self.fail(f'All overload variants{name_str} require at least one argument', - context, code=code) + self.fail( + f"All overload variants{name_str} require at least one argument", + context, + code=code, + ) elif num_args == 1: - self.fail('No overload variant{} matches argument type {}' - .format(name_str, arg_types_str), context, code=code) + self.fail( + "No overload variant{} matches argument type {}".format(name_str, arg_types_str), + context, + code=code, + ) else: - self.fail('No overload variant{} matches argument types {}' - .format(name_str, arg_types_str), context, code=code) + self.fail( + "No overload variant{} matches argument types {}".format(name_str, arg_types_str), + context, + code=code, + ) - self.note( - f'Possible overload variant{plural_s(len(overload.items))}:', - context, code=code) + self.note(f"Possible overload variant{plural_s(len(overload.items))}:", context, code=code) for item in overload.items: self.note(pretty_callable(item), context, offset=4, code=code) - def wrong_number_values_to_unpack(self, provided: int, expected: int, - context: Context) -> None: + def wrong_number_values_to_unpack( + self, provided: int, expected: int, context: Context + ) -> None: if provided < expected: if provided == 1: - self.fail(f'Need more than 1 value to unpack ({expected} expected)', - context) + self.fail(f"Need more than 1 value to unpack ({expected} expected)", context) else: - self.fail('Need more than {} values to unpack ({} expected)'.format( - provided, expected), context) + self.fail( + "Need more than {} values to unpack ({} expected)".format(provided, expected), + context, + ) elif provided > expected: - self.fail('Too many values to unpack ({} expected, {} provided)'.format( - expected, provided), context) + self.fail( + "Too many values to unpack ({} expected, {} provided)".format(expected, provided), + context, + ) def unpacking_strings_disallowed(self, context: Context) -> None: self.fail("Unpacking a string is disallowed", context) def type_not_iterable(self, type: Type, context: Context) -> None: - self.fail(f'{format_type(type)} object is not iterable', context) + self.fail(f"{format_type(type)} object is not iterable", context) def possible_missing_await(self, context: Context) -> None: self.note('Maybe you forgot to use "await"?', context) - def incompatible_operator_assignment(self, op: str, - context: Context) -> None: - self.fail(f'Result type of {op} incompatible in assignment', - context) + def incompatible_operator_assignment(self, op: str, context: Context) -> None: + self.fail(f"Result type of {op} incompatible in assignment", context) def overload_signature_incompatible_with_supertype( - self, name: str, name_in_super: str, supertype: str, - context: Context) -> None: + self, name: str, name_in_super: str, supertype: str, context: Context + ) -> None: target = self.override_target(name, name_in_super, supertype) - self.fail('Signature of "{}" incompatible with {}'.format( - name, target), context, code=codes.OVERRIDE) + self.fail( + 'Signature of "{}" incompatible with {}'.format(name, target), + context, + code=codes.OVERRIDE, + ) note_template = 'Overload variants must be defined in the same order as they are in "{}"' self.note(note_template.format(supertype), context, code=codes.OVERRIDE) def signature_incompatible_with_supertype( - self, name: str, name_in_super: str, supertype: str, context: Context, - original: Optional[FunctionLike] = None, - override: Optional[FunctionLike] = None) -> None: + self, + name: str, + name_in_super: str, + supertype: str, + context: Context, + original: Optional[FunctionLike] = None, + override: Optional[FunctionLike] = None, + ) -> None: code = codes.OVERRIDE target = self.override_target(name, name_in_super, supertype) - self.fail('Signature of "{}" incompatible with {}'.format( - name, target), context, code=code) + self.fail( + 'Signature of "{}" incompatible with {}'.format(name, target), context, code=code + ) INCLUDE_DECORATOR = True # Include @classmethod and @staticmethod decorators, if any ALLOW_DUPS = True # Allow duplicate notes, needed when signatures are duplicates @@ -844,121 +1068,167 @@ def signature_incompatible_with_supertype( # note: def f(self) -> str # note: Subclass: # note: def f(self, x: str) -> None - if original is not None and isinstance(original, (CallableType, Overloaded)) \ - and override is not None and isinstance(override, (CallableType, Overloaded)): - self.note('Superclass:', context, offset=ALIGN_OFFSET + OFFSET, code=code) - self.pretty_callable_or_overload(original, context, offset=ALIGN_OFFSET + 2 * OFFSET, - add_class_or_static_decorator=INCLUDE_DECORATOR, - allow_dups=ALLOW_DUPS, code=code) - - self.note('Subclass:', context, offset=ALIGN_OFFSET + OFFSET, code=code) - self.pretty_callable_or_overload(override, context, offset=ALIGN_OFFSET + 2 * OFFSET, - add_class_or_static_decorator=INCLUDE_DECORATOR, - allow_dups=ALLOW_DUPS, code=code) - - def pretty_callable_or_overload(self, - tp: Union[CallableType, Overloaded], - context: Context, - *, - offset: int = 0, - add_class_or_static_decorator: bool = False, - allow_dups: bool = False, - code: Optional[ErrorCode] = None) -> None: + if ( + original is not None + and isinstance(original, (CallableType, Overloaded)) + and override is not None + and isinstance(override, (CallableType, Overloaded)) + ): + self.note("Superclass:", context, offset=ALIGN_OFFSET + OFFSET, code=code) + self.pretty_callable_or_overload( + original, + context, + offset=ALIGN_OFFSET + 2 * OFFSET, + add_class_or_static_decorator=INCLUDE_DECORATOR, + allow_dups=ALLOW_DUPS, + code=code, + ) + + self.note("Subclass:", context, offset=ALIGN_OFFSET + OFFSET, code=code) + self.pretty_callable_or_overload( + override, + context, + offset=ALIGN_OFFSET + 2 * OFFSET, + add_class_or_static_decorator=INCLUDE_DECORATOR, + allow_dups=ALLOW_DUPS, + code=code, + ) + + def pretty_callable_or_overload( + self, + tp: Union[CallableType, Overloaded], + context: Context, + *, + offset: int = 0, + add_class_or_static_decorator: bool = False, + allow_dups: bool = False, + code: Optional[ErrorCode] = None, + ) -> None: if isinstance(tp, CallableType): if add_class_or_static_decorator: decorator = pretty_class_or_static_decorator(tp) if decorator is not None: self.note(decorator, context, offset=offset, allow_dups=allow_dups, code=code) - self.note(pretty_callable(tp), context, - offset=offset, allow_dups=allow_dups, code=code) + self.note( + pretty_callable(tp), context, offset=offset, allow_dups=allow_dups, code=code + ) elif isinstance(tp, Overloaded): - self.pretty_overload(tp, context, offset, - add_class_or_static_decorator=add_class_or_static_decorator, - allow_dups=allow_dups, code=code) + self.pretty_overload( + tp, + context, + offset, + add_class_or_static_decorator=add_class_or_static_decorator, + allow_dups=allow_dups, + code=code, + ) def argument_incompatible_with_supertype( - self, arg_num: int, name: str, type_name: Optional[str], - name_in_supertype: str, arg_type_in_supertype: Type, supertype: str, - context: Context) -> None: + self, + arg_num: int, + name: str, + type_name: Optional[str], + name_in_supertype: str, + arg_type_in_supertype: Type, + supertype: str, + context: Context, + ) -> None: target = self.override_target(name, name_in_supertype, supertype) arg_type_in_supertype_f = format_type_bare(arg_type_in_supertype) - self.fail('Argument {} of "{}" is incompatible with {}; ' - 'supertype defines the argument type as "{}"' - .format(arg_num, name, target, arg_type_in_supertype_f), - context, - code=codes.OVERRIDE) - self.note( - 'This violates the Liskov substitution principle', + self.fail( + 'Argument {} of "{}" is incompatible with {}; ' + 'supertype defines the argument type as "{}"'.format( + arg_num, name, target, arg_type_in_supertype_f + ), context, - code=codes.OVERRIDE) + code=codes.OVERRIDE, + ) + self.note("This violates the Liskov substitution principle", context, code=codes.OVERRIDE) self.note( - 'See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides', + "See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides", context, - code=codes.OVERRIDE) + code=codes.OVERRIDE, + ) if name == "__eq__" and type_name: multiline_msg = self.comparison_method_example_msg(class_name=type_name) self.note_multiline(multiline_msg, context, code=codes.OVERRIDE) def comparison_method_example_msg(self, class_name: str) -> str: - return dedent('''\ + return dedent( + """\ It is recommended for "__eq__" to work with arbitrary objects, for example: def __eq__(self, other: object) -> bool: if not isinstance(other, {class_name}): return NotImplemented return - '''.format(class_name=class_name)) + """.format( + class_name=class_name + ) + ) def return_type_incompatible_with_supertype( - self, name: str, name_in_supertype: str, supertype: str, - original: Type, override: Type, - context: Context) -> None: + self, + name: str, + name_in_supertype: str, + supertype: str, + original: Type, + override: Type, + context: Context, + ) -> None: target = self.override_target(name, name_in_supertype, supertype) override_str, original_str = format_type_distinctly(override, original) - self.fail('Return type {} of "{}" incompatible with return type {} in {}' - .format(override_str, name, original_str, target), - context, - code=codes.OVERRIDE) + self.fail( + 'Return type {} of "{}" incompatible with return type {} in {}'.format( + override_str, name, original_str, target + ), + context, + code=codes.OVERRIDE, + ) - def override_target(self, name: str, name_in_super: str, - supertype: str) -> str: + def override_target(self, name: str, name_in_super: str, supertype: str) -> str: target = f'supertype "{supertype}"' if name_in_super != name: target = f'"{name_in_super}" of {target}' return target - def incompatible_type_application(self, expected_arg_count: int, - actual_arg_count: int, - context: Context) -> None: + def incompatible_type_application( + self, expected_arg_count: int, actual_arg_count: int, context: Context + ) -> None: if expected_arg_count == 0: - self.fail('Type application targets a non-generic function or class', - context) + self.fail("Type application targets a non-generic function or class", context) elif actual_arg_count > expected_arg_count: - self.fail('Type application has too many types ({} expected)' - .format(expected_arg_count), context) + self.fail( + "Type application has too many types ({} expected)".format(expected_arg_count), + context, + ) else: - self.fail('Type application has too few types ({} expected)' - .format(expected_arg_count), context) + self.fail( + "Type application has too few types ({} expected)".format(expected_arg_count), + context, + ) - def could_not_infer_type_arguments(self, callee_type: CallableType, n: int, - context: Context) -> None: + def could_not_infer_type_arguments( + self, callee_type: CallableType, n: int, context: Context + ) -> None: callee_name = callable_name(callee_type) if callee_name is not None and n > 0: - self.fail(f'Cannot infer type argument {n} of {callee_name}', context) + self.fail(f"Cannot infer type argument {n} of {callee_name}", context) else: - self.fail('Cannot infer function type argument', context) + self.fail("Cannot infer function type argument", context) def invalid_var_arg(self, typ: Type, context: Context) -> None: - self.fail('List or tuple expected as variadic arguments', context) + self.fail("List or tuple expected as variadic arguments", context) def invalid_keyword_var_arg(self, typ: Type, is_mapping: bool, context: Context) -> None: typ = get_proper_type(typ) if isinstance(typ, Instance) and is_mapping: - self.fail('Keywords must be strings', context) + self.fail("Keywords must be strings", context) else: self.fail( - f'Argument after ** must be a mapping, not {format_type(typ)}', - context, code=codes.ARG_TYPE) + f"Argument after ** must be a mapping, not {format_type(typ)}", + context, + code=codes.ARG_TYPE, + ) def undefined_in_superclass(self, member: str, context: Context) -> None: self.fail(f'"{member}" undefined in superclass', context) @@ -968,46 +1238,62 @@ def first_argument_for_super_must_be_type(self, actual: Type, context: Context) if isinstance(actual, Instance): # Don't include type of instance, because it can look confusingly like a type # object. - type_str = 'a non-type instance' + type_str = "a non-type instance" else: type_str = format_type(actual) - self.fail(f'Argument 1 for "super" must be a type object; got {type_str}', context, - code=codes.ARG_TYPE) + self.fail( + f'Argument 1 for "super" must be a type object; got {type_str}', + context, + code=codes.ARG_TYPE, + ) def too_few_string_formatting_arguments(self, context: Context) -> None: - self.fail('Not enough arguments for format string', context, - code=codes.STRING_FORMATTING) + self.fail("Not enough arguments for format string", context, code=codes.STRING_FORMATTING) def too_many_string_formatting_arguments(self, context: Context) -> None: - self.fail('Not all arguments converted during string formatting', context, - code=codes.STRING_FORMATTING) + self.fail( + "Not all arguments converted during string formatting", + context, + code=codes.STRING_FORMATTING, + ) def unsupported_placeholder(self, placeholder: str, context: Context) -> None: - self.fail(f'Unsupported format character "{placeholder}"', context, - code=codes.STRING_FORMATTING) + self.fail( + f'Unsupported format character "{placeholder}"', context, code=codes.STRING_FORMATTING + ) def string_interpolation_with_star_and_key(self, context: Context) -> None: - self.fail('String interpolation contains both stars and mapping keys', context, - code=codes.STRING_FORMATTING) + self.fail( + "String interpolation contains both stars and mapping keys", + context, + code=codes.STRING_FORMATTING, + ) - def requires_int_or_single_byte(self, context: Context, - format_call: bool = False) -> None: - self.fail('"{}c" requires an integer in range(256) or a single byte' - .format(':' if format_call else '%'), - context, code=codes.STRING_FORMATTING) + def requires_int_or_single_byte(self, context: Context, format_call: bool = False) -> None: + self.fail( + '"{}c" requires an integer in range(256) or a single byte'.format( + ":" if format_call else "%" + ), + context, + code=codes.STRING_FORMATTING, + ) - def requires_int_or_char(self, context: Context, - format_call: bool = False) -> None: - self.fail('"{}c" requires int or char'.format(':' if format_call else '%'), - context, code=codes.STRING_FORMATTING) + def requires_int_or_char(self, context: Context, format_call: bool = False) -> None: + self.fail( + '"{}c" requires int or char'.format(":" if format_call else "%"), + context, + code=codes.STRING_FORMATTING, + ) def key_not_in_mapping(self, key: str, context: Context) -> None: - self.fail(f'Key "{key}" not found in mapping', context, - code=codes.STRING_FORMATTING) + self.fail(f'Key "{key}" not found in mapping', context, code=codes.STRING_FORMATTING) def string_interpolation_mixing_key_and_non_keys(self, context: Context) -> None: - self.fail('String interpolation mixes specifier with and without mapping keys', context, - code=codes.STRING_FORMATTING) + self.fail( + "String interpolation mixes specifier with and without mapping keys", + context, + code=codes.STRING_FORMATTING, + ) def cannot_determine_type(self, name: str, context: Context) -> None: self.fail(f'Cannot determine type of "{name}"', context, code=codes.HAS_TYPE) @@ -1016,38 +1302,47 @@ def cannot_determine_type_in_base(self, name: str, base: str, context: Context) self.fail(f'Cannot determine type of "{name}" in base class "{base}"', context) def no_formal_self(self, name: str, item: CallableType, context: Context) -> None: - self.fail('Attribute function "%s" with type %s does not accept self argument' - % (name, format_type(item)), context) + self.fail( + 'Attribute function "%s" with type %s does not accept self argument' + % (name, format_type(item)), + context, + ) - def incompatible_self_argument(self, name: str, arg: Type, sig: CallableType, - is_classmethod: bool, context: Context) -> None: - kind = 'class attribute function' if is_classmethod else 'attribute function' - self.fail('Invalid self argument %s to %s "%s" with type %s' - % (format_type(arg), kind, name, format_type(sig)), context) + def incompatible_self_argument( + self, name: str, arg: Type, sig: CallableType, is_classmethod: bool, context: Context + ) -> None: + kind = "class attribute function" if is_classmethod else "attribute function" + self.fail( + 'Invalid self argument %s to %s "%s" with type %s' + % (format_type(arg), kind, name, format_type(sig)), + context, + ) def incompatible_conditional_function_def(self, defn: FuncDef) -> None: - self.fail('All conditional function variants must have identical ' - 'signatures', defn) + self.fail("All conditional function variants must have identical " "signatures", defn) - def cannot_instantiate_abstract_class(self, class_name: str, - abstract_attributes: List[str], - context: Context) -> None: + def cannot_instantiate_abstract_class( + self, class_name: str, abstract_attributes: List[str], context: Context + ) -> None: attrs = format_string_list([f'"{a}"' for a in abstract_attributes]) - self.fail('Cannot instantiate abstract class "%s" with abstract ' - 'attribute%s %s' % (class_name, plural_s(abstract_attributes), - attrs), - context, code=codes.ABSTRACT) - - def base_class_definitions_incompatible(self, name: str, base1: TypeInfo, - base2: TypeInfo, - context: Context) -> None: - self.fail('Definition of "{}" in base class "{}" is incompatible ' - 'with definition in base class "{}"'.format( - name, base1.name, base2.name), context) + self.fail( + 'Cannot instantiate abstract class "%s" with abstract ' + "attribute%s %s" % (class_name, plural_s(abstract_attributes), attrs), + context, + code=codes.ABSTRACT, + ) + + def base_class_definitions_incompatible( + self, name: str, base1: TypeInfo, base2: TypeInfo, context: Context + ) -> None: + self.fail( + 'Definition of "{}" in base class "{}" is incompatible ' + 'with definition in base class "{}"'.format(name, base1.name, base2.name), + context, + ) def cant_assign_to_method(self, context: Context) -> None: - self.fail(message_registry.CANNOT_ASSIGN_TO_METHOD, context, - code=codes.ASSIGNMENT) + self.fail(message_registry.CANNOT_ASSIGN_TO_METHOD, context, code=codes.ASSIGNMENT) def cant_assign_to_classvar(self, name: str, context: Context) -> None: self.fail(f'Cannot assign to class variable "{name}" via instance', context) @@ -1056,8 +1351,11 @@ def final_cant_override_writable(self, name: str, ctx: Context) -> None: self.fail(f'Cannot override writable attribute "{name}" with a final one', ctx) def cant_override_final(self, name: str, base_name: str, ctx: Context) -> None: - self.fail('Cannot override final attribute "{}"' - ' (previously declared in base class "{}")'.format(name, base_name), ctx) + self.fail( + 'Cannot override final attribute "{}"' + ' (previously declared in base class "{}")'.format(name, base_name), + ctx, + ) def cant_assign_to_final(self, name: str, attr_assign: bool, ctx: Context) -> None: """Warn about a prohibited assignment to a final attribute. @@ -1073,96 +1371,121 @@ def protocol_members_cant_be_final(self, ctx: Context) -> None: def final_without_value(self, ctx: Context) -> None: self.fail("Final name must be initialized with a value", ctx) - def read_only_property(self, name: str, type: TypeInfo, - context: Context) -> None: + def read_only_property(self, name: str, type: TypeInfo, context: Context) -> None: self.fail(f'Property "{name}" defined in "{type.name}" is read-only', context) - def incompatible_typevar_value(self, - callee: CallableType, - typ: Type, - typevar_name: str, - context: Context) -> None: - self.fail(message_registry.INCOMPATIBLE_TYPEVAR_VALUE - .format(typevar_name, callable_name(callee) or 'function', format_type(typ)), - context, - code=codes.TYPE_VAR) + def incompatible_typevar_value( + self, callee: CallableType, typ: Type, typevar_name: str, context: Context + ) -> None: + self.fail( + message_registry.INCOMPATIBLE_TYPEVAR_VALUE.format( + typevar_name, callable_name(callee) or "function", format_type(typ) + ), + context, + code=codes.TYPE_VAR, + ) def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: - left_str = 'element' if kind == 'container' else 'left operand' - right_str = 'container item' if kind == 'container' else 'right operand' - message = 'Non-overlapping {} check ({} type: {}, {} type: {})' + left_str = "element" if kind == "container" else "left operand" + right_str = "container item" if kind == "container" else "right operand" + message = "Non-overlapping {} check ({} type: {}, {} type: {})" left_typ, right_typ = format_type_distinctly(left, right) - self.fail(message.format(kind, left_str, left_typ, right_str, right_typ), ctx, - code=codes.COMPARISON_OVERLAP) + self.fail( + message.format(kind, left_str, left_typ, right_str, right_typ), + ctx, + code=codes.COMPARISON_OVERLAP, + ) def overload_inconsistently_applies_decorator(self, decorator: str, context: Context) -> None: self.fail( f'Overload does not consistently use the "@{decorator}" ' - + 'decorator on all function signatures.', - context) + + "decorator on all function signatures.", + context, + ) def overloaded_signatures_overlap(self, index1: int, index2: int, context: Context) -> None: - self.fail('Overloaded function signatures {} and {} overlap with ' - 'incompatible return types'.format(index1, index2), context) + self.fail( + "Overloaded function signatures {} and {} overlap with " + "incompatible return types".format(index1, index2), + context, + ) - def overloaded_signature_will_never_match(self, index1: int, index2: int, - context: Context) -> None: + def overloaded_signature_will_never_match( + self, index1: int, index2: int, context: Context + ) -> None: self.fail( - 'Overloaded function signature {index2} will never be matched: ' - 'signature {index1}\'s parameter type(s) are the same or broader'.format( - index1=index1, - index2=index2), - context) + "Overloaded function signature {index2} will never be matched: " + "signature {index1}'s parameter type(s) are the same or broader".format( + index1=index1, index2=index2 + ), + context, + ) def overloaded_signatures_typevar_specific(self, index: int, context: Context) -> None: - self.fail(f'Overloaded function implementation cannot satisfy signature {index} ' + - 'due to inconsistencies in how they use type variables', context) + self.fail( + f"Overloaded function implementation cannot satisfy signature {index} " + + "due to inconsistencies in how they use type variables", + context, + ) def overloaded_signatures_arg_specific(self, index: int, context: Context) -> None: - self.fail('Overloaded function implementation does not accept all possible arguments ' - 'of signature {}'.format(index), context) + self.fail( + "Overloaded function implementation does not accept all possible arguments " + "of signature {}".format(index), + context, + ) def overloaded_signatures_ret_specific(self, index: int, context: Context) -> None: - self.fail('Overloaded function implementation cannot produce return type ' - 'of signature {}'.format(index), context) + self.fail( + "Overloaded function implementation cannot produce return type " + "of signature {}".format(index), + context, + ) def warn_both_operands_are_from_unions(self, context: Context) -> None: - self.note('Both left and right operands are unions', context, code=codes.OPERATOR) + self.note("Both left and right operands are unions", context, code=codes.OPERATOR) def warn_operand_was_from_union(self, side: str, original: Type, context: Context) -> None: - self.note(f'{side} operand is of type {format_type(original)}', context, - code=codes.OPERATOR) + self.note( + f"{side} operand is of type {format_type(original)}", context, code=codes.OPERATOR + ) def operator_method_signatures_overlap( - self, reverse_class: TypeInfo, reverse_method: str, forward_class: Type, - forward_method: str, context: Context) -> None: - self.fail('Signatures of "{}" of "{}" and "{}" of {} ' - 'are unsafely overlapping'.format( - reverse_method, reverse_class.name, - forward_method, format_type(forward_class)), - context) - - def forward_operator_not_callable( - self, forward_method: str, context: Context) -> None: + self, + reverse_class: TypeInfo, + reverse_method: str, + forward_class: Type, + forward_method: str, + context: Context, + ) -> None: + self.fail( + 'Signatures of "{}" of "{}" and "{}" of {} ' + "are unsafely overlapping".format( + reverse_method, reverse_class.name, forward_method, format_type(forward_class) + ), + context, + ) + + def forward_operator_not_callable(self, forward_method: str, context: Context) -> None: self.fail(f'Forward operator "{forward_method}" is not callable', context) - def signatures_incompatible(self, method: str, other_method: str, - context: Context) -> None: - self.fail('Signatures of "{}" and "{}" are incompatible'.format( - method, other_method), context) + def signatures_incompatible(self, method: str, other_method: str, context: Context) -> None: + self.fail( + 'Signatures of "{}" and "{}" are incompatible'.format(method, other_method), context + ) def yield_from_invalid_operand_type(self, expr: Type, context: Context) -> Type: - text = format_type(expr) if format_type(expr) != 'object' else expr + text = format_type(expr) if format_type(expr) != "object" else expr self.fail(f'"yield from" can\'t be applied to {text}', context) return AnyType(TypeOfAny.from_error) def invalid_signature(self, func_type: Type, context: Context) -> None: - self.fail(f'Invalid signature {format_type(func_type)}', context) + self.fail(f"Invalid signature {format_type(func_type)}", context) def invalid_signature_for_special_method( - self, func_type: Type, context: Context, method_name: str) -> None: - self.fail(f'Invalid signature {format_type(func_type)} for "{method_name}"', - context) + self, func_type: Type, context: Context, method_name: str + ) -> None: + self.fail(f'Invalid signature {format_type(func_type)} for "{method_name}"', context) def reveal_type(self, typ: Type, context: Context) -> None: self.note(f'Revealed type is "{typ}"', context) @@ -1174,7 +1497,7 @@ def reveal_locals(self, type_map: Dict[str, Optional[Type]], context: Context) - if sorted_locals: self.note("Revealed local types are:", context) for k, v in sorted_locals.items(): - self.note(f' {k}: {v}', context) + self.note(f" {k}: {v}", context) else: self.note("There are no locals to reveal", context) @@ -1182,52 +1505,67 @@ def unsupported_type_type(self, item: Type, context: Context) -> None: self.fail(f'Cannot instantiate type "Type[{format_type_bare(item)}]"', context) def redundant_cast(self, typ: Type, context: Context) -> None: - self.fail(f'Redundant cast to {format_type(typ)}', context, - code=codes.REDUNDANT_CAST) + self.fail(f"Redundant cast to {format_type(typ)}", context, code=codes.REDUNDANT_CAST) def assert_type_fail(self, source_type: Type, target_type: Type, context: Context) -> None: - self.fail(f"Expression is of type {format_type(source_type)}, " - f"not {format_type(target_type)}", context, - code=codes.ASSERT_TYPE) + self.fail( + f"Expression is of type {format_type(source_type)}, " + f"not {format_type(target_type)}", + context, + code=codes.ASSERT_TYPE, + ) def unimported_type_becomes_any(self, prefix: str, typ: Type, ctx: Context) -> None: - self.fail(f"{prefix} becomes {format_type(typ)} due to an unfollowed import", - ctx, code=codes.NO_ANY_UNIMPORTED) - - def need_annotation_for_var(self, node: SymbolNode, context: Context, - python_version: Optional[Tuple[int, int]] = None) -> None: - hint = '' + self.fail( + f"{prefix} becomes {format_type(typ)} due to an unfollowed import", + ctx, + code=codes.NO_ANY_UNIMPORTED, + ) + + def need_annotation_for_var( + self, node: SymbolNode, context: Context, python_version: Optional[Tuple[int, int]] = None + ) -> None: + hint = "" has_variable_annotations = not python_version or python_version >= (3, 6) # Only gives hint if it's a variable declaration and the partial type is a builtin type - if (python_version and isinstance(node, Var) and isinstance(node.type, PartialType) and - node.type.type and node.type.type.fullname in reverse_builtin_aliases): + if ( + python_version + and isinstance(node, Var) + and isinstance(node.type, PartialType) + and node.type.type + and node.type.type.fullname in reverse_builtin_aliases + ): alias = reverse_builtin_aliases[node.type.type.fullname] - alias = alias.split('.')[-1] - type_dec = '' - if alias == 'Dict': - type_dec = f'{type_dec}, {type_dec}' + alias = alias.split(".")[-1] + type_dec = "" + if alias == "Dict": + type_dec = f"{type_dec}, {type_dec}" if has_variable_annotations: hint = f' (hint: "{node.name}: {alias}[{type_dec}] = ...")' else: hint = f' (hint: "{node.name} = ... # type: {alias}[{type_dec}]")' if has_variable_annotations: - needed = 'annotation' + needed = "annotation" else: - needed = 'comment' + needed = "comment" - self.fail(f'Need type {needed} for "{unmangle(node.name)}"{hint}', context, - code=codes.VAR_ANNOTATED) + self.fail( + f'Need type {needed} for "{unmangle(node.name)}"{hint}', + context, + code=codes.VAR_ANNOTATED, + ) def explicit_any(self, ctx: Context) -> None: self.fail('Explicit "Any" is not allowed', ctx) def unexpected_typeddict_keys( - self, - typ: TypedDictType, - expected_keys: List[str], - actual_keys: List[str], - context: Context) -> None: + self, + typ: TypedDictType, + expected_keys: List[str], + actual_keys: List[str], + context: Context, + ) -> None: actual_set = set(actual_keys) expected_set = set(expected_keys) if not typ.is_anonymous(): @@ -1235,84 +1573,98 @@ def unexpected_typeddict_keys( if actual_set < expected_set: # Use list comprehension instead of set operations to preserve order. missing = [key for key in expected_keys if key not in actual_set] - self.fail('Missing {} for TypedDict {}'.format( - format_key_list(missing, short=True), format_type(typ)), - context, code=codes.TYPEDDICT_ITEM) + self.fail( + "Missing {} for TypedDict {}".format( + format_key_list(missing, short=True), format_type(typ) + ), + context, + code=codes.TYPEDDICT_ITEM, + ) return else: extra = [key for key in actual_keys if key not in expected_set] if extra: # If there are both extra and missing keys, only report extra ones for # simplicity. - self.fail('Extra {} for TypedDict {}'.format( - format_key_list(extra, short=True), format_type(typ)), - context, code=codes.TYPEDDICT_ITEM) + self.fail( + "Extra {} for TypedDict {}".format( + format_key_list(extra, short=True), format_type(typ) + ), + context, + code=codes.TYPEDDICT_ITEM, + ) return found = format_key_list(actual_keys, short=True) if not expected_keys: - self.fail(f'Unexpected TypedDict {found}', context) + self.fail(f"Unexpected TypedDict {found}", context) return expected = format_key_list(expected_keys) if actual_keys and actual_set < expected_set: - found = f'only {found}' - self.fail(f'Expected {expected} but found {found}', context, - code=codes.TYPEDDICT_ITEM) - - def typeddict_key_must_be_string_literal( - self, - typ: TypedDictType, - context: Context) -> None: + found = f"only {found}" + self.fail(f"Expected {expected} but found {found}", context, code=codes.TYPEDDICT_ITEM) + + def typeddict_key_must_be_string_literal(self, typ: TypedDictType, context: Context) -> None: self.fail( - 'TypedDict key must be a string literal; expected one of {}'.format( - format_item_name_list(typ.items.keys())), context, code=codes.LITERAL_REQ) + "TypedDict key must be a string literal; expected one of {}".format( + format_item_name_list(typ.items.keys()) + ), + context, + code=codes.LITERAL_REQ, + ) def typeddict_key_not_found( - self, - typ: TypedDictType, - item_name: str, - context: Context) -> None: + self, typ: TypedDictType, item_name: str, context: Context + ) -> None: if typ.is_anonymous(): - self.fail('"{}" is not a valid TypedDict key; expected one of {}'.format( - item_name, format_item_name_list(typ.items.keys())), context) + self.fail( + '"{}" is not a valid TypedDict key; expected one of {}'.format( + item_name, format_item_name_list(typ.items.keys()) + ), + context, + ) else: - self.fail('TypedDict {} has no key "{}"'.format( - format_type(typ), item_name), context, code=codes.TYPEDDICT_ITEM) + self.fail( + 'TypedDict {} has no key "{}"'.format(format_type(typ), item_name), + context, + code=codes.TYPEDDICT_ITEM, + ) matches = best_matches(item_name, typ.items.keys()) if matches: - self.note("Did you mean {}?".format( - pretty_seq(matches[:3], "or")), context, code=codes.TYPEDDICT_ITEM) - - def typeddict_context_ambiguous( - self, - types: List[TypedDictType], - context: Context) -> None: - formatted_types = ', '.join(list(format_type_distinctly(*types))) - self.fail('Type of TypedDict is ambiguous, could be any of ({})'.format( - formatted_types), context) + self.note( + "Did you mean {}?".format(pretty_seq(matches[:3], "or")), + context, + code=codes.TYPEDDICT_ITEM, + ) + + def typeddict_context_ambiguous(self, types: List[TypedDictType], context: Context) -> None: + formatted_types = ", ".join(list(format_type_distinctly(*types))) + self.fail( + "Type of TypedDict is ambiguous, could be any of ({})".format(formatted_types), context + ) def typeddict_key_cannot_be_deleted( - self, - typ: TypedDictType, - item_name: str, - context: Context) -> None: + self, typ: TypedDictType, item_name: str, context: Context + ) -> None: if typ.is_anonymous(): - self.fail(f'TypedDict key "{item_name}" cannot be deleted', - context) + self.fail(f'TypedDict key "{item_name}" cannot be deleted', context) else: - self.fail('Key "{}" of TypedDict {} cannot be deleted'.format( - item_name, format_type(typ)), context) + self.fail( + 'Key "{}" of TypedDict {} cannot be deleted'.format(item_name, format_type(typ)), + context, + ) def typeddict_setdefault_arguments_inconsistent( - self, - default: Type, - expected: Type, - context: Context) -> None: + self, default: Type, expected: Type, context: Context + ) -> None: msg = 'Argument 2 to "setdefault" of "TypedDict" has incompatible type {}; expected {}' - self.fail(msg.format(format_type(default), format_type(expected)), context, - code=codes.TYPEDDICT_ITEM) + self.fail( + msg.format(format_type(default), format_type(expected)), + context, + code=codes.TYPEDDICT_ITEM, + ) def type_arguments_not_allowed(self, context: Context) -> None: - self.fail('Parameterized generics cannot be used with class or instance checks', context) + self.fail("Parameterized generics cannot be used with class or instance checks", context) def disallowed_any_type(self, typ: Type, context: Context) -> None: typ = get_proper_type(typ) @@ -1323,70 +1675,89 @@ def disallowed_any_type(self, typ: Type, context: Context) -> None: self.fail(message, context) def incorrectly_returning_any(self, typ: Type, context: Context) -> None: - message = f'Returning Any from function declared to return {format_type(typ)}' + message = f"Returning Any from function declared to return {format_type(typ)}" self.fail(message, context, code=codes.NO_ANY_RETURN) def incorrect__exit__return(self, context: Context) -> None: self.fail( - '"bool" is invalid as return type for "__exit__" that always returns False', context, - code=codes.EXIT_RETURN) + '"bool" is invalid as return type for "__exit__" that always returns False', + context, + code=codes.EXIT_RETURN, + ) self.note( 'Use "typing_extensions.Literal[False]" as the return type or change it to "None"', - context, code=codes.EXIT_RETURN) + context, + code=codes.EXIT_RETURN, + ) self.note( 'If return type of "__exit__" implies that it may return True, ' - 'the context manager may swallow exceptions', - context, code=codes.EXIT_RETURN) + "the context manager may swallow exceptions", + context, + code=codes.EXIT_RETURN, + ) def untyped_decorated_function(self, typ: Type, context: Context) -> None: typ = get_proper_type(typ) if isinstance(typ, AnyType): self.fail("Function is untyped after decorator transformation", context) else: - self.fail('Type of decorated function contains type "Any" ({})'.format( - format_type(typ)), context) + self.fail( + 'Type of decorated function contains type "Any" ({})'.format(format_type(typ)), + context, + ) def typed_function_untyped_decorator(self, func_name: str, context: Context) -> None: self.fail(f'Untyped decorator makes function "{func_name}" untyped', context) - def bad_proto_variance(self, actual: int, tvar_name: str, expected: int, - context: Context) -> None: - msg = capitalize('{} type variable "{}" used in protocol where' - ' {} one is expected'.format(variance_string(actual), - tvar_name, - variance_string(expected))) + def bad_proto_variance( + self, actual: int, tvar_name: str, expected: int, context: Context + ) -> None: + msg = capitalize( + '{} type variable "{}" used in protocol where' + " {} one is expected".format( + variance_string(actual), tvar_name, variance_string(expected) + ) + ) self.fail(msg, context) def concrete_only_assign(self, typ: Type, context: Context) -> None: - self.fail("Can only assign concrete classes to a variable of type {}" - .format(format_type(typ)), context) + self.fail( + "Can only assign concrete classes to a variable of type {}".format(format_type(typ)), + context, + ) def concrete_only_call(self, typ: Type, context: Context) -> None: - self.fail("Only concrete class can be given where {} is expected" - .format(format_type(typ)), context) + self.fail( + "Only concrete class can be given where {} is expected".format(format_type(typ)), + context, + ) def cannot_use_function_with_type( - self, method_name: str, type_name: str, context: Context) -> None: + self, method_name: str, type_name: str, context: Context + ) -> None: self.fail(f"Cannot use {method_name}() with {type_name} type", context) - def report_non_method_protocol(self, tp: TypeInfo, members: List[str], - context: Context) -> None: - self.fail("Only protocols that don't have non-method members can be" - " used with issubclass()", context) + def report_non_method_protocol( + self, tp: TypeInfo, members: List[str], context: Context + ) -> None: + self.fail( + "Only protocols that don't have non-method members can be" " used with issubclass()", + context, + ) if len(members) < 3: - attrs = ', '.join(members) - self.note('Protocol "{}" has non-method member(s): {}' - .format(tp.name, attrs), context) - - def note_call(self, - subtype: Type, - call: Type, - context: Context, - *, - code: Optional[ErrorCode]) -> None: - self.note('"{}.__call__" has type {}'.format(format_type_bare(subtype), - format_type(call, verbosity=1)), - context, code=code) + attrs = ", ".join(members) + self.note('Protocol "{}" has non-method member(s): {}'.format(tp.name, attrs), context) + + def note_call( + self, subtype: Type, call: Type, context: Context, *, code: Optional[ErrorCode] + ) -> None: + self.note( + '"{}.__call__" has type {}'.format( + format_type_bare(subtype), format_type(call, verbosity=1) + ), + context, + code=code, + ) def unreachable_statement(self, context: Context) -> None: self.fail("Statement is unreachable", context, code=codes.UNREACHABLE) @@ -1396,15 +1767,16 @@ def redundant_left_operand(self, op_name: str, context: Context) -> None: it does not change the truth value of the entire condition as a whole. 'op_name' should either be the string "and" or the string "or". """ - self.redundant_expr(f'Left operand of "{op_name}"', op_name == 'and', context) + self.redundant_expr(f'Left operand of "{op_name}"', op_name == "and", context) def unreachable_right_operand(self, op_name: str, context: Context) -> None: """Indicates that the right operand of a boolean expression is redundant: it does not change the truth value of the entire condition as a whole. 'op_name' should either be the string "and" or the string "or". """ - self.fail(f'Right operand of "{op_name}" is never evaluated', - context, code=codes.UNREACHABLE) + self.fail( + f'Right operand of "{op_name}" is never evaluated', context, code=codes.UNREACHABLE + ) def redundant_condition_in_comprehension(self, truthiness: bool, context: Context) -> None: self.redundant_expr("If condition in comprehension", truthiness, context) @@ -1413,24 +1785,28 @@ def redundant_condition_in_if(self, truthiness: bool, context: Context) -> None: self.redundant_expr("If condition", truthiness, context) def redundant_expr(self, description: str, truthiness: bool, context: Context) -> None: - self.fail(f"{description} is always {str(truthiness).lower()}", - context, code=codes.REDUNDANT_EXPR) - - def impossible_intersection(self, - formatted_base_class_list: str, - reason: str, - context: Context, - ) -> None: + self.fail( + f"{description} is always {str(truthiness).lower()}", + context, + code=codes.REDUNDANT_EXPR, + ) + + def impossible_intersection( + self, formatted_base_class_list: str, reason: str, context: Context + ) -> None: template = "Subclass of {} cannot exist: would have {}" - self.fail(template.format(formatted_base_class_list, reason), context, - code=codes.UNREACHABLE) - - def report_protocol_problems(self, - subtype: Union[Instance, TupleType, TypedDictType], - supertype: Instance, - context: Context, - *, - code: Optional[ErrorCode]) -> None: + self.fail( + template.format(formatted_base_class_list, reason), context, code=codes.UNREACHABLE + ) + + def report_protocol_problems( + self, + subtype: Union[Instance, TupleType, TypedDictType], + supertype: Instance, + context: Context, + *, + code: Optional[ErrorCode], + ) -> None: """Report possible protocol conflicts between 'subtype' and 'supertype'. This includes missing members, incompatible types, and incompatible @@ -1465,43 +1841,55 @@ def report_protocol_problems(self, # Report missing members missing = get_missing_protocol_members(subtype, supertype) - if (missing and len(missing) < len(supertype.type.protocol_members) and - len(missing) <= MAX_ITEMS): - self.note('"{}" is missing following "{}" protocol member{}:' - .format(subtype.type.name, supertype.type.name, plural_s(missing)), - context, - code=code) - self.note(', '.join(missing), context, offset=OFFSET, code=code) + if ( + missing + and len(missing) < len(supertype.type.protocol_members) + and len(missing) <= MAX_ITEMS + ): + self.note( + '"{}" is missing following "{}" protocol member{}:'.format( + subtype.type.name, supertype.type.name, plural_s(missing) + ), + context, + code=code, + ) + self.note(", ".join(missing), context, offset=OFFSET, code=code) elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members): # This is an obviously wrong type: too many missing members return # Report member type conflicts conflict_types = get_conflict_protocol_types(subtype, supertype) - if conflict_types and (not is_subtype(subtype, erase_type(supertype)) or - not subtype.type.defn.type_vars or - not supertype.type.defn.type_vars): - self.note(f'Following member(s) of {format_type(subtype)} have conflicts:', - context, - code=code) + if conflict_types and ( + not is_subtype(subtype, erase_type(supertype)) + or not subtype.type.defn.type_vars + or not supertype.type.defn.type_vars + ): + self.note( + f"Following member(s) of {format_type(subtype)} have conflicts:", + context, + code=code, + ) for name, got, exp in conflict_types[:MAX_ITEMS]: exp = get_proper_type(exp) got = get_proper_type(got) - if (not isinstance(exp, (CallableType, Overloaded)) or - not isinstance(got, (CallableType, Overloaded))): - self.note('{}: expected {}, got {}'.format(name, - *format_type_distinctly(exp, got)), - context, - offset=OFFSET, - code=code) + if not isinstance(exp, (CallableType, Overloaded)) or not isinstance( + got, (CallableType, Overloaded) + ): + self.note( + "{}: expected {}, got {}".format(name, *format_type_distinctly(exp, got)), + context, + offset=OFFSET, + code=code, + ) else: - self.note('Expected:', context, offset=OFFSET, code=code) + self.note("Expected:", context, offset=OFFSET, code=code) if isinstance(exp, CallableType): self.note(pretty_callable(exp), context, offset=2 * OFFSET, code=code) else: assert isinstance(exp, Overloaded) self.pretty_overload(exp, context, 2 * OFFSET, code=code) - self.note('Got:', context, offset=OFFSET, code=code) + self.note("Got:", context, offset=OFFSET, code=code) if isinstance(got, CallableType): self.note(pretty_callable(got), context, offset=2 * OFFSET, code=code) else: @@ -1513,90 +1901,119 @@ def report_protocol_problems(self, conflict_flags = get_bad_protocol_flags(subtype, supertype) for name, subflags, superflags in conflict_flags[:MAX_ITEMS]: if IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags: - self.note('Protocol member {}.{} expected instance variable,' - ' got class variable'.format(supertype.type.name, name), - context, - code=code) + self.note( + "Protocol member {}.{} expected instance variable," + " got class variable".format(supertype.type.name, name), + context, + code=code, + ) if IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags: - self.note('Protocol member {}.{} expected class variable,' - ' got instance variable'.format(supertype.type.name, name), - context, - code=code) + self.note( + "Protocol member {}.{} expected class variable," + " got instance variable".format(supertype.type.name, name), + context, + code=code, + ) if IS_SETTABLE in superflags and IS_SETTABLE not in subflags: - self.note('Protocol member {}.{} expected settable variable,' - ' got read-only attribute'.format(supertype.type.name, name), - context, - code=code) + self.note( + "Protocol member {}.{} expected settable variable," + " got read-only attribute".format(supertype.type.name, name), + context, + code=code, + ) if IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags: - self.note('Protocol member {}.{} expected class or static method' - .format(supertype.type.name, name), - context, - code=code) + self.note( + "Protocol member {}.{} expected class or static method".format( + supertype.type.name, name + ), + context, + code=code, + ) self.print_more(conflict_flags, context, OFFSET, MAX_ITEMS, code=code) - def pretty_overload(self, - tp: Overloaded, - context: Context, - offset: int, - *, - add_class_or_static_decorator: bool = False, - allow_dups: bool = False, - code: Optional[ErrorCode] = None) -> None: + def pretty_overload( + self, + tp: Overloaded, + context: Context, + offset: int, + *, + add_class_or_static_decorator: bool = False, + allow_dups: bool = False, + code: Optional[ErrorCode] = None, + ) -> None: for item in tp.items: - self.note('@overload', context, offset=offset, allow_dups=allow_dups, code=code) + self.note("@overload", context, offset=offset, allow_dups=allow_dups, code=code) if add_class_or_static_decorator: decorator = pretty_class_or_static_decorator(item) if decorator is not None: self.note(decorator, context, offset=offset, allow_dups=allow_dups, code=code) - self.note(pretty_callable(item), context, - offset=offset, allow_dups=allow_dups, code=code) + self.note( + pretty_callable(item), context, offset=offset, allow_dups=allow_dups, code=code + ) - def print_more(self, - conflicts: Sequence[Any], - context: Context, - offset: int, - max_items: int, - *, - code: Optional[ErrorCode] = None) -> None: + def print_more( + self, + conflicts: Sequence[Any], + context: Context, + offset: int, + max_items: int, + *, + code: Optional[ErrorCode] = None, + ) -> None: if len(conflicts) > max_items: - self.note(f'<{len(conflicts) - max_items} more conflict(s) not shown>', - context, offset=offset, code=code) - - def try_report_long_tuple_assignment_error(self, - subtype: ProperType, - supertype: ProperType, - context: Context, - msg: str = message_registry.INCOMPATIBLE_TYPES, - subtype_label: Optional[str] = None, - supertype_label: Optional[str] = None, - code: Optional[ErrorCode] = None) -> bool: + self.note( + f"<{len(conflicts) - max_items} more conflict(s) not shown>", + context, + offset=offset, + code=code, + ) + + def try_report_long_tuple_assignment_error( + self, + subtype: ProperType, + supertype: ProperType, + context: Context, + msg: str = message_registry.INCOMPATIBLE_TYPES, + subtype_label: Optional[str] = None, + supertype_label: Optional[str] = None, + code: Optional[ErrorCode] = None, + ) -> bool: """Try to generate meaningful error message for very long tuple assignment Returns a bool: True when generating long tuple assignment error, False when no such error reported """ if isinstance(subtype, TupleType): - if (len(subtype.items) > 10 and - isinstance(supertype, Instance) and - supertype.type.fullname == 'builtins.tuple'): + if ( + len(subtype.items) > 10 + and isinstance(supertype, Instance) + and supertype.type.fullname == "builtins.tuple" + ): lhs_type = supertype.args[0] lhs_types = [lhs_type] * len(subtype.items) - self.generate_incompatible_tuple_error(lhs_types, - subtype.items, context, msg, code) + self.generate_incompatible_tuple_error( + lhs_types, subtype.items, context, msg, code + ) return True - elif (isinstance(supertype, TupleType) and - (len(subtype.items) > 10 or len(supertype.items) > 10)): + elif isinstance(supertype, TupleType) and ( + len(subtype.items) > 10 or len(supertype.items) > 10 + ): if len(subtype.items) != len(supertype.items): if supertype_label is not None and subtype_label is not None: - error_msg = "{} ({} {}, {} {})".format(msg, subtype_label, - self.format_long_tuple_type(subtype), supertype_label, - self.format_long_tuple_type(supertype)) + error_msg = "{} ({} {}, {} {})".format( + msg, + subtype_label, + self.format_long_tuple_type(subtype), + supertype_label, + self.format_long_tuple_type(supertype), + ) self.fail(error_msg, context, code=code) return True - self.generate_incompatible_tuple_error(supertype.items, - subtype.items, context, msg, code) + self.generate_incompatible_tuple_error( + supertype.items, subtype.items, context, msg, code + ) return True return False @@ -1604,33 +2021,38 @@ def format_long_tuple_type(self, typ: TupleType) -> str: """Format very long tuple type using an ellipsis notation""" item_cnt = len(typ.items) if item_cnt > 10: - return 'Tuple[{}, {}, ... <{} more items>]'\ - .format(format_type_bare(typ.items[0]), - format_type_bare(typ.items[1]), str(item_cnt - 2)) + return "Tuple[{}, {}, ... <{} more items>]".format( + format_type_bare(typ.items[0]), format_type_bare(typ.items[1]), str(item_cnt - 2) + ) else: return format_type_bare(typ) - def generate_incompatible_tuple_error(self, - lhs_types: List[Type], - rhs_types: List[Type], - context: Context, - msg: str = message_registry.INCOMPATIBLE_TYPES, - code: Optional[ErrorCode] = None) -> None: + def generate_incompatible_tuple_error( + self, + lhs_types: List[Type], + rhs_types: List[Type], + context: Context, + msg: str = message_registry.INCOMPATIBLE_TYPES, + code: Optional[ErrorCode] = None, + ) -> None: """Generate error message for individual incompatible tuple pairs""" error_cnt = 0 notes = [] # List[str] for i, (lhs_t, rhs_t) in enumerate(zip(lhs_types, rhs_types)): if not is_subtype(lhs_t, rhs_t): if error_cnt < 3: - notes.append('Expression tuple item {} has type {}; {} expected; ' - .format(str(i), format_type(rhs_t), format_type(lhs_t))) + notes.append( + "Expression tuple item {} has type {}; {} expected; ".format( + str(i), format_type(rhs_t), format_type(lhs_t) + ) + ) error_cnt += 1 - error_msg = msg + f' ({str(error_cnt)} tuple items are incompatible' + error_msg = msg + f" ({str(error_cnt)} tuple items are incompatible" if error_cnt - 3 > 0: - error_msg += f'; {str(error_cnt - 3)} items are omitted)' + error_msg += f"; {str(error_cnt - 3)} items are omitted)" else: - error_msg += ')' + error_msg += ")" self.fail(error_msg, context, code=code) for note in notes: self.note(note, context, code=code) @@ -1639,30 +2061,38 @@ def add_fixture_note(self, fullname: str, ctx: Context) -> None: self.note(f'Maybe your test fixture does not define "{fullname}"?', ctx) if fullname in SUGGESTED_TEST_FIXTURES: self.note( - 'Consider adding [builtins fixtures/{}] to your test description'.format( - SUGGESTED_TEST_FIXTURES[fullname]), ctx) + "Consider adding [builtins fixtures/{}] to your test description".format( + SUGGESTED_TEST_FIXTURES[fullname] + ), + ctx, + ) def quote_type_string(type_string: str) -> str: """Quotes a type representation for use in messages.""" - no_quote_regex = r'^<(tuple|union): \d+ items>$' - if (type_string in ['Module', 'overloaded function', '', ''] - or re.match(no_quote_regex, type_string) is not None or type_string.endswith('?')): + no_quote_regex = r"^<(tuple|union): \d+ items>$" + if ( + type_string in ["Module", "overloaded function", "", ""] + or re.match(no_quote_regex, type_string) is not None + or type_string.endswith("?") + ): # Messages are easier to read if these aren't quoted. We use a # regex to match strings with variable contents. return type_string return f'"{type_string}"' -def format_callable_args(arg_types: List[Type], arg_kinds: List[ArgKind], - arg_names: List[Optional[str]], format: Callable[[Type], str], - verbosity: int) -> str: +def format_callable_args( + arg_types: List[Type], + arg_kinds: List[ArgKind], + arg_names: List[Optional[str]], + format: Callable[[Type], str], + verbosity: int, +) -> str: """Format a bunch of Callable arguments into a string""" arg_strings = [] - for arg_name, arg_type, arg_kind in zip( - arg_names, arg_types, arg_kinds): - if (arg_kind == ARG_POS and arg_name is None - or verbosity == 0 and arg_kind.is_positional()): + for arg_name, arg_type, arg_kind in zip(arg_names, arg_types, arg_kinds): + if arg_kind == ARG_POS and arg_name is None or verbosity == 0 and arg_kind.is_positional(): arg_strings.append(format(arg_type)) else: @@ -1670,17 +2100,14 @@ def format_callable_args(arg_types: List[Type], arg_kinds: List[ArgKind], if arg_kind.is_star() or arg_name is None: arg_strings.append(f"{constructor}({format(arg_type)})") else: - arg_strings.append("{}({}, {})".format( - constructor, - format(arg_type), - repr(arg_name))) + arg_strings.append( + "{}({}, {})".format(constructor, format(arg_type), repr(arg_name)) + ) return ", ".join(arg_strings) -def format_type_inner(typ: Type, - verbosity: int, - fullnames: Optional[Set[str]]) -> str: +def format_type_inner(typ: Type, verbosity: int, fullnames: Optional[Set[str]]) -> str: """ Convert a type to a relatively short string suitable for error messages. @@ -1688,16 +2115,17 @@ def format_type_inner(typ: Type, verbosity: a coarse grained control on the verbosity of the type fullnames: a set of names that should be printed in full """ + def format(typ: Type) -> str: return format_type_inner(typ, verbosity, fullnames) def format_list(types: Sequence[Type]) -> str: - return ', '.join(format(typ) for typ in types) + return ", ".join(format(typ) for typ in types) def format_literal_value(typ: LiteralType) -> str: if typ.is_enum_literal(): underlying_type = format(typ.fallback) - return f'{underlying_type}.{typ.value}' + return f"{underlying_type}.{typ.value}" else: return typ.value_repr() @@ -1707,9 +2135,9 @@ def format_literal_value(typ: LiteralType) -> str: if isinstance(typ, Instance): itype = typ # Get the short name of the type. - if itype.type.fullname in ('types.ModuleType', '_importlib_modulespec.ModuleType'): + if itype.type.fullname in ("types.ModuleType", "_importlib_modulespec.ModuleType"): # Make some common error messages simpler and tidier. - return 'Module' + return "Module" if verbosity >= 2 or (fullnames and itype.type.fullname in fullnames): base_str = itype.type.fullname else: @@ -1717,16 +2145,16 @@ def format_literal_value(typ: LiteralType) -> str: if not itype.args: # No type arguments, just return the type name return base_str - elif itype.type.fullname == 'builtins.tuple': + elif itype.type.fullname == "builtins.tuple": item_type_str = format(itype.args[0]) - return f'Tuple[{item_type_str}, ...]' + return f"Tuple[{item_type_str}, ...]" elif itype.type.fullname in reverse_builtin_aliases: alias = reverse_builtin_aliases[itype.type.fullname] - alias = alias.split('.')[-1] - return f'{alias}[{format_list(itype.args)}]' + alias = alias.split(".")[-1] + return f"{alias}[{format_list(itype.args)}]" else: # There are type arguments. Convert the arguments to strings. - return f'{base_str}[{format_list(itype.args)}]' + return f"{base_str}[{format_list(itype.args)}]" elif isinstance(typ, TypeVarType): # This is similar to non-generic instance types. return typ.name @@ -1734,20 +2162,17 @@ def format_literal_value(typ: LiteralType) -> str: # Concatenate[..., P] if typ.prefix.arg_types: args = format_callable_args( - typ.prefix.arg_types, - typ.prefix.arg_kinds, - typ.prefix.arg_names, - format, - verbosity) + typ.prefix.arg_types, typ.prefix.arg_kinds, typ.prefix.arg_names, format, verbosity + ) - return f'[{args}, **{typ.name_with_suffix()}]' + return f"[{args}, **{typ.name_with_suffix()}]" else: return typ.name_with_suffix() elif isinstance(typ, TupleType): # Prefer the name of the fallback class (if not tuple), as it's more informative. - if typ.partial_fallback.type.fullname != 'builtins.tuple': + if typ.partial_fallback.type.fullname != "builtins.tuple": return format(typ.partial_fallback) - s = f'Tuple[{format_list(typ.items)}]' + s = f"Tuple[{format_list(typ.items)}]" return s elif isinstance(typ, TypedDictType): # If the TypedDictType is named, return the name @@ -1755,53 +2180,54 @@ def format_literal_value(typ: LiteralType) -> str: return format(typ.fallback) items = [] for (item_name, item_type) in typ.items.items(): - modifier = '' if item_name in typ.required_keys else '?' - items.append(f'{item_name!r}{modifier}: {format(item_type)}') + modifier = "" if item_name in typ.required_keys else "?" + items.append(f"{item_name!r}{modifier}: {format(item_type)}") s = f"TypedDict({{{', '.join(items)}}})" return s elif isinstance(typ, LiteralType): - return f'Literal[{format_literal_value(typ)}]' + return f"Literal[{format_literal_value(typ)}]" elif isinstance(typ, UnionType): literal_items, union_items = separate_union_literals(typ) # Coalesce multiple Literal[] members. This also changes output order. # If there's just one Literal item, retain the original ordering. if len(literal_items) > 1: - literal_str = 'Literal[{}]'.format( - ', '.join(format_literal_value(t) for t in literal_items) + literal_str = "Literal[{}]".format( + ", ".join(format_literal_value(t) for t in literal_items) ) if len(union_items) == 1 and isinstance(get_proper_type(union_items[0]), NoneType): - return f'Optional[{literal_str}]' + return f"Optional[{literal_str}]" elif union_items: - return f'Union[{format_list(union_items)}, {literal_str}]' + return f"Union[{format_list(union_items)}, {literal_str}]" else: return literal_str else: # Only print Union as Optional if the Optional wouldn't have to contain another Union - print_as_optional = (len(typ.items) - - sum(isinstance(get_proper_type(t), NoneType) - for t in typ.items) == 1) + print_as_optional = ( + len(typ.items) - sum(isinstance(get_proper_type(t), NoneType) for t in typ.items) + == 1 + ) if print_as_optional: rest = [t for t in typ.items if not isinstance(get_proper_type(t), NoneType)] - return f'Optional[{format(rest[0])}]' + return f"Optional[{format(rest[0])}]" else: - s = f'Union[{format_list(typ.items)}]' + s = f"Union[{format_list(typ.items)}]" return s elif isinstance(typ, NoneType): - return 'None' + return "None" elif isinstance(typ, AnyType): - return 'Any' + return "Any" elif isinstance(typ, DeletedType): - return '' + return "" elif isinstance(typ, UninhabitedType): if typ.is_noreturn: - return 'NoReturn' + return "NoReturn" else: - return '' + return "" elif isinstance(typ, TypeType): - return f'Type[{format(typ.item)}]' + return f"Type[{format(typ.item)}]" elif isinstance(typ, FunctionLike): func = typ if func.is_type_obj(): @@ -1810,41 +2236,33 @@ def format_literal_value(typ: LiteralType) -> str: return format(TypeType.make_normalized(erase_type(func.items[0].ret_type))) elif isinstance(func, CallableType): if func.type_guard is not None: - return_type = f'TypeGuard[{format(func.type_guard)}]' + return_type = f"TypeGuard[{format(func.type_guard)}]" else: return_type = format(func.ret_type) if func.is_ellipsis_args: - return f'Callable[..., {return_type}]' + return f"Callable[..., {return_type}]" param_spec = func.param_spec() if param_spec is not None: - return f'Callable[{format(param_spec)}, {return_type}]' + return f"Callable[{format(param_spec)}, {return_type}]" args = format_callable_args( - func.arg_types, - func.arg_kinds, - func.arg_names, - format, - verbosity) - return f'Callable[[{args}], {return_type}]' + func.arg_types, func.arg_kinds, func.arg_names, format, verbosity + ) + return f"Callable[[{args}], {return_type}]" else: # Use a simple representation for function types; proper # function types may result in long and difficult-to-read # error messages. - return 'overloaded function' + return "overloaded function" elif isinstance(typ, UnboundType): return str(typ) elif isinstance(typ, Parameters): - args = format_callable_args( - typ.arg_types, - typ.arg_kinds, - typ.arg_names, - format, - verbosity) - return f'[{args}]' + args = format_callable_args(typ.arg_types, typ.arg_kinds, typ.arg_names, format, verbosity) + return f"[{args}]" elif typ is None: - raise RuntimeError('Type is None') + raise RuntimeError("Type is None") else: # Default case; we simply have to return something meaningful here. - return 'object' + return "object" def collect_all_instances(t: Type) -> List[Instance]: @@ -1878,8 +2296,8 @@ def find_type_overlaps(*types: Type) -> Set[str]: for inst in collect_all_instances(type): d.setdefault(inst.type.name, set()).add(inst.type.fullname) for shortname in d.keys(): - if f'typing.{shortname}' in TYPES_FOR_UNIMPORTED_HINTS: - d[shortname].add(f'typing.{shortname}') + if f"typing.{shortname}" in TYPES_FOR_UNIMPORTED_HINTS: + d[shortname].add(f"typing.{shortname}") overlaps: Set[str] = set() for fullnames in d.values(): @@ -1902,8 +2320,7 @@ def format_type(typ: Type, verbosity: int = 0) -> str: return quote_type_string(format_type_bare(typ, verbosity)) -def format_type_bare(typ: Type, - verbosity: int = 0) -> str: +def format_type_bare(typ: Type, verbosity: int = 0) -> str: """ Convert a type to a relatively short string suitable for error messages. @@ -1933,8 +2350,7 @@ def format_type_distinctly(*types: Type, bare: bool = False) -> Tuple[str, ...]: overlapping = find_type_overlaps(*types) for verbosity in range(2): strs = [ - format_type_inner(type, verbosity=verbosity, fullnames=overlapping) - for type in types + format_type_inner(type, verbosity=verbosity, fullnames=overlapping) for type in types ] if len(set(strs)) == len(strs): break @@ -1948,9 +2364,9 @@ def pretty_class_or_static_decorator(tp: CallableType) -> Optional[str]: """Return @classmethod or @staticmethod, if any, for the given callable type.""" if tp.definition is not None and isinstance(tp.definition, SYMBOL_FUNCBASE_TYPES): if tp.definition.is_class: - return '@classmethod' + return "@classmethod" if tp.definition.is_static: - return '@staticmethod' + return "@staticmethod" return None @@ -1959,50 +2375,56 @@ def pretty_callable(tp: CallableType) -> str: For example: def [T <: int] f(self, x: int, y: T) -> None """ - s = '' + s = "" asterisk = False for i in range(len(tp.arg_types)): if s: - s += ', ' + s += ", " if tp.arg_kinds[i].is_named() and not asterisk: - s += '*, ' + s += "*, " asterisk = True if tp.arg_kinds[i] == ARG_STAR: - s += '*' + s += "*" asterisk = True if tp.arg_kinds[i] == ARG_STAR2: - s += '**' + s += "**" name = tp.arg_names[i] if name: - s += name + ': ' + s += name + ": " s += format_type_bare(tp.arg_types[i]) if tp.arg_kinds[i].is_optional(): - s += ' = ...' + s += " = ..." # If we got a "special arg" (i.e: self, cls, etc...), prepend it to the arg list - if (isinstance(tp.definition, FuncDef) and - tp.definition.name is not None and - hasattr(tp.definition, 'arguments')): + if ( + isinstance(tp.definition, FuncDef) + and tp.definition.name is not None + and hasattr(tp.definition, "arguments") + ): definition_args = [arg.variable.name for arg in tp.definition.arguments] - if definition_args and tp.arg_names != definition_args \ - and len(definition_args) > 0 and definition_args[0]: + if ( + definition_args + and tp.arg_names != definition_args + and len(definition_args) > 0 + and definition_args[0] + ): if s: - s = ', ' + s + s = ", " + s s = definition_args[0] + s - s = f'{tp.definition.name}({s})' + s = f"{tp.definition.name}({s})" elif tp.name: - first_arg = tp.def_extras.get('first_arg') + first_arg = tp.def_extras.get("first_arg") if first_arg: if s: - s = ', ' + s + s = ", " + s s = first_arg + s - s = f'{tp.name.split()[0]}({s})' # skip "of Class" part + s = f"{tp.name.split()[0]}({s})" # skip "of Class" part else: - s = f'({s})' + s = f"({s})" - s += ' -> ' + s += " -> " if tp.type_guard is not None: - s += f'TypeGuard[{format_type_bare(tp.type_guard)}]' + s += f"TypeGuard[{format_type_bare(tp.type_guard)}]" else: s += format_type_bare(tp.ret_type) @@ -2011,29 +2433,33 @@ def [T <: int] f(self, x: int, y: T) -> None for tvar in tp.variables: if isinstance(tvar, TypeVarType): upper_bound = get_proper_type(tvar.upper_bound) - if (isinstance(upper_bound, Instance) and - upper_bound.type.fullname != 'builtins.object'): - tvars.append(f'{tvar.name} <: {format_type_bare(upper_bound)}') + if ( + isinstance(upper_bound, Instance) + and upper_bound.type.fullname != "builtins.object" + ): + tvars.append(f"{tvar.name} <: {format_type_bare(upper_bound)}") elif tvar.values: - tvars.append('{} in ({})' - .format(tvar.name, ', '.join([format_type_bare(tp) - for tp in tvar.values]))) + tvars.append( + "{} in ({})".format( + tvar.name, ", ".join([format_type_bare(tp) for tp in tvar.values]) + ) + ) else: tvars.append(tvar.name) else: # For other TypeVarLikeTypes, just use the repr tvars.append(repr(tvar)) s = f"[{', '.join(tvars)}] {s}" - return f'def {s}' + return f"def {s}" def variance_string(variance: int) -> str: if variance == COVARIANT: - return 'covariant' + return "covariant" elif variance == CONTRAVARIANT: - return 'contravariant' + return "contravariant" else: - return 'invariant' + return "invariant" def get_missing_protocol_members(left: Instance, right: Instance) -> List[str]: @@ -2055,7 +2481,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> List[Tuple[s assert right.type.is_protocol conflicts: List[Tuple[str, Type, Type]] = [] for member in right.type.protocol_members: - if member in ('__init__', '__new__'): + if member in ("__init__", "__new__"): continue supertype = find_member(member, right, left) assert supertype is not None @@ -2070,8 +2496,9 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> List[Tuple[s return conflicts -def get_bad_protocol_flags(left: Instance, right: Instance - ) -> List[Tuple[str, Set[int], Set[int]]]: +def get_bad_protocol_flags( + left: Instance, right: Instance +) -> List[Tuple[str, Set[int], Set[int]]]: """Return all incompatible attribute flags for members that are present in both 'left' and 'right'. """ @@ -2079,24 +2506,32 @@ def get_bad_protocol_flags(left: Instance, right: Instance all_flags: List[Tuple[str, Set[int], Set[int]]] = [] for member in right.type.protocol_members: if find_member(member, left, left): - item = (member, - get_member_flags(member, left.type), - get_member_flags(member, right.type)) + item = ( + member, + get_member_flags(member, left.type), + get_member_flags(member, right.type), + ) all_flags.append(item) bad_flags = [] for name, subflags, superflags in all_flags: - if (IS_CLASSVAR in subflags and IS_CLASSVAR not in superflags or - IS_CLASSVAR in superflags and IS_CLASSVAR not in subflags or - IS_SETTABLE in superflags and IS_SETTABLE not in subflags or - IS_CLASS_OR_STATIC in superflags and IS_CLASS_OR_STATIC not in subflags): + if ( + IS_CLASSVAR in subflags + and IS_CLASSVAR not in superflags + or IS_CLASSVAR in superflags + and IS_CLASSVAR not in subflags + or IS_SETTABLE in superflags + and IS_SETTABLE not in subflags + or IS_CLASS_OR_STATIC in superflags + and IS_CLASS_OR_STATIC not in subflags + ): bad_flags.append((name, subflags, superflags)) return bad_flags def capitalize(s: str) -> str: """Capitalize the first character of a string.""" - if s == '': - return '' + if s == "": + return "" else: return s[0].upper() + s[1:] @@ -2106,14 +2541,14 @@ def extract_type(name: str) -> str: the type portion in quotes (e.g. "y"). Otherwise, return the string unmodified. """ - name = re.sub('^"[a-zA-Z0-9_]+" of ', '', name) + name = re.sub('^"[a-zA-Z0-9_]+" of ', "", name) return name def strip_quotes(s: str) -> str: """Strip a double quote at the beginning and end of the string, if any.""" - s = re.sub('^"', '', s) - s = re.sub('"$', '', s) + s = re.sub('^"', "", s) + s = re.sub('"$', "", s) return s @@ -2124,39 +2559,42 @@ def format_string_list(lst: List[str]) -> str: elif len(lst) <= 5: return f"{', '.join(lst[:-1])} and {lst[-1]}" else: - return '%s, ... and %s (%i methods suppressed)' % ( - ', '.join(lst[:2]), lst[-1], len(lst) - 3) + return "%s, ... and %s (%i methods suppressed)" % ( + ", ".join(lst[:2]), + lst[-1], + len(lst) - 3, + ) def format_item_name_list(s: Iterable[str]) -> str: lst = list(s) if len(lst) <= 5: - return '(' + ', '.join([f'"{name}"' for name in lst]) + ')' + return "(" + ", ".join([f'"{name}"' for name in lst]) + ")" else: - return '(' + ', '.join([f'"{name}"' for name in lst[:5]]) + ', ...)' + return "(" + ", ".join([f'"{name}"' for name in lst[:5]]) + ", ...)" def callable_name(type: FunctionLike) -> Optional[str]: name = type.get_name() - if name is not None and name[0] != '<': - return f'"{name}"'.replace(' of ', '" of "') + if name is not None and name[0] != "<": + return f'"{name}"'.replace(" of ", '" of "') return name def for_function(callee: CallableType) -> str: name = callable_name(callee) if name is not None: - return f' for {name}' - return '' + return f" for {name}" + return "" def find_defining_module(modules: Dict[str, MypyFile], typ: CallableType) -> Optional[MypyFile]: if not typ.definition: return None fullname = typ.definition.fullname - if fullname is not None and '.' in fullname: - for i in range(fullname.count('.')): - module_name = fullname.rsplit('.', i + 1)[0] + if fullname is not None and "." in fullname: + for i in range(fullname.count(".")): + module_name = fullname.rsplit(".", i + 1)[0] try: return modules[module_name] except KeyError: @@ -2166,15 +2604,14 @@ def find_defining_module(modules: Dict[str, MypyFile], typ: CallableType) -> Opt # For hard-coding suggested missing member alternatives. -COMMON_MISTAKES: Final[Dict[str, Sequence[str]]] = { - 'add': ('append', 'extend'), -} +COMMON_MISTAKES: Final[Dict[str, Sequence[str]]] = {"add": ("append", "extend")} def best_matches(current: str, options: Iterable[str]) -> List[str]: ratios = {v: difflib.SequenceMatcher(a=current, b=v).ratio() for v in options} - return sorted((o for o in options if ratios[o] > 0.75), - reverse=True, key=lambda v: (ratios[v], v)) + return sorted( + (o for o in options if ratios[o] > 0.75), reverse=True, key=lambda v: (ratios[v], v) + ) def pretty_seq(args: Sequence[str], conjunction: str) -> str: @@ -2187,35 +2624,41 @@ def pretty_seq(args: Sequence[str], conjunction: str) -> str: return ", ".join(quoted[:-1]) + last_sep + quoted[-1] -def append_invariance_notes(notes: List[str], arg_type: Instance, - expected_type: Instance) -> List[str]: +def append_invariance_notes( + notes: List[str], arg_type: Instance, expected_type: Instance +) -> List[str]: """Explain that the type is invariant and give notes for how to solve the issue.""" - invariant_type = '' - covariant_suggestion = '' - if (arg_type.type.fullname == 'builtins.list' and - expected_type.type.fullname == 'builtins.list' and - is_subtype(arg_type.args[0], expected_type.args[0])): - invariant_type = 'List' + invariant_type = "" + covariant_suggestion = "" + if ( + arg_type.type.fullname == "builtins.list" + and expected_type.type.fullname == "builtins.list" + and is_subtype(arg_type.args[0], expected_type.args[0]) + ): + invariant_type = "List" covariant_suggestion = 'Consider using "Sequence" instead, which is covariant' - elif (arg_type.type.fullname == 'builtins.dict' and - expected_type.type.fullname == 'builtins.dict' and - is_same_type(arg_type.args[0], expected_type.args[0]) and - is_subtype(arg_type.args[1], expected_type.args[1])): - invariant_type = 'Dict' - covariant_suggestion = ('Consider using "Mapping" instead, ' - 'which is covariant in the value type') + elif ( + arg_type.type.fullname == "builtins.dict" + and expected_type.type.fullname == "builtins.dict" + and is_same_type(arg_type.args[0], expected_type.args[0]) + and is_subtype(arg_type.args[1], expected_type.args[1]) + ): + invariant_type = "Dict" + covariant_suggestion = ( + 'Consider using "Mapping" instead, ' "which is covariant in the value type" + ) if invariant_type and covariant_suggestion: notes.append( - f'"{invariant_type}" is invariant -- see ' + - "https://mypy.readthedocs.io/en/stable/common_issues.html#variance") + f'"{invariant_type}" is invariant -- see ' + + "https://mypy.readthedocs.io/en/stable/common_issues.html#variance" + ) notes.append(covariant_suggestion) return notes -def make_inferred_type_note(context: Context, - subtype: Type, - supertype: Type, - supertype_str: str) -> str: +def make_inferred_type_note( + context: Context, subtype: Type, supertype: Type, supertype_str: str +) -> str: """Explain that the user may have forgotten to type a variable. The user does not expect an error if the inferred container type is the same as the return @@ -2225,30 +2668,33 @@ def make_inferred_type_note(context: Context, """ subtype = get_proper_type(subtype) supertype = get_proper_type(supertype) - if (isinstance(subtype, Instance) and - isinstance(supertype, Instance) and - subtype.type.fullname == supertype.type.fullname and - subtype.args and - supertype.args and - isinstance(context, ReturnStmt) and - isinstance(context.expr, NameExpr) and - isinstance(context.expr.node, Var) and - context.expr.node.is_inferred): + if ( + isinstance(subtype, Instance) + and isinstance(supertype, Instance) + and subtype.type.fullname == supertype.type.fullname + and subtype.args + and supertype.args + and isinstance(context, ReturnStmt) + and isinstance(context.expr, NameExpr) + and isinstance(context.expr.node, Var) + and context.expr.node.is_inferred + ): for subtype_arg, supertype_arg in zip(subtype.args, supertype.args): if not is_subtype(subtype_arg, supertype_arg): - return '' + return "" var_name = context.expr.name return 'Perhaps you need a type annotation for "{}"? Suggestion: {}'.format( - var_name, supertype_str) - return '' + var_name, supertype_str + ) + return "" def format_key_list(keys: List[str], *, short: bool = False) -> str: formatted_keys = [f'"{key}"' for key in keys] - td = '' if short else 'TypedDict ' + td = "" if short else "TypedDict " if len(keys) == 0: - return f'no {td}keys' + return f"no {td}keys" elif len(keys) == 1: - return f'{td}key {formatted_keys[0]}' + return f"{td}key {formatted_keys[0]}" else: return f"{td}keys ({', '.join(formatted_keys)})" diff --git a/mypy/metastore.py b/mypy/metastore.py index 29f1bbba2feb1..7c83827e278b7 100644 --- a/mypy/metastore.py +++ b/mypy/metastore.py @@ -11,10 +11,11 @@ import binascii import os import time - from abc import abstractmethod -from typing import List, Iterable, Any, Optional +from typing import Any, Iterable, List, Optional + from typing_extensions import TYPE_CHECKING + if TYPE_CHECKING: # We avoid importing sqlite3 unless we are using it so we can mostly work # on semi-broken pythons that are missing it. @@ -66,11 +67,12 @@ def commit(self) -> None: pass @abstractmethod - def list_all(self) -> Iterable[str]: ... + def list_all(self) -> Iterable[str]: + ... def random_string() -> str: - return binascii.hexlify(os.urandom(8)).decode('ascii') + return binascii.hexlify(os.urandom(8)).decode("ascii") class FilesystemMetadataStore(MetadataStore): @@ -105,10 +107,10 @@ def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool: return False path = os.path.join(self.cache_dir_prefix, name) - tmp_filename = path + '.' + random_string() + tmp_filename = path + "." + random_string() try: os.makedirs(os.path.dirname(path), exist_ok=True) - with open(tmp_filename, 'w') as f: + with open(tmp_filename, "w") as f: f.write(data) os.replace(tmp_filename, path) if mtime is not None: @@ -137,19 +139,19 @@ def list_all(self) -> Iterable[str]: yield os.path.join(dir, file) -SCHEMA = ''' +SCHEMA = """ CREATE TABLE IF NOT EXISTS files ( path TEXT UNIQUE NOT NULL, mtime REAL, data TEXT ); CREATE INDEX IF NOT EXISTS path_idx on files(path); -''' +""" # No migrations yet MIGRATIONS: List[str] = [] -def connect_db(db_file: str) -> 'sqlite3.Connection': +def connect_db(db_file: str) -> "sqlite3.Connection": import sqlite3.dbapi2 db = sqlite3.dbapi2.connect(db_file) @@ -172,14 +174,14 @@ def __init__(self, cache_dir_prefix: str) -> None: return os.makedirs(cache_dir_prefix, exist_ok=True) - self.db = connect_db(os.path.join(cache_dir_prefix, 'cache.db')) + self.db = connect_db(os.path.join(cache_dir_prefix, "cache.db")) def _query(self, name: str, field: str) -> Any: # Raises FileNotFound for consistency with the file system version if not self.db: raise FileNotFoundError() - cur = self.db.execute(f'SELECT {field} FROM files WHERE path = ?', (name,)) + cur = self.db.execute(f"SELECT {field} FROM files WHERE path = ?", (name,)) results = cur.fetchall() if not results: raise FileNotFoundError() @@ -187,10 +189,10 @@ def _query(self, name: str, field: str) -> Any: return results[0][0] def getmtime(self, name: str) -> float: - return self._query(name, 'mtime') + return self._query(name, "mtime") def read(self, name: str) -> str: - return self._query(name, 'data') + return self._query(name, "data") def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool: import sqlite3 @@ -200,8 +202,10 @@ def write(self, name: str, data: str, mtime: Optional[float] = None) -> bool: try: if mtime is None: mtime = time.time() - self.db.execute('INSERT OR REPLACE INTO files(path, mtime, data) VALUES(?, ?, ?)', - (name, mtime, data)) + self.db.execute( + "INSERT OR REPLACE INTO files(path, mtime, data) VALUES(?, ?, ?)", + (name, mtime, data), + ) except sqlite3.OperationalError: return False return True @@ -210,7 +214,7 @@ def remove(self, name: str) -> None: if not self.db: raise FileNotFoundError() - self.db.execute('DELETE FROM files WHERE path = ?', (name,)) + self.db.execute("DELETE FROM files WHERE path = ?", (name,)) def commit(self) -> None: if self.db: @@ -218,5 +222,5 @@ def commit(self) -> None: def list_all(self) -> Iterable[str]: if self.db: - for row in self.db.execute('SELECT path FROM files'): + for row in self.db.execute("SELECT path FROM files"): yield row[0] diff --git a/mypy/mixedtraverser.py b/mypy/mixedtraverser.py index c14648cdf6541..425752c1c1292 100644 --- a/mypy/mixedtraverser.py +++ b/mypy/mixedtraverser.py @@ -1,12 +1,24 @@ from typing import Optional from mypy.nodes import ( - AssertTypeExpr, Var, FuncItem, ClassDef, AssignmentStmt, ForStmt, WithStmt, - CastExpr, TypeApplication, TypeAliasExpr, TypeVarExpr, TypedDictExpr, NamedTupleExpr, - PromoteExpr, NewTypeExpr + AssertTypeExpr, + AssignmentStmt, + CastExpr, + ClassDef, + ForStmt, + FuncItem, + NamedTupleExpr, + NewTypeExpr, + PromoteExpr, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeVarExpr, + Var, + WithStmt, ) -from mypy.types import Type from mypy.traverser import TraverserVisitor +from mypy.types import Type from mypy.typetraverser import TypeTraverserVisitor diff --git a/mypy/modulefinder.py b/mypy/modulefinder.py index c9e3d058ffbda..2b61905c60e74 100644 --- a/mypy/modulefinder.py +++ b/mypy/modulefinder.py @@ -20,13 +20,14 @@ import tomli as tomllib from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union + from typing_extensions import Final, TypeAlias as _TypeAlias +from mypy import pyinfo from mypy.fscache import FileSystemCache from mypy.nodes import MypyFile from mypy.options import Options from mypy.stubinfo import is_legacy_bundled_package -from mypy import pyinfo # Paths to be searched in find_module(). @@ -78,12 +79,14 @@ def error_message_templates(self, daemon: bool) -> Tuple[str, List[str]]: notes = [doc_link] elif self is ModuleNotFoundReason.WRONG_WORKING_DIRECTORY: msg = 'Cannot find implementation or library stub for module named "{module}"' - notes = ["You may be running mypy in a subpackage, " - "mypy should be run on the package root"] + notes = [ + "You may be running mypy in a subpackage, " + "mypy should be run on the package root" + ] elif self is ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS: msg = ( 'Skipping analyzing "{module}": module is installed, but missing library stubs ' - 'or py.typed marker' + "or py.typed marker" ) notes = [doc_link] elif self is ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED: @@ -93,7 +96,8 @@ def error_message_templates(self, daemon: bool) -> Tuple[str, List[str]]: notes = ['Hint: "python3 -m pip install {stub_dist}"'] if not daemon: notes.append( - '(or run "mypy --install-types" to install all missing stub packages)') + '(or run "mypy --install-types" to install all missing stub packages)' + ) notes.append(doc_link) else: assert False @@ -108,19 +112,22 @@ def error_message_templates(self, daemon: bool) -> Tuple[str, List[str]]: class BuildSource: """A single source file.""" - def __init__(self, path: Optional[str], module: Optional[str], - text: Optional[str] = None, base_dir: Optional[str] = None) -> None: + def __init__( + self, + path: Optional[str], + module: Optional[str], + text: Optional[str] = None, + base_dir: Optional[str] = None, + ) -> None: self.path = path # File where it's found (e.g. 'xxx/yyy/foo/bar.py') - self.module = module or '__main__' # Module name (e.g. 'foo.bar') + self.module = module or "__main__" # Module name (e.g. 'foo.bar') self.text = text # Source code, if initially supplied, else None self.base_dir = base_dir # Directory where the package is rooted (e.g. 'xxx/yyy') def __repr__(self) -> str: - return 'BuildSource(path={!r}, module={!r}, has_text={}, base_dir={!r})'.format( - self.path, - self.module, - self.text is not None, - self.base_dir) + return "BuildSource(path={!r}, module={!r}, has_text={}, base_dir={!r})".format( + self.path, self.module, self.text is not None, self.base_dir + ) class BuildSourceSet: @@ -137,7 +144,7 @@ def __init__(self, sources: List[BuildSource]) -> None: if source.path: self.source_paths.add(source.path) if source.module: - self.source_modules[source.module] = source.path or '' + self.source_modules[source.module] = source.path or "" def is_source(self, file: MypyFile) -> bool: if file.path and file.path in self.source_paths: @@ -161,12 +168,14 @@ class FindModuleCache: cleared by client code. """ - def __init__(self, - search_paths: SearchPaths, - fscache: Optional[FileSystemCache], - options: Optional[Options], - stdlib_py_versions: Optional[StdlibVersions] = None, - source_set: Optional[BuildSourceSet] = None) -> None: + def __init__( + self, + search_paths: SearchPaths, + fscache: Optional[FileSystemCache], + options: Optional[Options], + stdlib_py_versions: Optional[StdlibVersions] = None, + source_set: Optional[BuildSourceSet] = None, + ) -> None: self.search_paths = search_paths self.source_set = source_set self.fscache = fscache or FileSystemCache() @@ -180,8 +189,8 @@ def __init__(self, custom_typeshed_dir = None if options: custom_typeshed_dir = options.custom_typeshed_dir - self.stdlib_py_versions = ( - stdlib_py_versions or load_stdlib_py_versions(custom_typeshed_dir) + self.stdlib_py_versions = stdlib_py_versions or load_stdlib_py_versions( + custom_typeshed_dir ) self.python_major_ver = 3 if options is None else options.python_version[0] @@ -204,14 +213,15 @@ def find_module_via_source_set(self, id: str) -> Optional[ModuleSearchResult]: # in case of deletion of init files, which is covered by some tests. # TODO: are there some combination of flags in which this check should be skipped? d = os.path.dirname(p) - for _ in range(id.count('.')): - if not any(self.fscache.isfile(os.path.join(d, '__init__' + x)) - for x in PYTHON_EXTENSIONS): + for _ in range(id.count(".")): + if not any( + self.fscache.isfile(os.path.join(d, "__init__" + x)) for x in PYTHON_EXTENSIONS + ): return None d = os.path.dirname(d) return p - idx = id.rfind('.') + idx = id.rfind(".") if idx != -1: # When we're looking for foo.bar.baz and can't find a matching module # in the source set, look up for a foo.bar module. @@ -220,8 +230,9 @@ def find_module_via_source_set(self, id: str) -> Optional[ModuleSearchResult]: return None basename, ext = os.path.splitext(parent) - if (not any(parent.endswith('__init__' + x) for x in PYTHON_EXTENSIONS) - and (ext in PYTHON_EXTENSIONS and not self.fscache.isdir(basename))): + if not any(parent.endswith("__init__" + x) for x in PYTHON_EXTENSIONS) and ( + ext in PYTHON_EXTENSIONS and not self.fscache.isdir(basename) + ): # If we do find such a *module* (and crucially, we don't want a package, # hence the filtering out of __init__ files, and checking for the presence # of a folder with a matching name), then we can be pretty confident that @@ -242,7 +253,7 @@ def find_lib_path_dirs(self, id: str, lib_path: Tuple[str, ...]) -> PackageDirs: This is run for the python_path, mypy_path, and typeshed_path search paths. """ - components = id.split('.') + components = id.split(".") dir_chain = os.sep.join(components[:-1]) # e.g., 'foo/bar' dirs = [] @@ -282,8 +293,7 @@ def get_toplevel_possibilities(self, lib_path: Tuple[str, ...], id: str) -> List components.setdefault(name, []).append(dir) if self.python_major_ver == 2: - components = {id: filter_redundant_py2_dirs(dirs) - for id, dirs in components.items()} + components = {id: filter_redundant_py2_dirs(dirs) for id, dirs in components.items()} self.initial_components[lib_path] = components return components.get(id, []) @@ -295,16 +305,18 @@ def find_module(self, id: str, *, fast_path: bool = False) -> ModuleSearchResult error descriptions. """ if id not in self.results: - top_level = id.partition('.')[0] + top_level = id.partition(".")[0] use_typeshed = True if id in self.stdlib_py_versions: use_typeshed = self._typeshed_has_version(id) elif top_level in self.stdlib_py_versions: use_typeshed = self._typeshed_has_version(top_level) self.results[id] = self._find_module(id, use_typeshed) - if (not (fast_path or (self.options is not None and self.options.fast_module_lookup)) - and self.results[id] is ModuleNotFoundReason.NOT_FOUND - and self._can_find_module_in_parent_dir(id)): + if ( + not (fast_path or (self.options is not None and self.options.fast_module_lookup)) + and self.results[id] is ModuleNotFoundReason.NOT_FOUND + and self._can_find_module_in_parent_dir(id) + ): self.results[id] = ModuleNotFoundReason.WRONG_WORKING_DIRECTORY return self.results[id] @@ -315,26 +327,29 @@ def _typeshed_has_version(self, module: str) -> bool: min_version, max_version = self.stdlib_py_versions[module] return version >= min_version and (max_version is None or version <= max_version) - def _find_module_non_stub_helper(self, components: List[str], - pkg_dir: str) -> Union[OnePackageDir, ModuleNotFoundReason]: + def _find_module_non_stub_helper( + self, components: List[str], pkg_dir: str + ) -> Union[OnePackageDir, ModuleNotFoundReason]: plausible_match = False dir_path = pkg_dir for index, component in enumerate(components): dir_path = os.path.join(dir_path, component) - if self.fscache.isfile(os.path.join(dir_path, 'py.typed')): + if self.fscache.isfile(os.path.join(dir_path, "py.typed")): return os.path.join(pkg_dir, *components[:-1]), index == 0 - elif not plausible_match and (self.fscache.isdir(dir_path) - or self.fscache.isfile(dir_path + ".py")): + elif not plausible_match and ( + self.fscache.isdir(dir_path) or self.fscache.isfile(dir_path + ".py") + ): plausible_match = True # If this is not a directory then we can't traverse further into it if not self.fscache.isdir(dir_path): break if is_legacy_bundled_package(components[0], self.python_major_ver): - if (len(components) == 1 - or (self.find_module(components[0]) is - ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED)): + if len(components) == 1 or ( + self.find_module(components[0]) + is ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED + ): return ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED - if is_legacy_bundled_package('.'.join(components[:2]), self.python_major_ver): + if is_legacy_bundled_package(".".join(components[:2]), self.python_major_ver): return ModuleNotFoundReason.APPROVED_STUBS_NOT_INSTALLED if plausible_match: return ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS @@ -344,7 +359,7 @@ def _find_module_non_stub_helper(self, components: List[str], def _update_ns_ancestors(self, components: List[str], match: Tuple[str, bool]) -> None: path, verify = match for i in range(1, len(components)): - pkg_id = '.'.join(components[:-i]) + pkg_id = ".".join(components[:-i]) if pkg_id not in self.ns_ancestors and self.fscache.isdir(path): self.ns_ancestors[pkg_id] = path path = os.path.dirname(path) @@ -358,10 +373,11 @@ def _can_find_module_in_parent_dir(self, id: str) -> bool: SearchPaths((), (), (), ()), self.fscache, self.options, - stdlib_py_versions=self.stdlib_py_versions + stdlib_py_versions=self.stdlib_py_versions, ) - while any(file.endswith(("__init__.py", "__init__.pyi")) - for file in os.listdir(working_dir)): + while any( + file.endswith(("__init__.py", "__init__.pyi")) for file in os.listdir(working_dir) + ): working_dir = os.path.dirname(working_dir) parent_search.search_paths = SearchPaths((working_dir,), (), (), ()) if not isinstance(parent_search._find_module(id, False), ModuleNotFoundReason): @@ -398,9 +414,11 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # # Thankfully, such cases are efficiently handled by looking up the module path # via BuildSourceSet. - p = (self.find_module_via_source_set(id) - if (self.options is not None and self.options.fast_module_lookup) - else None) + p = ( + self.find_module_via_source_set(id) + if (self.options is not None and self.options.fast_module_lookup) + else None + ) if p: return p @@ -408,7 +426,7 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # many elements of lib_path don't even have a subdirectory 'foo/bar'. Discover # that only once and cache it for when we look for modules like 'foo.bar.blah' # that will require the same subdirectory. - components = id.split('.') + components = id.split(".") dir_chain = os.sep.join(components[:-1]) # e.g., 'foo/bar' # We have two sets of folders so that we collect *all* stubs folders and @@ -419,16 +437,16 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: need_installed_stubs = False # Third-party stub/typed packages for pkg_dir in self.search_paths.package_path: - stub_name = components[0] + '-stubs' + stub_name = components[0] + "-stubs" stub_dir = os.path.join(pkg_dir, stub_name) if self.python_major_ver == 2: - alt_stub_name = components[0] + '-python2-stubs' + alt_stub_name = components[0] + "-python2-stubs" alt_stub_dir = os.path.join(pkg_dir, alt_stub_name) if fscache.isdir(alt_stub_dir): stub_name = alt_stub_name stub_dir = alt_stub_dir if fscache.isdir(stub_dir) and self._is_compatible_stub_package(stub_dir): - stub_typed_file = os.path.join(stub_dir, 'py.typed') + stub_typed_file = os.path.join(stub_dir, "py.typed") stub_components = [stub_name] + components[1:] path = os.path.join(pkg_dir, *stub_components[:-1]) if fscache.isdir(path): @@ -437,7 +455,7 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # 'partial\n' to make the package partial # Partial here means that mypy should look at the runtime # package if installed. - if fscache.read(stub_typed_file).decode().strip() == 'partial': + if fscache.read(stub_typed_file).decode().strip() == "partial": runtime_path = os.path.join(pkg_dir, dir_chain) third_party_inline_dirs.append((runtime_path, True)) # if the package is partial, we don't verify the module, as @@ -477,7 +495,7 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # elements of lib_path. This is probably much shorter than lib_path itself. # Now just look for 'baz.pyi', 'baz/__init__.py', etc., inside those directories. seplast = os.sep + components[-1] # so e.g. '/baz' - sepinit = os.sep + '__init__' + sepinit = os.sep + "__init__" near_misses = [] # Collect near misses for namespace mode (see below). for base_dir, verify in candidate_base_dirs: base_path = base_dir + seplast # so e.g. '/usr/lib/python3.4/foo/bar/baz' @@ -488,10 +506,10 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # Prefer package over module, i.e. baz/__init__.py* over baz.py*. for extension in PYTHON_EXTENSIONS: path = base_path + sepinit + extension - suffix = '-stubs' + suffix = "-stubs" if self.python_major_ver == 2: - if os.path.isdir(base_path + '-python2-stubs'): - suffix = '-python2-stubs' + if os.path.isdir(base_path + "-python2-stubs"): + suffix = "-python2-stubs" path_stubs = base_path + suffix + sepinit + extension if fscache.isfile_case(path, dir_prefix): has_init = True @@ -541,8 +559,10 @@ def _find_module(self, id: str, use_typeshed: bool) -> ModuleSearchResult: # foo/__init__.py it returns 2 (regardless of what's in # foo/bar). It doesn't look higher than that. if self.options and self.options.namespace_packages and near_misses: - levels = [highest_init_level(fscache, id, path, dir_prefix) - for path, dir_prefix in near_misses] + levels = [ + highest_init_level(fscache, id, path, dir_prefix) + for path, dir_prefix in near_misses + ] index = levels.index(max(levels)) return near_misses[index][0] @@ -567,14 +587,14 @@ def _is_compatible_stub_package(self, stub_dir: str) -> bool: Stub packages may contain a metadata file which specifies whether the stubs are compatible with Python 2 and 3. """ - metadata_fnam = os.path.join(stub_dir, 'METADATA.toml') + metadata_fnam = os.path.join(stub_dir, "METADATA.toml") if os.path.isfile(metadata_fnam): with open(metadata_fnam, "rb") as f: metadata = tomllib.load(f) if self.python_major_ver == 2: - return bool(metadata.get('python2', False)) + return bool(metadata.get("python2", False)) else: - return bool(metadata.get('python3', True)) + return bool(metadata.get("python3", True)) return True def find_modules_recursive(self, module: str) -> List[BuildSource]: @@ -584,7 +604,7 @@ def find_modules_recursive(self, module: str) -> List[BuildSource]: sources = [BuildSource(module_path, module, None)] package_path = None - if module_path.endswith(('__init__.py', '__init__.pyi')): + if module_path.endswith(("__init__.py", "__init__.pyi")): package_path = os.path.dirname(module_path) elif self.fscache.isdir(module_path): package_path = module_path @@ -616,23 +636,22 @@ def find_modules_recursive(self, module: str) -> List[BuildSource]: or self.fscache.isfile(os.path.join(subpath, "__init__.pyi")) ): seen.add(name) - sources.extend(self.find_modules_recursive(module + '.' + name)) + sources.extend(self.find_modules_recursive(module + "." + name)) else: stem, suffix = os.path.splitext(name) - if stem == '__init__': + if stem == "__init__": continue - if stem not in seen and '.' not in stem and suffix in PYTHON_EXTENSIONS: + if stem not in seen and "." not in stem and suffix in PYTHON_EXTENSIONS: # (If we sorted names by keyfunc) we could probably just make the BuildSource # ourselves, but this ensures compatibility with find_module / the cache seen.add(stem) - sources.extend(self.find_modules_recursive(module + '.' + stem)) + sources.extend(self.find_modules_recursive(module + "." + stem)) return sources -def matches_exclude(subpath: str, - excludes: List[str], - fscache: FileSystemCache, - verbose: bool) -> bool: +def matches_exclude( + subpath: str, excludes: List[str], fscache: FileSystemCache, verbose: bool +) -> bool: if not excludes: return False subpath_str = os.path.relpath(subpath).replace(os.sep, "/") @@ -641,49 +660,52 @@ def matches_exclude(subpath: str, for exclude in excludes: if re.search(exclude, subpath_str): if verbose: - print(f"TRACE: Excluding {subpath_str} (matches pattern {exclude})", - file=sys.stderr) + print( + f"TRACE: Excluding {subpath_str} (matches pattern {exclude})", file=sys.stderr + ) return True return False def verify_module(fscache: FileSystemCache, id: str, path: str, prefix: str) -> bool: """Check that all packages containing id have a __init__ file.""" - if path.endswith(('__init__.py', '__init__.pyi')): + if path.endswith(("__init__.py", "__init__.pyi")): path = os.path.dirname(path) - for i in range(id.count('.')): + for i in range(id.count(".")): path = os.path.dirname(path) - if not any(fscache.isfile_case(os.path.join(path, f'__init__{extension}'), - prefix) - for extension in PYTHON_EXTENSIONS): + if not any( + fscache.isfile_case(os.path.join(path, f"__init__{extension}"), prefix) + for extension in PYTHON_EXTENSIONS + ): return False return True def highest_init_level(fscache: FileSystemCache, id: str, path: str, prefix: str) -> int: """Compute the highest level where an __init__ file is found.""" - if path.endswith(('__init__.py', '__init__.pyi')): + if path.endswith(("__init__.py", "__init__.pyi")): path = os.path.dirname(path) level = 0 - for i in range(id.count('.')): + for i in range(id.count(".")): path = os.path.dirname(path) - if any(fscache.isfile_case(os.path.join(path, f'__init__{extension}'), - prefix) - for extension in PYTHON_EXTENSIONS): + if any( + fscache.isfile_case(os.path.join(path, f"__init__{extension}"), prefix) + for extension in PYTHON_EXTENSIONS + ): level = i + 1 return level def mypy_path() -> List[str]: - path_env = os.getenv('MYPYPATH') + path_env = os.getenv("MYPYPATH") if not path_env: return [] return path_env.split(os.pathsep) -def default_lib_path(data_dir: str, - pyversion: Tuple[int, int], - custom_typeshed_dir: Optional[str]) -> List[str]: +def default_lib_path( + data_dir: str, pyversion: Tuple[int, int], custom_typeshed_dir: Optional[str] +) -> List[str]: """Return default standard library search paths.""" path: List[str] = [] @@ -692,11 +714,14 @@ def default_lib_path(data_dir: str, mypy_extensions_dir = os.path.join(custom_typeshed_dir, "stubs", "mypy-extensions") versions_file = os.path.join(typeshed_dir, "VERSIONS") if not os.path.isdir(typeshed_dir) or not os.path.isfile(versions_file): - print("error: --custom-typeshed-dir does not point to a valid typeshed ({})".format( - custom_typeshed_dir)) + print( + "error: --custom-typeshed-dir does not point to a valid typeshed ({})".format( + custom_typeshed_dir + ) + ) sys.exit(2) else: - auto = os.path.join(data_dir, 'stubs-auto') + auto = os.path.join(data_dir, "stubs-auto") if os.path.isdir(auto): data_dir = auto typeshed_dir = os.path.join(data_dir, "typeshed", "stdlib") @@ -712,12 +737,16 @@ def default_lib_path(data_dir: str, path.append(mypy_extensions_dir) # Add fallback path that can be used if we have a broken installation. - if sys.platform != 'win32': - path.append('/usr/local/lib/mypy') + if sys.platform != "win32": + path.append("/usr/local/lib/mypy") if not path: - print("Could not resolve typeshed subdirectories. Your mypy install is broken.\n" - "Python executable is located at {}.\nMypy located at {}".format( - sys.executable, data_dir), file=sys.stderr) + print( + "Could not resolve typeshed subdirectories. Your mypy install is broken.\n" + "Python executable is located at {}.\nMypy located at {}".format( + sys.executable, data_dir + ), + file=sys.stderr, + ) sys.exit(1) return path @@ -741,8 +770,10 @@ def get_search_dirs(python_executable: Optional[str]) -> Tuple[List[str], List[s # executable try: sys_path, site_packages = ast.literal_eval( - subprocess.check_output([python_executable, pyinfo.__file__, 'getsearchdirs'], - stderr=subprocess.PIPE).decode()) + subprocess.check_output( + [python_executable, pyinfo.__file__, "getsearchdirs"], stderr=subprocess.PIPE + ).decode() + ) except OSError as err: reason = os.strerror(err.errno) raise CompileError( @@ -770,10 +801,9 @@ def add_py2_mypypath_entries(mypypath: List[str]) -> List[str]: return result -def compute_search_paths(sources: List[BuildSource], - options: Options, - data_dir: str, - alt_lib_path: Optional[str] = None) -> SearchPaths: +def compute_search_paths( + sources: List[BuildSource], options: Options, data_dir: str, alt_lib_path: Optional[str] = None +) -> SearchPaths: """Compute the search paths as specified in PEP 561. There are the following 4 members created: @@ -781,22 +811,23 @@ def compute_search_paths(sources: List[BuildSource], - MYPYPATH (set either via config or environment variable) - installed package directories (which will later be split into stub-only and inline) - typeshed - """ + """ # Determine the default module search path. lib_path = collections.deque( - default_lib_path(data_dir, - options.python_version, - custom_typeshed_dir=options.custom_typeshed_dir)) + default_lib_path( + data_dir, options.python_version, custom_typeshed_dir=options.custom_typeshed_dir + ) + ) if options.use_builtins_fixtures: # Use stub builtins (to speed up test cases and to make them easier to # debug). This is a test-only feature, so assume our files are laid out # as in the source tree. # We also need to allow overriding where to look for it. Argh. - root_dir = os.getenv('MYPY_TEST_PREFIX', None) + root_dir = os.getenv("MYPY_TEST_PREFIX", None) if not root_dir: root_dir = os.path.dirname(os.path.dirname(__file__)) - lib_path.appendleft(os.path.join(root_dir, 'test-data', 'unit', 'lib-stub')) + lib_path.appendleft(os.path.join(root_dir, "test-data", "unit", "lib-stub")) # alt_lib_path is used by some tests to bypass the normal lib_path mechanics. # If we don't have one, grab directories of source files. python_path: List[str] = [] @@ -815,7 +846,7 @@ def compute_search_paths(sources: List[BuildSource], # TODO: Don't do this in some cases; for motivation see see # https://github.com/python/mypy/issues/4195#issuecomment-341915031 if options.bazel: - dir = '.' + dir = "." else: dir = os.getcwd() if dir not in lib_path: @@ -847,8 +878,11 @@ def compute_search_paths(sources: List[BuildSource], or (os.path.altsep and any(p.startswith(site + os.path.altsep) for p in mypypath)) ): print(f"{site} is in the MYPYPATH. Please remove it.", file=sys.stderr) - print("See https://mypy.readthedocs.io/en/stable/running_mypy.html" - "#how-mypy-handles-imports for more info", file=sys.stderr) + print( + "See https://mypy.readthedocs.io/en/stable/running_mypy.html" + "#how-mypy-handles-imports for more info", + file=sys.stderr, + ) sys.exit(1) return SearchPaths( @@ -881,8 +915,9 @@ def load_stdlib_py_versions(custom_typeshed_dir: Optional[str]) -> StdlibVersion module, version_range = line.split(":") versions = version_range.split("-") min_version = parse_version(versions[0]) - max_version = (parse_version(versions[1]) - if len(versions) >= 2 and versions[1].strip() else None) + max_version = ( + parse_version(versions[1]) if len(versions) >= 2 and versions[1].strip() else None + ) result[module] = min_version, max_version # Modules that are Python 2 only or have separate Python 2 stubs diff --git a/mypy/moduleinspect.py b/mypy/moduleinspect.py index 90532ae191508..ec2e964f7ffc6 100644 --- a/mypy/moduleinspect.py +++ b/mypy/moduleinspect.py @@ -1,25 +1,27 @@ """Basic introspection of modules.""" -from typing import List, Optional, Union -from types import ModuleType -from multiprocessing import Process, Queue import importlib import inspect import os import pkgutil import queue import sys +from multiprocessing import Process, Queue +from types import ModuleType +from typing import List, Optional, Union class ModuleProperties: # Note that all __init__ args must have default values - def __init__(self, - name: str = "", - file: Optional[str] = None, - path: Optional[List[str]] = None, - all: Optional[List[str]] = None, - is_c_module: bool = False, - subpackages: Optional[List[str]] = None) -> None: + def __init__( + self, + name: str = "", + file: Optional[str] = None, + path: Optional[List[str]] = None, + all: Optional[List[str]] = None, + is_c_module: bool = False, + subpackages: Optional[List[str]] = None, + ) -> None: self.name = name # __name__ attribute self.file = file # __file__ attribute self.path = path # __path__ attribute @@ -29,11 +31,11 @@ def __init__(self, def is_c_module(module: ModuleType) -> bool: - if module.__dict__.get('__file__') is None: + if module.__dict__.get("__file__") is None: # Could be a namespace package. These must be handled through # introspection, since there is no source file. return True - return os.path.splitext(module.__dict__['__file__'])[-1] in ['.so', '.pyd'] + return os.path.splitext(module.__dict__["__file__"])[-1] in [".so", ".pyd"] class InspectError(Exception): @@ -51,7 +53,7 @@ def get_package_properties(package_id: str) -> ModuleProperties: path: Optional[List[str]] = getattr(package, "__path__", None) if not isinstance(path, list): path = None - pkg_all = getattr(package, '__all__', None) + pkg_all = getattr(package, "__all__", None) if pkg_all is not None: try: pkg_all = list(pkg_all) @@ -65,28 +67,27 @@ def get_package_properties(package_id: str) -> ModuleProperties: if is_c: # This is a C extension module, now get the list of all sub-packages # using the inspect module - subpackages = [package.__name__ + "." + name - for name, val in inspect.getmembers(package) - if inspect.ismodule(val) - and val.__name__ == package.__name__ + "." + name] + subpackages = [ + package.__name__ + "." + name + for name, val in inspect.getmembers(package) + if inspect.ismodule(val) and val.__name__ == package.__name__ + "." + name + ] else: # It's a module inside a package. There's nothing else to walk/yield. subpackages = [] else: - all_packages = pkgutil.walk_packages(path, prefix=package.__name__ + ".", - onerror=lambda r: None) + all_packages = pkgutil.walk_packages( + path, prefix=package.__name__ + ".", onerror=lambda r: None + ) subpackages = [qualified_name for importer, qualified_name, ispkg in all_packages] - return ModuleProperties(name=name, - file=file, - path=path, - all=pkg_all, - is_c_module=is_c, - subpackages=subpackages) + return ModuleProperties( + name=name, file=file, path=path, all=pkg_all, is_c_module=is_c, subpackages=subpackages + ) -def worker(tasks: 'Queue[str]', - results: 'Queue[Union[str, ModuleProperties]]', - sys_path: List[str]) -> None: +def worker( + tasks: "Queue[str]", results: "Queue[Union[str, ModuleProperties]]", sys_path: List[str] +) -> None: """The main loop of a worker introspection process.""" sys.path = sys_path while True: @@ -139,7 +140,7 @@ def get_package_properties(self, package_id: str) -> ModuleProperties: if res is None: # The process died; recover and report error. self._start() - raise InspectError(f'Process died when importing {package_id!r}') + raise InspectError(f"Process died when importing {package_id!r}") if isinstance(res, str): # Error importing module if self.counter > 0: @@ -161,7 +162,7 @@ def _get_from_queue(self) -> Union[ModuleProperties, str, None]: n = 0 while True: if n == max_iter: - raise RuntimeError('Timeout waiting for subprocess') + raise RuntimeError("Timeout waiting for subprocess") try: return self.results.get(timeout=0.05) except queue.Empty: @@ -169,7 +170,7 @@ def _get_from_queue(self) -> Union[ModuleProperties, str, None]: return None n += 1 - def __enter__(self) -> 'ModuleInspect': + def __enter__(self) -> "ModuleInspect": return self def __exit__(self, *args: object) -> None: diff --git a/mypy/mro.py b/mypy/mro.py index 1bea83c6d97d5..3c29013d62c9d 100644 --- a/mypy/mro.py +++ b/mypy/mro.py @@ -1,4 +1,4 @@ -from typing import Optional, Callable, List +from typing import Callable, List, Optional from mypy.nodes import TypeInfo from mypy.types import Instance @@ -22,14 +22,14 @@ class MroError(Exception): """Raised if a consistent mro cannot be determined for a class.""" -def linearize_hierarchy(info: TypeInfo, - obj_type: Optional[Callable[[], Instance]] = None) -> List[TypeInfo]: +def linearize_hierarchy( + info: TypeInfo, obj_type: Optional[Callable[[], Instance]] = None +) -> List[TypeInfo]: # TODO describe if info.mro: return info.mro bases = info.direct_base_classes() - if (not bases and info.fullname != 'builtins.object' and - obj_type is not None): + if not bases and info.fullname != "builtins.object" and obj_type is not None: # Second pass in import cycle, add a dummy `object` base class, # otherwise MRO calculation may spuriously fail. # MRO will be re-calculated for real in the third pass. diff --git a/mypy/nodes.py b/mypy/nodes.py index 75ec06583f9be..25689be806fe4 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1,21 +1,32 @@ """Abstract syntax tree node classes (i.e. parse tree).""" import os -from enum import Enum, unique from abc import abstractmethod -from mypy.backports import OrderedDict from collections import defaultdict +from enum import Enum, unique from typing import ( - Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional, Callable, Sequence, Iterator + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, ) -from typing_extensions import DefaultDict, Final, TYPE_CHECKING, TypeAlias as _TypeAlias + from mypy_extensions import trait +from typing_extensions import TYPE_CHECKING, DefaultDict, Final, TypeAlias as _TypeAlias import mypy.strconv -from mypy.util import short_type -from mypy.visitor import NodeVisitor, StatementVisitor, ExpressionVisitor - +from mypy.backports import OrderedDict from mypy.bogus_type import Bogus +from mypy.util import short_type +from mypy.visitor import ExpressionVisitor, NodeVisitor, StatementVisitor if TYPE_CHECKING: from mypy.patterns import Pattern @@ -23,7 +34,8 @@ class Context: """Base type for objects that are valid as error message locations.""" - __slots__ = ('line', 'column', 'end_line', 'end_column') + + __slots__ = ("line", "column", "end_line", "end_column") def __init__(self, line: int = -1, column: int = -1) -> None: self.line = line @@ -31,11 +43,13 @@ def __init__(self, line: int = -1, column: int = -1) -> None: self.end_line: Optional[int] = None self.end_column: Optional[int] = None - def set_line(self, - target: Union['Context', int], - column: Optional[int] = None, - end_line: Optional[int] = None, - end_column: Optional[int] = None) -> None: + def set_line( + self, + target: Union["Context", int], + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + ) -> None: """If target is a node, pull line (and column) information into this node. If column is specified, this will override any column information coming from a node. @@ -71,7 +85,7 @@ def get_column(self) -> int: import mypy.types -T = TypeVar('T') +T = TypeVar("T") JsonDict: _TypeAlias = Dict[str, Any] @@ -98,77 +112,72 @@ def get_column(self) -> int: LITERAL_TYPE: Final = 1 LITERAL_NO: Final = 0 -node_kinds: Final = { - LDEF: 'Ldef', - GDEF: 'Gdef', - MDEF: 'Mdef', - UNBOUND_IMPORTED: 'UnboundImported', -} +node_kinds: Final = {LDEF: "Ldef", GDEF: "Gdef", MDEF: "Mdef", UNBOUND_IMPORTED: "UnboundImported"} inverse_node_kinds: Final = {_kind: _name for _name, _kind in node_kinds.items()} implicit_module_attrs: Final = { - '__name__': '__builtins__.str', - '__doc__': None, # depends on Python version, see semanal.py - '__path__': None, # depends on if the module is a package - '__file__': '__builtins__.str', - '__package__': '__builtins__.str', - '__annotations__': None, # dict[str, Any] bounded in add_implicit_module_attrs() + "__name__": "__builtins__.str", + "__doc__": None, # depends on Python version, see semanal.py + "__path__": None, # depends on if the module is a package + "__file__": "__builtins__.str", + "__package__": "__builtins__.str", + "__annotations__": None, # dict[str, Any] bounded in add_implicit_module_attrs() } # These aliases exist because built-in class objects are not subscriptable. # For example `list[int]` fails at runtime. Instead List[int] should be used. type_aliases: Final = { - 'typing.List': 'builtins.list', - 'typing.Dict': 'builtins.dict', - 'typing.Set': 'builtins.set', - 'typing.FrozenSet': 'builtins.frozenset', - 'typing.ChainMap': 'collections.ChainMap', - 'typing.Counter': 'collections.Counter', - 'typing.DefaultDict': 'collections.defaultdict', - 'typing.Deque': 'collections.deque', - 'typing.OrderedDict': 'collections.OrderedDict', + "typing.List": "builtins.list", + "typing.Dict": "builtins.dict", + "typing.Set": "builtins.set", + "typing.FrozenSet": "builtins.frozenset", + "typing.ChainMap": "collections.ChainMap", + "typing.Counter": "collections.Counter", + "typing.DefaultDict": "collections.defaultdict", + "typing.Deque": "collections.deque", + "typing.OrderedDict": "collections.OrderedDict", # HACK: a lie in lieu of actual support for PEP 675 - 'typing.LiteralString': 'builtins.str', + "typing.LiteralString": "builtins.str", } # This keeps track of the oldest supported Python version where the corresponding # alias source is available. type_aliases_source_versions: Final = { - 'typing.List': (2, 7), - 'typing.Dict': (2, 7), - 'typing.Set': (2, 7), - 'typing.FrozenSet': (2, 7), - 'typing.ChainMap': (3, 3), - 'typing.Counter': (2, 7), - 'typing.DefaultDict': (2, 7), - 'typing.Deque': (2, 7), - 'typing.OrderedDict': (3, 7), - 'typing.LiteralString': (3, 11), + "typing.List": (2, 7), + "typing.Dict": (2, 7), + "typing.Set": (2, 7), + "typing.FrozenSet": (2, 7), + "typing.ChainMap": (3, 3), + "typing.Counter": (2, 7), + "typing.DefaultDict": (2, 7), + "typing.Deque": (2, 7), + "typing.OrderedDict": (3, 7), + "typing.LiteralString": (3, 11), } # This keeps track of aliases in `typing_extensions`, which we treat specially. typing_extensions_aliases: Final = { # See: https://github.com/python/mypy/issues/11528 - 'typing_extensions.OrderedDict': 'collections.OrderedDict', + "typing_extensions.OrderedDict": "collections.OrderedDict", # HACK: a lie in lieu of actual support for PEP 675 - 'typing_extensions.LiteralString': 'builtins.str', + "typing_extensions.LiteralString": "builtins.str", } reverse_builtin_aliases: Final = { - 'builtins.list': 'typing.List', - 'builtins.dict': 'typing.Dict', - 'builtins.set': 'typing.Set', - 'builtins.frozenset': 'typing.FrozenSet', + "builtins.list": "typing.List", + "builtins.dict": "typing.Dict", + "builtins.set": "typing.Set", + "builtins.frozenset": "typing.FrozenSet", } _nongen_builtins: Final = {"builtins.tuple": "typing.Tuple", "builtins.enumerate": ""} _nongen_builtins.update((name, alias) for alias, name in type_aliases.items()) # Drop OrderedDict from this for backward compatibility -del _nongen_builtins['collections.OrderedDict'] +del _nongen_builtins["collections.OrderedDict"] # HACK: consequence of hackily treating LiteralString as an alias for str -del _nongen_builtins['builtins.str'] +del _nongen_builtins["builtins.str"] def get_nongen_builtins(python_version: Tuple[int, int]) -> Dict[str, str]: @@ -195,7 +204,7 @@ def __str__(self) -> str: return ans def accept(self, visitor: NodeVisitor[T]) -> T: - raise RuntimeError('Not implemented') + raise RuntimeError("Not implemented") @trait @@ -205,7 +214,7 @@ class Statement(Node): __slots__ = () def accept(self, visitor: StatementVisitor[T]) -> T: - raise RuntimeError('Not implemented') + raise RuntimeError("Not implemented") @trait @@ -215,7 +224,7 @@ class Expression(Node): __slots__ = () def accept(self, visitor: ExpressionVisitor[T]) -> T: - raise RuntimeError('Not implemented') + raise RuntimeError("Not implemented") class FakeExpression(Expression): @@ -242,38 +251,52 @@ class SymbolNode(Node): @property @abstractmethod - def name(self) -> str: pass + def name(self) -> str: + pass # fullname can often be None even though the type system # disagrees. We mark this with Bogus to let mypyc know not to # worry about it. @property @abstractmethod - def fullname(self) -> Bogus[str]: pass + def fullname(self) -> Bogus[str]: + pass @abstractmethod - def serialize(self) -> JsonDict: pass + def serialize(self) -> JsonDict: + pass @classmethod - def deserialize(cls, data: JsonDict) -> 'SymbolNode': - classname = data['.class'] + def deserialize(cls, data: JsonDict) -> "SymbolNode": + classname = data[".class"] method = deserialize_map.get(classname) if method is not None: return method(data) - raise NotImplementedError(f'unexpected .class {classname}') + raise NotImplementedError(f"unexpected .class {classname}") # Items: fullname, related symbol table node, surrounding type (if any) -Definition: _TypeAlias = Tuple[str, 'SymbolTableNode', Optional['TypeInfo']] +Definition: _TypeAlias = Tuple[str, "SymbolTableNode", Optional["TypeInfo"]] class MypyFile(SymbolNode): """The abstract syntax tree of a single source file.""" - __slots__ = ('_fullname', 'path', 'defs', 'alias_deps', - 'is_bom', 'names', 'imports', 'ignored_lines', 'is_stub', - 'is_cache_skeleton', 'is_partial_stub_package', 'plugin_deps', - 'future_import_flags') + __slots__ = ( + "_fullname", + "path", + "defs", + "alias_deps", + "is_bom", + "names", + "imports", + "ignored_lines", + "is_stub", + "is_cache_skeleton", + "is_partial_stub_package", + "plugin_deps", + "future_import_flags", + ) # Fully qualified module name _fullname: Bogus[str] @@ -305,11 +328,13 @@ class MypyFile(SymbolNode): # Future imports defined in this file. Populated during semantic analysis. future_import_flags: Set[str] - def __init__(self, - defs: List[Statement], - imports: List['ImportBase'], - is_bom: bool = False, - ignored_lines: Optional[Dict[int, List[str]]] = None) -> None: + def __init__( + self, + defs: List[Statement], + imports: List["ImportBase"], + is_bom: bool = False, + ignored_lines: Optional[Dict[int, List[str]]] = None, + ) -> None: super().__init__() self.defs = defs self.line = 1 # Dummy line number @@ -322,7 +347,7 @@ def __init__(self, else: self.ignored_lines = {} - self.path = '' + self.path = "" self.is_stub = False self.is_cache_skeleton = False self.is_partial_stub_package = False @@ -337,7 +362,7 @@ def local_definitions(self) -> Iterator[Definition]: @property def name(self) -> str: - return '' if not self._fullname else self._fullname.split('.')[-1] + return "" if not self._fullname else self._fullname.split(".")[-1] @property def fullname(self) -> Bogus[str]: @@ -347,39 +372,40 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_mypy_file(self) def is_package_init_file(self) -> bool: - return len(self.path) != 0 and os.path.basename(self.path).startswith('__init__.') + return len(self.path) != 0 and os.path.basename(self.path).startswith("__init__.") def is_future_flag_set(self, flag: str) -> bool: return flag in self.future_import_flags def serialize(self) -> JsonDict: - return {'.class': 'MypyFile', - '_fullname': self._fullname, - 'names': self.names.serialize(self._fullname), - 'is_stub': self.is_stub, - 'path': self.path, - 'is_partial_stub_package': self.is_partial_stub_package, - 'future_import_flags': list(self.future_import_flags), - } + return { + ".class": "MypyFile", + "_fullname": self._fullname, + "names": self.names.serialize(self._fullname), + "is_stub": self.is_stub, + "path": self.path, + "is_partial_stub_package": self.is_partial_stub_package, + "future_import_flags": list(self.future_import_flags), + } @classmethod - def deserialize(cls, data: JsonDict) -> 'MypyFile': - assert data['.class'] == 'MypyFile', data + def deserialize(cls, data: JsonDict) -> "MypyFile": + assert data[".class"] == "MypyFile", data tree = MypyFile([], []) - tree._fullname = data['_fullname'] - tree.names = SymbolTable.deserialize(data['names']) - tree.is_stub = data['is_stub'] - tree.path = data['path'] - tree.is_partial_stub_package = data['is_partial_stub_package'] + tree._fullname = data["_fullname"] + tree.names = SymbolTable.deserialize(data["names"]) + tree.is_stub = data["is_stub"] + tree.path = data["path"] + tree.is_partial_stub_package = data["is_partial_stub_package"] tree.is_cache_skeleton = True - tree.future_import_flags = set(data['future_import_flags']) + tree.future_import_flags = set(data["future_import_flags"]) return tree class ImportBase(Statement): """Base class for all import statements.""" - __slots__ = ('is_unreachable', 'is_top_level', 'is_mypy_only', 'assignments') + __slots__ = ("is_unreachable", "is_top_level", "is_mypy_only", "assignments") is_unreachable: bool # Set by semanal.SemanticAnalyzerPass1 if inside `if False` etc. is_top_level: bool # Ditto if outside any class or def @@ -404,7 +430,7 @@ def __init__(self) -> None: class Import(ImportBase): """import m [as n]""" - __slots__ = ('ids',) + __slots__ = ("ids",) ids: List[Tuple[str, Optional[str]]] # (module id, as id) @@ -419,7 +445,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ImportFrom(ImportBase): """from m import x [as y], ...""" - __slots__ = ('id', 'names', 'relative') + __slots__ = ("id", "names", "relative") id: str relative: int @@ -438,7 +464,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ImportAll(ImportBase): """from m import *""" - __slots__ = ('id', 'relative', 'imported_names') + __slots__ = ("id", "relative", "imported_names") id: str relative: int @@ -467,7 +493,7 @@ class ImportedName(SymbolNode): can't be visited. """ - __slots__ = ('target_fullname',) + __slots__ = ("target_fullname",) def __init__(self, target_fullname: str) -> None: super().__init__() @@ -475,7 +501,7 @@ def __init__(self, target_fullname: str) -> None: @property def name(self) -> str: - return self.target_fullname.split('.')[-1] + return self.target_fullname.split(".")[-1] @property def fullname(self) -> str: @@ -485,11 +511,11 @@ def serialize(self) -> JsonDict: assert False, "ImportedName leaked from semantic analysis" @classmethod - def deserialize(cls, data: JsonDict) -> 'ImportedName': + def deserialize(cls, data: JsonDict) -> "ImportedName": assert False, "ImportedName should never be serialized" def __str__(self) -> str: - return f'ImportedName({self.target_fullname})' + return f"ImportedName({self.target_fullname})" FUNCBASE_FLAGS: Final = ["is_property", "is_class", "is_static", "is_final"] @@ -509,15 +535,16 @@ class FuncBase(Node): SymbolNode subclasses that are also FuncBase subclasses. """ - __slots__ = ('type', - 'unanalyzed_type', - 'info', - 'is_property', - 'is_class', # Uses "@classmethod" (explicit or implicit) - 'is_static', # Uses "@staticmethod" - 'is_final', # Uses "@final" - '_fullname', - ) + __slots__ = ( + "type", + "unanalyzed_type", + "info", + "is_property", + "is_class", # Uses "@classmethod" (explicit or implicit) + "is_static", # Uses "@staticmethod" + "is_final", # Uses "@final" + "_fullname", + ) def __init__(self) -> None: super().__init__() @@ -539,14 +566,15 @@ def __init__(self) -> None: @property @abstractmethod - def name(self) -> str: pass + def name(self) -> str: + pass @property def fullname(self) -> Bogus[str]: return self._fullname -OverloadPart: _TypeAlias = Union['FuncDef', 'Decorator'] +OverloadPart: _TypeAlias = Union["FuncDef", "Decorator"] class OverloadedFuncDef(FuncBase, SymbolNode, Statement): @@ -559,13 +587,13 @@ class OverloadedFuncDef(FuncBase, SymbolNode, Statement): Overloaded variants must be consecutive in the source file. """ - __slots__ = ('items', 'unanalyzed_items', 'impl') + __slots__ = ("items", "unanalyzed_items", "impl") items: List[OverloadPart] unanalyzed_items: List[OverloadPart] impl: Optional[OverloadPart] - def __init__(self, items: List['OverloadPart']) -> None: + def __init__(self, items: List["OverloadPart"]) -> None: super().__init__() self.items = items self.unanalyzed_items = items.copy() @@ -587,31 +615,32 @@ def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_overloaded_func_def(self) def serialize(self) -> JsonDict: - return {'.class': 'OverloadedFuncDef', - 'items': [i.serialize() for i in self.items], - 'type': None if self.type is None else self.type.serialize(), - 'fullname': self._fullname, - 'impl': None if self.impl is None else self.impl.serialize(), - 'flags': get_flags(self, FUNCBASE_FLAGS), - } + return { + ".class": "OverloadedFuncDef", + "items": [i.serialize() for i in self.items], + "type": None if self.type is None else self.type.serialize(), + "fullname": self._fullname, + "impl": None if self.impl is None else self.impl.serialize(), + "flags": get_flags(self, FUNCBASE_FLAGS), + } @classmethod - def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef': - assert data['.class'] == 'OverloadedFuncDef' - res = OverloadedFuncDef([ - cast(OverloadPart, SymbolNode.deserialize(d)) - for d in data['items']]) - if data.get('impl') is not None: - res.impl = cast(OverloadPart, SymbolNode.deserialize(data['impl'])) + def deserialize(cls, data: JsonDict) -> "OverloadedFuncDef": + assert data[".class"] == "OverloadedFuncDef" + res = OverloadedFuncDef( + [cast(OverloadPart, SymbolNode.deserialize(d)) for d in data["items"]] + ) + if data.get("impl") is not None: + res.impl = cast(OverloadPart, SymbolNode.deserialize(data["impl"])) # set line for empty overload items, as not set in __init__ if len(res.items) > 0: res.set_line(res.impl.line) - if data.get('type') is not None: - typ = mypy.types.deserialize_type(data['type']) + if data.get("type") is not None: + typ = mypy.types.deserialize_type(data["type"]) assert isinstance(typ, mypy.types.ProperType) res.type = typ - res._fullname = data['fullname'] - set_flags(res, data['flags']) + res._fullname = data["fullname"] + set_flags(res, data["flags"]) # NOTE: res.info will be set in the fixup phase. return res @@ -619,14 +648,16 @@ def deserialize(cls, data: JsonDict) -> 'OverloadedFuncDef': class Argument(Node): """A single argument in a FuncItem.""" - __slots__ = ('variable', 'type_annotation', 'initializer', 'kind', 'pos_only') + __slots__ = ("variable", "type_annotation", "initializer", "kind", "pos_only") - def __init__(self, - variable: 'Var', - type_annotation: 'Optional[mypy.types.Type]', - initializer: Optional[Expression], - kind: 'ArgKind', - pos_only: bool = False) -> None: + def __init__( + self, + variable: "Var", + type_annotation: "Optional[mypy.types.Type]", + initializer: Optional[Expression], + kind: "ArgKind", + pos_only: bool = False, + ) -> None: super().__init__() self.variable = variable self.type_annotation = type_annotation @@ -634,59 +665,64 @@ def __init__(self, self.kind = kind # must be an ARG_* constant self.pos_only = pos_only - def set_line(self, - target: Union[Context, int], - column: Optional[int] = None, - end_line: Optional[int] = None, - end_column: Optional[int] = None) -> None: + def set_line( + self, + target: Union[Context, int], + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + ) -> None: super().set_line(target, column, end_line, end_column) if self.initializer and self.initializer.line < 0: - self.initializer.set_line( - self.line, self.column, self.end_line, self.end_column) + self.initializer.set_line(self.line, self.column, self.end_line, self.end_column) - self.variable.set_line( - self.line, self.column, self.end_line, self.end_column) + self.variable.set_line(self.line, self.column, self.end_line, self.end_column) FUNCITEM_FLAGS: Final = FUNCBASE_FLAGS + [ - 'is_overload', 'is_generator', 'is_coroutine', 'is_async_generator', - 'is_awaitable_coroutine', + "is_overload", + "is_generator", + "is_coroutine", + "is_async_generator", + "is_awaitable_coroutine", ] class FuncItem(FuncBase): """Base class for nodes usable as overloaded function items.""" - __slots__ = ('arguments', # Note that can be unset if deserialized (type is a lie!) - 'arg_names', # Names of arguments - 'arg_kinds', # Kinds of arguments - 'min_args', # Minimum number of arguments - 'max_pos', # Maximum number of positional arguments, -1 if no explicit - # limit (*args not included) - 'body', # Body of the function - 'is_overload', # Is this an overload variant of function with more than - # one overload variant? - 'is_generator', # Contains a yield statement? - 'is_coroutine', # Defined using 'async def' syntax? - 'is_async_generator', # Is an async def generator? - 'is_awaitable_coroutine', # Decorated with '@{typing,asyncio}.coroutine'? - 'expanded', # Variants of function with type variables with values expanded - ) - - __deletable__ = ('arguments', 'max_pos', 'min_args') - - def __init__(self, - arguments: Optional[List[Argument]] = None, - body: Optional['Block'] = None, - typ: 'Optional[mypy.types.FunctionLike]' = None) -> None: + __slots__ = ( + "arguments", # Note that can be unset if deserialized (type is a lie!) + "arg_names", # Names of arguments + "arg_kinds", # Kinds of arguments + "min_args", # Minimum number of arguments + "max_pos", # Maximum number of positional arguments, -1 if no explicit + # limit (*args not included) + "body", # Body of the function + "is_overload", # Is this an overload variant of function with more than + # one overload variant? + "is_generator", # Contains a yield statement? + "is_coroutine", # Defined using 'async def' syntax? + "is_async_generator", # Is an async def generator? + "is_awaitable_coroutine", # Decorated with '@{typing,asyncio}.coroutine'? + "expanded", # Variants of function with type variables with values expanded + ) + + __deletable__ = ("arguments", "max_pos", "min_args") + + def __init__( + self, + arguments: Optional[List[Argument]] = None, + body: Optional["Block"] = None, + typ: "Optional[mypy.types.FunctionLike]" = None, + ) -> None: super().__init__() self.arguments = arguments or [] self.arg_names = [None if arg.pos_only else arg.variable.name for arg in self.arguments] self.arg_kinds: List[ArgKind] = [arg.kind for arg in self.arguments] - self.max_pos: int = ( - self.arg_kinds.count(ARG_POS) + self.arg_kinds.count(ARG_OPT)) - self.body: 'Block' = body or Block([]) + self.max_pos: int = self.arg_kinds.count(ARG_POS) + self.arg_kinds.count(ARG_OPT) + self.body: "Block" = body or Block([]) self.type = typ self.unanalyzed_type = typ self.is_overload: bool = False @@ -704,11 +740,13 @@ def __init__(self, def max_fixed_argc(self) -> int: return self.max_pos - def set_line(self, - target: Union[Context, int], - column: Optional[int] = None, - end_line: Optional[int] = None, - end_column: Optional[int] = None) -> None: + def set_line( + self, + target: Union[Context, int], + column: Optional[int] = None, + end_line: Optional[int] = None, + end_column: Optional[int] = None, + ) -> None: super().set_line(target, column, end_line, end_column) for arg in self.arguments: arg.set_line(self.line, self.column, self.end_line, end_column) @@ -717,9 +755,7 @@ def is_dynamic(self) -> bool: return self.type is None -FUNCDEF_FLAGS: Final = FUNCITEM_FLAGS + [ - 'is_decorated', 'is_conditional', 'is_abstract', -] +FUNCDEF_FLAGS: Final = FUNCITEM_FLAGS + ["is_decorated", "is_conditional", "is_abstract"] class FuncDef(FuncItem, SymbolNode, Statement): @@ -728,19 +764,16 @@ class FuncDef(FuncItem, SymbolNode, Statement): This is a non-lambda function defined using 'def'. """ - __slots__ = ('_name', - 'is_decorated', - 'is_conditional', - 'is_abstract', - 'original_def', - ) + __slots__ = ("_name", "is_decorated", "is_conditional", "is_abstract", "original_def") # Note that all __init__ args must have default values - def __init__(self, - name: str = '', # Function name - arguments: Optional[List[Argument]] = None, - body: Optional['Block'] = None, - typ: 'Optional[mypy.types.FunctionLike]' = None) -> None: + def __init__( + self, + name: str = "", # Function name + arguments: Optional[List[Argument]] = None, + body: Optional["Block"] = None, + typ: "Optional[mypy.types.FunctionLike]" = None, + ) -> None: super().__init__(arguments, body, typ) self._name = name self.is_decorated = False @@ -764,31 +797,36 @@ def serialize(self) -> JsonDict: # TODO: After a FuncDef is deserialized, the only time we use `arg_names` # and `arg_kinds` is when `type` is None and we need to infer a type. Can # we store the inferred type ahead of time? - return {'.class': 'FuncDef', - 'name': self._name, - 'fullname': self._fullname, - 'arg_names': self.arg_names, - 'arg_kinds': [int(x.value) for x in self.arg_kinds], - 'type': None if self.type is None else self.type.serialize(), - 'flags': get_flags(self, FUNCDEF_FLAGS), - # TODO: Do we need expanded, original_def? - } + return { + ".class": "FuncDef", + "name": self._name, + "fullname": self._fullname, + "arg_names": self.arg_names, + "arg_kinds": [int(x.value) for x in self.arg_kinds], + "type": None if self.type is None else self.type.serialize(), + "flags": get_flags(self, FUNCDEF_FLAGS), + # TODO: Do we need expanded, original_def? + } @classmethod - def deserialize(cls, data: JsonDict) -> 'FuncDef': - assert data['.class'] == 'FuncDef' + def deserialize(cls, data: JsonDict) -> "FuncDef": + assert data[".class"] == "FuncDef" body = Block([]) - ret = FuncDef(data['name'], - [], - body, - (None if data['type'] is None - else cast(mypy.types.FunctionLike, - mypy.types.deserialize_type(data['type'])))) - ret._fullname = data['fullname'] - set_flags(ret, data['flags']) + ret = FuncDef( + data["name"], + [], + body, + ( + None + if data["type"] is None + else cast(mypy.types.FunctionLike, mypy.types.deserialize_type(data["type"])) + ), + ) + ret._fullname = data["fullname"] + set_flags(ret, data["flags"]) # NOTE: ret.info is set in the fixup phase. - ret.arg_names = data['arg_names'] - ret.arg_kinds = [ArgKind(x) for x in data['arg_kinds']] + ret.arg_names = data["arg_names"] + ret.arg_kinds = [ArgKind(x) for x in data["arg_kinds"]] # Leave these uninitialized so that future uses will trigger an error del ret.arguments del ret.max_pos @@ -807,7 +845,7 @@ class Decorator(SymbolNode, Statement): A single Decorator object can include any number of function decorators. """ - __slots__ = ('func', 'decorators', 'original_decorators', 'var', 'is_overload') + __slots__ = ("func", "decorators", "original_decorators", "var", "is_overload") func: FuncDef # Decorated function decorators: List[Expression] # Decorators (may be empty) @@ -817,8 +855,7 @@ class Decorator(SymbolNode, Statement): var: "Var" # Represents the decorated function obj is_overload: bool - def __init__(self, func: FuncDef, decorators: List[Expression], - var: 'Var') -> None: + def __init__(self, func: FuncDef, decorators: List[Expression], var: "Var") -> None: super().__init__() self.func = func self.decorators = decorators @@ -839,39 +876,50 @@ def is_final(self) -> bool: return self.func.is_final @property - def info(self) -> 'TypeInfo': + def info(self) -> "TypeInfo": return self.func.info @property - def type(self) -> 'Optional[mypy.types.Type]': + def type(self) -> "Optional[mypy.types.Type]": return self.var.type def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_decorator(self) def serialize(self) -> JsonDict: - return {'.class': 'Decorator', - 'func': self.func.serialize(), - 'var': self.var.serialize(), - 'is_overload': self.is_overload, - } + return { + ".class": "Decorator", + "func": self.func.serialize(), + "var": self.var.serialize(), + "is_overload": self.is_overload, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'Decorator': - assert data['.class'] == 'Decorator' - dec = Decorator(FuncDef.deserialize(data['func']), - [], - Var.deserialize(data['var'])) - dec.is_overload = data['is_overload'] + def deserialize(cls, data: JsonDict) -> "Decorator": + assert data[".class"] == "Decorator" + dec = Decorator(FuncDef.deserialize(data["func"]), [], Var.deserialize(data["var"])) + dec.is_overload = data["is_overload"] return dec VAR_FLAGS: Final = [ - 'is_self', 'is_initialized_in_class', 'is_staticmethod', - 'is_classmethod', 'is_property', 'is_settable_property', 'is_suppressed_import', - 'is_classvar', 'is_abstract_var', 'is_final', 'final_unset_in_class', 'final_set_in_init', - 'explicit_self_type', 'is_ready', 'from_module_getattr', - 'has_explicit_value', 'allow_incompatible_override', + "is_self", + "is_initialized_in_class", + "is_staticmethod", + "is_classmethod", + "is_property", + "is_settable_property", + "is_suppressed_import", + "is_classvar", + "is_abstract_var", + "is_final", + "final_unset_in_class", + "final_set_in_init", + "explicit_self_type", + "is_ready", + "from_module_getattr", + "has_explicit_value", + "allow_incompatible_override", ] @@ -881,43 +929,44 @@ class Var(SymbolNode): It can refer to global/local variable or a data attribute. """ - __slots__ = ('_name', - '_fullname', - 'info', - 'type', - 'final_value', - 'is_self', - 'is_ready', - 'is_inferred', - 'is_initialized_in_class', - 'is_staticmethod', - 'is_classmethod', - 'is_property', - 'is_settable_property', - 'is_classvar', - 'is_abstract_var', - 'is_final', - 'final_unset_in_class', - 'final_set_in_init', - 'is_suppressed_import', - 'explicit_self_type', - 'from_module_getattr', - 'has_explicit_value', - 'allow_incompatible_override', - ) - - def __init__(self, name: str, type: 'Optional[mypy.types.Type]' = None) -> None: - super().__init__() - self._name = name # Name without module prefix + __slots__ = ( + "_name", + "_fullname", + "info", + "type", + "final_value", + "is_self", + "is_ready", + "is_inferred", + "is_initialized_in_class", + "is_staticmethod", + "is_classmethod", + "is_property", + "is_settable_property", + "is_classvar", + "is_abstract_var", + "is_final", + "final_unset_in_class", + "final_set_in_init", + "is_suppressed_import", + "explicit_self_type", + "from_module_getattr", + "has_explicit_value", + "allow_incompatible_override", + ) + + def __init__(self, name: str, type: "Optional[mypy.types.Type]" = None) -> None: + super().__init__() + self._name = name # Name without module prefix # TODO: Should be Optional[str] - self._fullname = cast('Bogus[str]', None) # Name with module prefix + self._fullname = cast("Bogus[str]", None) # Name with module prefix # TODO: Should be Optional[TypeInfo] self.info = VAR_NO_INFO self.type: Optional[mypy.types.Type] = type # Declared or inferred type, or None # Is this the first argument to an ordinary method (usually "self")? self.is_self = False self.is_ready = True # If inferred, is the inferred type available? - self.is_inferred = (self.type is None) + self.is_inferred = self.type is None # Is this initialized explicitly to a non-None value in class body? self.is_initialized_in_class = False self.is_staticmethod = False @@ -976,28 +1025,39 @@ def serialize(self) -> JsonDict: "flags": get_flags(self, VAR_FLAGS), } if self.final_value is not None: - data['final_value'] = self.final_value + data["final_value"] = self.final_value return data @classmethod - def deserialize(cls, data: JsonDict) -> 'Var': - assert data['.class'] == 'Var' - name = data['name'] - type = None if data['type'] is None else mypy.types.deserialize_type(data['type']) + def deserialize(cls, data: JsonDict) -> "Var": + assert data[".class"] == "Var" + name = data["name"] + type = None if data["type"] is None else mypy.types.deserialize_type(data["type"]) v = Var(name, type) v.is_ready = False # Override True default set in __init__ - v._fullname = data['fullname'] - set_flags(v, data['flags']) - v.final_value = data.get('final_value') + v._fullname = data["fullname"] + set_flags(v, data["flags"]) + v.final_value = data.get("final_value") return v class ClassDef(Statement): """Class definition""" - __slots__ = ('name', 'fullname', 'defs', 'type_vars', 'base_type_exprs', - 'removed_base_type_exprs', 'info', 'metaclass', 'decorators', - 'keywords', 'analyzed', 'has_incompatible_baseclass') + __slots__ = ( + "name", + "fullname", + "defs", + "type_vars", + "base_type_exprs", + "removed_base_type_exprs", + "info", + "metaclass", + "decorators", + "keywords", + "analyzed", + "has_incompatible_baseclass", + ) name: str # Name of the class without module prefix fullname: Bogus[str] # Fully qualified name of the class @@ -1014,13 +1074,15 @@ class ClassDef(Statement): analyzed: Optional[Expression] has_incompatible_baseclass: bool - def __init__(self, - name: str, - defs: 'Block', - type_vars: Optional[List['mypy.types.TypeVarLikeType']] = None, - base_type_exprs: Optional[List[Expression]] = None, - metaclass: Optional[Expression] = None, - keywords: Optional[List[Tuple[str, Expression]]] = None) -> None: + def __init__( + self, + name: str, + defs: "Block", + type_vars: Optional[List["mypy.types.TypeVarLikeType"]] = None, + base_type_exprs: Optional[List[Expression]] = None, + metaclass: Optional[Expression] = None, + keywords: Optional[List[Tuple[str, Expression]]] = None, + ) -> None: super().__init__() self.name = name self.fullname = None # type: ignore @@ -1044,29 +1106,33 @@ def is_generic(self) -> bool: def serialize(self) -> JsonDict: # Not serialized: defs, base_type_exprs, metaclass, decorators, # analyzed (for named tuples etc.) - return {'.class': 'ClassDef', - 'name': self.name, - 'fullname': self.fullname, - 'type_vars': [v.serialize() for v in self.type_vars], - } + return { + ".class": "ClassDef", + "name": self.name, + "fullname": self.fullname, + "type_vars": [v.serialize() for v in self.type_vars], + } @classmethod - def deserialize(self, data: JsonDict) -> 'ClassDef': - assert data['.class'] == 'ClassDef' - res = ClassDef(data['name'], - Block([]), - # https://github.com/python/mypy/issues/12257 - [cast(mypy.types.TypeVarLikeType, mypy.types.deserialize_type(v)) - for v in data['type_vars']], - ) - res.fullname = data['fullname'] + def deserialize(self, data: JsonDict) -> "ClassDef": + assert data[".class"] == "ClassDef" + res = ClassDef( + data["name"], + Block([]), + # https://github.com/python/mypy/issues/12257 + [ + cast(mypy.types.TypeVarLikeType, mypy.types.deserialize_type(v)) + for v in data["type_vars"] + ], + ) + res.fullname = data["fullname"] return res class GlobalDecl(Statement): """Declaration global x, y, ...""" - __slots__ = ('names',) + __slots__ = ("names",) names: List[str] @@ -1081,7 +1147,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class NonlocalDecl(Statement): """Declaration nonlocal x, y, ...""" - __slots__ = ('names',) + __slots__ = ("names",) names: List[str] @@ -1094,7 +1160,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class Block(Statement): - __slots__ = ('body', 'is_unreachable') + __slots__ = ("body", "is_unreachable") def __init__(self, body: List[Statement]) -> None: super().__init__() @@ -1116,7 +1182,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ExpressionStmt(Statement): """An expression as a statement, such as print(s).""" - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Expression @@ -1139,8 +1205,15 @@ class AssignmentStmt(Statement): An lvalue can be NameExpr, TupleExpr, ListExpr, MemberExpr, or IndexExpr. """ - __slots__ = ('lvalues', 'rvalue', 'type', 'unanalyzed_type', 'new_syntax', - 'is_alias_def', 'is_final_def') + __slots__ = ( + "lvalues", + "rvalue", + "type", + "unanalyzed_type", + "new_syntax", + "is_alias_def", + "is_final_def", + ) lvalues: List[Lvalue] # This is a TempNode if and only if no rvalue (x: t). @@ -1161,8 +1234,13 @@ class AssignmentStmt(Statement): # during type checking when MROs are known). is_final_def: bool - def __init__(self, lvalues: List[Lvalue], rvalue: Expression, - type: 'Optional[mypy.types.Type]' = None, new_syntax: bool = False) -> None: + def __init__( + self, + lvalues: List[Lvalue], + rvalue: Expression, + type: "Optional[mypy.types.Type]" = None, + new_syntax: bool = False, + ) -> None: super().__init__() self.lvalues = lvalues self.rvalue = rvalue @@ -1179,7 +1257,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class OperatorAssignmentStmt(Statement): """Operator assignment statement such as x += 1""" - __slots__ = ('op', 'lvalue', 'rvalue') + __slots__ = ("op", "lvalue", "rvalue") op: str # TODO: Enum? lvalue: Lvalue @@ -1196,7 +1274,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class WhileStmt(Statement): - __slots__ = ('expr', 'body', 'else_body') + __slots__ = ("expr", "body", "else_body") expr: Expression body: Block @@ -1213,9 +1291,17 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ForStmt(Statement): - __slots__ = ('index', 'index_type', 'unanalyzed_index_type', - 'inferred_item_type', 'inferred_iterator_type', - 'expr', 'body', 'else_body', 'is_async') + __slots__ = ( + "index", + "index_type", + "unanalyzed_index_type", + "inferred_item_type", + "inferred_iterator_type", + "expr", + "body", + "else_body", + "is_async", + ) # Index variables index: Lvalue @@ -1233,12 +1319,14 @@ class ForStmt(Statement): else_body: Optional[Block] is_async: bool # True if `async for ...` (PEP 492, Python 3.5) - def __init__(self, - index: Lvalue, - expr: Expression, - body: Block, - else_body: Optional[Block], - index_type: 'Optional[mypy.types.Type]' = None) -> None: + def __init__( + self, + index: Lvalue, + expr: Expression, + body: Block, + else_body: Optional[Block], + index_type: "Optional[mypy.types.Type]" = None, + ) -> None: super().__init__() self.index = index self.index_type = index_type @@ -1255,7 +1343,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ReturnStmt(Statement): - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Optional[Expression] @@ -1268,7 +1356,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class AssertStmt(Statement): - __slots__ = ('expr', 'msg') + __slots__ = ("expr", "msg") expr: Expression msg: Optional[Expression] @@ -1283,7 +1371,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class DelStmt(Statement): - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Lvalue @@ -1317,14 +1405,15 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class IfStmt(Statement): - __slots__ = ('expr', 'body', 'else_body') + __slots__ = ("expr", "body", "else_body") expr: List[Expression] body: List[Block] else_body: Optional[Block] - def __init__(self, expr: List[Expression], body: List[Block], - else_body: Optional[Block]) -> None: + def __init__( + self, expr: List[Expression], body: List[Block], else_body: Optional[Block] + ) -> None: super().__init__() self.expr = expr self.body = body @@ -1335,7 +1424,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class RaiseStmt(Statement): - __slots__ = ('expr', 'from_expr', 'legacy_mode') + __slots__ = ("expr", "from_expr", "legacy_mode") # Plain 'raise' is a valid statement. expr: Optional[Expression] @@ -1354,7 +1443,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class TryStmt(Statement): - __slots__ = ('body', 'types', 'vars', 'handlers', 'else_body', 'finally_body') + __slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body") body: Block # Try body # Plain 'except:' also possible @@ -1364,10 +1453,15 @@ class TryStmt(Statement): else_body: Optional[Block] finally_body: Optional[Block] - def __init__(self, body: Block, vars: List['Optional[NameExpr]'], - types: List[Optional[Expression]], - handlers: List[Block], else_body: Optional[Block], - finally_body: Optional[Block]) -> None: + def __init__( + self, + body: Block, + vars: List["Optional[NameExpr]"], + types: List[Optional[Expression]], + handlers: List[Block], + else_body: Optional[Block], + finally_body: Optional[Block], + ) -> None: super().__init__() self.body = body self.vars = vars @@ -1381,8 +1475,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class WithStmt(Statement): - __slots__ = ('expr', 'target', 'unanalyzed_type', - 'analyzed_types', 'body', 'is_async') + __slots__ = ("expr", "target", "unanalyzed_type", "analyzed_types", "body", "is_async") expr: List[Expression] target: List[Optional[Lvalue]] @@ -1393,8 +1486,13 @@ class WithStmt(Statement): body: Block is_async: bool # True if `async with ...` (PEP 492, Python 3.5) - def __init__(self, expr: List[Expression], target: List[Optional[Lvalue]], - body: Block, target_type: 'Optional[mypy.types.Type]' = None) -> None: + def __init__( + self, + expr: List[Expression], + target: List[Optional[Lvalue]], + body: Block, + target_type: "Optional[mypy.types.Type]" = None, + ) -> None: super().__init__() self.expr = expr self.target = target @@ -1409,12 +1507,17 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class MatchStmt(Statement): subject: Expression - patterns: List['Pattern'] + patterns: List["Pattern"] guards: List[Optional[Expression]] bodies: List[Block] - def __init__(self, subject: Expression, patterns: List['Pattern'], - guards: List[Optional[Expression]], bodies: List[Block]) -> None: + def __init__( + self, + subject: Expression, + patterns: List["Pattern"], + guards: List[Optional[Expression]], + bodies: List[Block], + ) -> None: super().__init__() assert len(patterns) == len(guards) == len(bodies) self.subject = subject @@ -1429,17 +1532,16 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class PrintStmt(Statement): """Python 2 print statement""" - __slots__ = ('args', 'newline', 'target') + __slots__ = ("args", "newline", "target") args: List[Expression] newline: bool # The file-like target object (given using >>). target: Optional[Expression] - def __init__(self, - args: List[Expression], - newline: bool, - target: Optional[Expression] = None) -> None: + def __init__( + self, args: List[Expression], newline: bool, target: Optional[Expression] = None + ) -> None: super().__init__() self.args = args self.newline = newline @@ -1452,15 +1554,15 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class ExecStmt(Statement): """Python 2 exec statement""" - __slots__ = ('expr', 'globals', 'locals') + __slots__ = ("expr", "globals", "locals") expr: Expression globals: Optional[Expression] locals: Optional[Expression] - def __init__(self, expr: Expression, - globals: Optional[Expression], - locals: Optional[Expression]) -> None: + def __init__( + self, expr: Expression, globals: Optional[Expression], locals: Optional[Expression] + ) -> None: super().__init__() self.expr = expr self.globals = globals @@ -1476,7 +1578,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T: class IntExpr(Expression): """Integer literal""" - __slots__ = ('value',) + __slots__ = ("value",) value: int # 0 by default @@ -1499,10 +1601,11 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: # 'x', u'x' -> StrExpr # UnicodeExpr is unused + class StrExpr(Expression): """String literal""" - __slots__ = ('value', 'from_python_3') + __slots__ = ("value", "from_python_3") value: str # '' by default @@ -1532,7 +1635,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class BytesExpr(Expression): """Bytes literal""" - __slots__ = ('value',) + __slots__ = ("value",) # Note: we deliberately do NOT use bytes here because it ends up # unnecessarily complicating a lot of the result logic. For example, @@ -1556,7 +1659,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class UnicodeExpr(Expression): """Unicode literal (Python 2.x)""" - __slots__ = ('value',) + __slots__ = ("value",) value: str @@ -1571,7 +1674,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class FloatExpr(Expression): """Float literal""" - __slots__ = ('value',) + __slots__ = ("value",) value: float # 0.0 by default @@ -1586,7 +1689,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ComplexExpr(Expression): """Complex literal""" - __slots__ = ('value',) + __slots__ = ("value",) value: complex @@ -1610,7 +1713,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class StarExpr(Expression): """Star expression""" - __slots__ = ('expr', 'valid') + __slots__ = ("expr", "valid") expr: Expression valid: bool @@ -1629,8 +1732,15 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class RefExpr(Expression): """Abstract base class for name-like constructs""" - __slots__ = ('kind', 'node', 'fullname', 'is_new_def', 'is_inferred_def', 'is_alias_rvalue', - 'type_guard') + __slots__ = ( + "kind", + "node", + "fullname", + "is_new_def", + "is_inferred_def", + "is_alias_rvalue", + "type_guard", + ) def __init__(self) -> None: super().__init__() @@ -1659,7 +1769,7 @@ class NameExpr(RefExpr): This refers to a local name, global name or a module. """ - __slots__ = ('name', 'is_special_form') + __slots__ = ("name", "is_special_form") def __init__(self, name: str) -> None: super().__init__() @@ -1677,7 +1787,7 @@ def serialize(self) -> JsonDict: class MemberExpr(RefExpr): """Member access expression x.y""" - __slots__ = ('expr', 'name', 'def_var') + __slots__ = ("expr", "name", "def_var") def __init__(self, expr: Expression, name: str) -> None: super().__init__() @@ -1708,18 +1818,10 @@ class ArgKind(Enum): ARG_NAMED_OPT = 5 def is_positional(self, star: bool = False) -> bool: - return ( - self == ARG_POS - or self == ARG_OPT - or (star and self == ARG_STAR) - ) + return self == ARG_POS or self == ARG_OPT or (star and self == ARG_STAR) def is_named(self, star: bool = False) -> bool: - return ( - self == ARG_NAMED - or self == ARG_NAMED_OPT - or (star and self == ARG_STAR2) - ) + return self == ARG_NAMED or self == ARG_NAMED_OPT or (star and self == ARG_STAR2) def is_required(self) -> bool: return self == ARG_POS or self == ARG_NAMED @@ -1746,14 +1848,16 @@ class CallExpr(Expression): such as cast(...) and None # type: .... """ - __slots__ = ('callee', 'args', 'arg_kinds', 'arg_names', 'analyzed') + __slots__ = ("callee", "args", "arg_kinds", "arg_names", "analyzed") - def __init__(self, - callee: Expression, - args: List[Expression], - arg_kinds: List[ArgKind], - arg_names: List[Optional[str]], - analyzed: Optional[Expression] = None) -> None: + def __init__( + self, + callee: Expression, + args: List[Expression], + arg_kinds: List[ArgKind], + arg_names: List[Optional[str]], + analyzed: Optional[Expression] = None, + ) -> None: super().__init__() if not arg_names: arg_names = [None] * len(args) @@ -1772,7 +1876,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class YieldFromExpr(Expression): - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Expression @@ -1785,7 +1889,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class YieldExpr(Expression): - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Optional[Expression] @@ -1803,7 +1907,7 @@ class IndexExpr(Expression): Also wraps type application such as List[int] as a special form. """ - __slots__ = ('base', 'index', 'method_type', 'analyzed') + __slots__ = ("base", "index", "method_type", "analyzed") base: Expression index: Expression @@ -1827,7 +1931,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class UnaryExpr(Expression): """Unary operation""" - __slots__ = ('op', 'expr', 'method_type') + __slots__ = ("op", "expr", "method_type") op: str # TODO: Enum? expr: Expression @@ -1847,7 +1951,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class AssignmentExpr(Expression): """Assignment expressions in Python 3.8+, like "a := 2".""" - __slots__ = ('target', 'value') + __slots__ = ("target", "value") def __init__(self, target: Expression, value: Expression) -> None: super().__init__() @@ -1862,8 +1966,7 @@ class OpExpr(Expression): """Binary operation (other than . or [] or comparison operators, which have specific nodes).""" - __slots__ = ('op', 'left', 'right', - 'method_type', 'right_always', 'right_unreachable') + __slots__ = ("op", "left", "right", "method_type", "right_always", "right_unreachable") op: str # TODO: Enum? left: Expression @@ -1891,7 +1994,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ComparisonExpr(Expression): """Comparison expression (e.g. a < b > c < d).""" - __slots__ = ('operators', 'operands', 'method_types') + __slots__ = ("operators", "operands", "method_types") operators: List[str] operands: List[Expression] @@ -1921,15 +2024,18 @@ class SliceExpr(Expression): This is only valid as index in index expressions. """ - __slots__ = ('begin_index', 'end_index', 'stride') + __slots__ = ("begin_index", "end_index", "stride") begin_index: Optional[Expression] end_index: Optional[Expression] stride: Optional[Expression] - def __init__(self, begin_index: Optional[Expression], - end_index: Optional[Expression], - stride: Optional[Expression]) -> None: + def __init__( + self, + begin_index: Optional[Expression], + end_index: Optional[Expression], + stride: Optional[Expression], + ) -> None: super().__init__() self.begin_index = begin_index self.end_index = end_index @@ -1942,12 +2048,12 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class CastExpr(Expression): """Cast expression cast(type, expr).""" - __slots__ = ('expr', 'type') + __slots__ = ("expr", "type") expr: Expression type: "mypy.types.Type" - def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None: + def __init__(self, expr: Expression, typ: "mypy.types.Type") -> None: super().__init__() self.expr = expr self.type = typ @@ -1958,12 +2064,13 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class AssertTypeExpr(Expression): """Represents a typing.assert_type(expr, type) call.""" - __slots__ = ('expr', 'type') + + __slots__ = ("expr", "type") expr: Expression type: "mypy.types.Type" - def __init__(self, expr: Expression, typ: 'mypy.types.Type') -> None: + def __init__(self, expr: Expression, typ: "mypy.types.Type") -> None: super().__init__() self.expr = expr self.type = typ @@ -1975,16 +2082,18 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class RevealExpr(Expression): """Reveal type expression reveal_type(expr) or reveal_locals() expression.""" - __slots__ = ('expr', 'kind', 'local_nodes') + __slots__ = ("expr", "kind", "local_nodes") expr: Optional[Expression] kind: int local_nodes: Optional[List[Var]] def __init__( - self, kind: int, - expr: Optional[Expression] = None, - local_nodes: 'Optional[List[Var]]' = None) -> None: + self, + kind: int, + expr: Optional[Expression] = None, + local_nodes: "Optional[List[Var]]" = None, + ) -> None: super().__init__() self.expr = expr self.kind = kind @@ -1997,7 +2106,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class SuperExpr(Expression): """Expression super().name""" - __slots__ = ('name', 'info', 'call') + __slots__ = ("name", "info", "call") name: str info: Optional["TypeInfo"] # Type that contains this super expression @@ -2018,7 +2127,7 @@ class LambdaExpr(FuncItem, Expression): @property def name(self) -> str: - return '' + return "" def expr(self) -> Expression: """Return the expression (the body) of the lambda.""" @@ -2037,7 +2146,7 @@ def is_dynamic(self) -> bool: class ListExpr(Expression): """List literal expression [...].""" - __slots__ = ('items',) + __slots__ = ("items",) items: List[Expression] @@ -2052,7 +2161,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class DictExpr(Expression): """Dictionary literal expression {key: value, ...}.""" - __slots__ = ('items',) + __slots__ = ("items",) items: List[Tuple[Optional[Expression], Expression]] @@ -2069,7 +2178,7 @@ class TupleExpr(Expression): Also lvalue sequences (..., ...) and [..., ...]""" - __slots__ = ('items',) + __slots__ = ("items",) items: List[Expression] @@ -2084,7 +2193,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class SetExpr(Expression): """Set literal expression {value, ...}.""" - __slots__ = ('items',) + __slots__ = ("items",) items: List[Expression] @@ -2099,7 +2208,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class GeneratorExpr(Expression): """Generator expression ... for ... in ... [ for ... in ... ] [ if ... ].""" - __slots__ = ('left_expr', 'sequences', 'condlists', 'is_async', 'indices') + __slots__ = ("left_expr", "sequences", "condlists", "is_async", "indices") left_expr: Expression sequences: List[Expression] @@ -2107,9 +2216,14 @@ class GeneratorExpr(Expression): is_async: List[bool] indices: List[Lvalue] - def __init__(self, left_expr: Expression, indices: List[Lvalue], - sequences: List[Expression], condlists: List[List[Expression]], - is_async: List[bool]) -> None: + def __init__( + self, + left_expr: Expression, + indices: List[Lvalue], + sequences: List[Expression], + condlists: List[List[Expression]], + is_async: List[bool], + ) -> None: super().__init__() self.left_expr = left_expr self.sequences = sequences @@ -2124,7 +2238,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ListComprehension(Expression): """List comprehension (e.g. [x + 1 for x in a])""" - __slots__ = ('generator',) + __slots__ = ("generator",) generator: GeneratorExpr @@ -2139,7 +2253,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class SetComprehension(Expression): """Set comprehension (e.g. {x + 1 for x in a})""" - __slots__ = ('generator',) + __slots__ = ("generator",) generator: GeneratorExpr @@ -2154,7 +2268,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class DictionaryComprehension(Expression): """Dictionary comprehension (e.g. {k: v for k, v in a}""" - __slots__ = ('key', 'value', 'sequences', 'condlists', 'is_async', 'indices') + __slots__ = ("key", "value", "sequences", "condlists", "is_async", "indices") key: Expression value: Expression @@ -2163,9 +2277,15 @@ class DictionaryComprehension(Expression): is_async: List[bool] indices: List[Lvalue] - def __init__(self, key: Expression, value: Expression, indices: List[Lvalue], - sequences: List[Expression], condlists: List[List[Expression]], - is_async: List[bool]) -> None: + def __init__( + self, + key: Expression, + value: Expression, + indices: List[Lvalue], + sequences: List[Expression], + condlists: List[List[Expression]], + is_async: List[bool], + ) -> None: super().__init__() self.key = key self.value = value @@ -2181,7 +2301,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class ConditionalExpr(Expression): """Conditional expression (e.g. x if y else z)""" - __slots__ = ('cond', 'if_expr', 'else_expr') + __slots__ = ("cond", "if_expr", "else_expr") cond: Expression if_expr: Expression @@ -2200,7 +2320,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class BackquoteExpr(Expression): """Python 2 expression `...`.""" - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Expression @@ -2215,12 +2335,12 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class TypeApplication(Expression): """Type application expr[type, ...]""" - __slots__ = ('expr', 'types') + __slots__ = ("expr", "types") expr: Expression types: List["mypy.types.Type"] - def __init__(self, expr: Expression, types: List['mypy.types.Type']) -> None: + def __init__(self, expr: Expression, types: List["mypy.types.Type"]) -> None: super().__init__() self.expr = expr self.types = types @@ -2249,7 +2369,7 @@ class TypeVarLikeExpr(SymbolNode, Expression): Note that they are constructed by the semantic analyzer. """ - __slots__ = ('_name', '_fullname', 'upper_bound', 'variance') + __slots__ = ("_name", "_fullname", "upper_bound", "variance") _name: str _fullname: str @@ -2263,7 +2383,7 @@ class TypeVarLikeExpr(SymbolNode, Expression): variance: int def __init__( - self, name: str, fullname: str, upper_bound: 'mypy.types.Type', variance: int = INVARIANT + self, name: str, fullname: str, upper_bound: "mypy.types.Type", variance: int = INVARIANT ) -> None: super().__init__() self._name = name @@ -2292,16 +2412,20 @@ class TypeVarExpr(TypeVarLikeExpr): 2. a generic function that refers to the type variable in its signature. """ - __slots__ = ('values',) + __slots__ = ("values",) # Value restriction: only types in the list are valid as values. If the # list is empty, there is no restriction. values: List["mypy.types.Type"] - def __init__(self, name: str, fullname: str, - values: List['mypy.types.Type'], - upper_bound: 'mypy.types.Type', - variance: int = INVARIANT) -> None: + def __init__( + self, + name: str, + fullname: str, + values: List["mypy.types.Type"], + upper_bound: "mypy.types.Type", + variance: int = INVARIANT, + ) -> None: super().__init__(name, fullname, upper_bound, variance) self.values = values @@ -2309,22 +2433,25 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_type_var_expr(self) def serialize(self) -> JsonDict: - return {'.class': 'TypeVarExpr', - 'name': self._name, - 'fullname': self._fullname, - 'values': [t.serialize() for t in self.values], - 'upper_bound': self.upper_bound.serialize(), - 'variance': self.variance, - } + return { + ".class": "TypeVarExpr", + "name": self._name, + "fullname": self._fullname, + "values": [t.serialize() for t in self.values], + "upper_bound": self.upper_bound.serialize(), + "variance": self.variance, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeVarExpr': - assert data['.class'] == 'TypeVarExpr' - return TypeVarExpr(data['name'], - data['fullname'], - [mypy.types.deserialize_type(v) for v in data['values']], - mypy.types.deserialize_type(data['upper_bound']), - data['variance']) + def deserialize(cls, data: JsonDict) -> "TypeVarExpr": + assert data[".class"] == "TypeVarExpr" + return TypeVarExpr( + data["name"], + data["fullname"], + [mypy.types.deserialize_type(v) for v in data["values"]], + mypy.types.deserialize_type(data["upper_bound"]), + data["variance"], + ) class ParamSpecExpr(TypeVarLikeExpr): @@ -2335,21 +2462,21 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: def serialize(self) -> JsonDict: return { - '.class': 'ParamSpecExpr', - 'name': self._name, - 'fullname': self._fullname, - 'upper_bound': self.upper_bound.serialize(), - 'variance': self.variance, + ".class": "ParamSpecExpr", + "name": self._name, + "fullname": self._fullname, + "upper_bound": self.upper_bound.serialize(), + "variance": self.variance, } @classmethod - def deserialize(cls, data: JsonDict) -> 'ParamSpecExpr': - assert data['.class'] == 'ParamSpecExpr' + def deserialize(cls, data: JsonDict) -> "ParamSpecExpr": + assert data[".class"] == "ParamSpecExpr" return ParamSpecExpr( - data['name'], - data['fullname'], - mypy.types.deserialize_type(data['upper_bound']), - data['variance'] + data["name"], + data["fullname"], + mypy.types.deserialize_type(data["upper_bound"]), + data["variance"], ) @@ -2363,28 +2490,28 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: def serialize(self) -> JsonDict: return { - '.class': 'TypeVarTupleExpr', - 'name': self._name, - 'fullname': self._fullname, - 'upper_bound': self.upper_bound.serialize(), - 'variance': self.variance, + ".class": "TypeVarTupleExpr", + "name": self._name, + "fullname": self._fullname, + "upper_bound": self.upper_bound.serialize(), + "variance": self.variance, } @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeVarTupleExpr': - assert data['.class'] == 'TypeVarTupleExpr' + def deserialize(cls, data: JsonDict) -> "TypeVarTupleExpr": + assert data[".class"] == "TypeVarTupleExpr" return TypeVarTupleExpr( - data['name'], - data['fullname'], - mypy.types.deserialize_type(data['upper_bound']), - data['variance'] + data["name"], + data["fullname"], + mypy.types.deserialize_type(data["upper_bound"]), + data["variance"], ) class TypeAliasExpr(Expression): """Type alias expression (rvalue).""" - __slots__ = ('type', 'tvars', 'no_args', 'node') + __slots__ = ("type", "tvars", "no_args", "node") # The target type. type: "mypy.types.Type" @@ -2396,9 +2523,9 @@ class TypeAliasExpr(Expression): # and # A = List[Any] no_args: bool - node: 'TypeAlias' + node: "TypeAlias" - def __init__(self, node: 'TypeAlias') -> None: + def __init__(self, node: "TypeAlias") -> None: super().__init__() self.type = node.target self.tvars = node.alias_tvars @@ -2412,14 +2539,14 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class NamedTupleExpr(Expression): """Named tuple expression namedtuple(...) or NamedTuple(...).""" - __slots__ = ('info', 'is_typed') + __slots__ = ("info", "is_typed") # The class representation of this named tuple (its tuple_type attribute contains # the tuple item types) info: "TypeInfo" is_typed: bool # whether this class was created with typing(_extensions).NamedTuple - def __init__(self, info: 'TypeInfo', is_typed: bool = False) -> None: + def __init__(self, info: "TypeInfo", is_typed: bool = False) -> None: super().__init__() self.info = info self.is_typed = is_typed @@ -2431,12 +2558,12 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class TypedDictExpr(Expression): """Typed dict expression TypedDict(...).""" - __slots__ = ('info',) + __slots__ = ("info",) # The class representation of this typed dict info: "TypeInfo" - def __init__(self, info: 'TypeInfo') -> None: + def __init__(self, info: "TypeInfo") -> None: super().__init__() self.info = info @@ -2447,7 +2574,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class EnumCallExpr(Expression): """Named tuple expression Enum('name', 'val1 val2 ...').""" - __slots__ = ('info', 'items', 'values') + __slots__ = ("info", "items", "values") # The class representation of this enumerated type info: "TypeInfo" @@ -2455,8 +2582,9 @@ class EnumCallExpr(Expression): items: List[str] values: List[Optional[Expression]] - def __init__(self, info: 'TypeInfo', items: List[str], - values: List[Optional[Expression]]) -> None: + def __init__( + self, info: "TypeInfo", items: List[str], values: List[Optional[Expression]] + ) -> None: super().__init__() self.info = info self.items = items @@ -2469,11 +2597,11 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class PromoteExpr(Expression): """Ducktype class decorator expression _promote(...).""" - __slots__ = ('type',) + __slots__ = ("type",) type: "mypy.types.Type" - def __init__(self, type: 'mypy.types.Type') -> None: + def __init__(self, type: "mypy.types.Type") -> None: super().__init__() self.type = type @@ -2484,7 +2612,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class NewTypeExpr(Expression): """NewType expression NewType(...).""" - __slots__ = ('name', 'old_type', 'info') + __slots__ = ("name", "old_type", "info") name: str # The base type (the second argument to NewType) @@ -2492,8 +2620,9 @@ class NewTypeExpr(Expression): # The synthesized class representing the new type (inherits old_type) info: Optional["TypeInfo"] - def __init__(self, name: str, old_type: 'Optional[mypy.types.Type]', line: int, - column: int) -> None: + def __init__( + self, name: str, old_type: "Optional[mypy.types.Type]", line: int, column: int + ) -> None: super().__init__(line=line, column=column) self.name = name self.old_type = old_type @@ -2506,7 +2635,7 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class AwaitExpr(Expression): """Await expression (await ...).""" - __slots__ = ('expr',) + __slots__ = ("expr",) expr: Expression @@ -2529,18 +2658,16 @@ class TempNode(Expression): some fixed type. """ - __slots__ = ('type', 'no_rhs') + __slots__ = ("type", "no_rhs") type: "mypy.types.Type" # Is this TempNode used to indicate absence of a right hand side in an annotated assignment? # (e.g. for 'x: int' the rvalue is TempNode(AnyType(TypeOfAny.special_form), no_rhs=True)) no_rhs: bool - def __init__(self, - typ: 'mypy.types.Type', - no_rhs: bool = False, - *, - context: Optional[Context] = None) -> None: + def __init__( + self, typ: "mypy.types.Type", no_rhs: bool = False, *, context: Optional[Context] = None + ) -> None: """Construct a dummy node; optionally borrow line/column from context object.""" super().__init__() self.type = typ @@ -2550,7 +2677,7 @@ def __init__(self, self.column = context.column def __repr__(self) -> str: - return 'TempNode:%d(%s)' % (self.line, str(self.type)) + return "TempNode:%d(%s)" % (self.line, str(self.type)) def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_temp_node(self) @@ -2570,14 +2697,41 @@ class is generic then it will be a type constructor of higher kind. """ __slots__ = ( - '_fullname', 'module_name', 'defn', 'mro', '_mro_refs', 'bad_mro', 'is_final', - 'declared_metaclass', 'metaclass_type', 'names', 'is_abstract', - 'is_protocol', 'runtime_protocol', 'abstract_attributes', - 'deletable_attributes', 'slots', 'assuming', 'assuming_proper', - 'inferring', 'is_enum', 'fallback_to_any', 'type_vars', 'has_param_spec_type', - 'bases', '_promote', 'tuple_type', 'is_named_tuple', 'typeddict_type', - 'is_newtype', 'is_intersection', 'metadata', 'alt_promote', - 'has_type_var_tuple_type', 'type_var_tuple_prefix', 'type_var_tuple_suffix' + "_fullname", + "module_name", + "defn", + "mro", + "_mro_refs", + "bad_mro", + "is_final", + "declared_metaclass", + "metaclass_type", + "names", + "is_abstract", + "is_protocol", + "runtime_protocol", + "abstract_attributes", + "deletable_attributes", + "slots", + "assuming", + "assuming_proper", + "inferring", + "is_enum", + "fallback_to_any", + "type_vars", + "has_param_spec_type", + "bases", + "_promote", + "tuple_type", + "is_named_tuple", + "typeddict_type", + "is_newtype", + "is_intersection", + "metadata", + "alt_promote", + "has_type_var_tuple_type", + "type_var_tuple_prefix", + "type_var_tuple_suffix", ) _fullname: Bogus[str] # Fully qualified name @@ -2599,9 +2753,9 @@ class is generic then it will be a type constructor of higher kind. metaclass_type: Optional["mypy.types.Instance"] names: "SymbolTable" # Names defined directly in this type - is_abstract: bool # Does the class have any abstract attributes? - is_protocol: bool # Is this a protocol class? - runtime_protocol: bool # Does this protocol support isinstance checks? + is_abstract: bool # Does the class have any abstract attributes? + is_protocol: bool # Is this a protocol class? + runtime_protocol: bool # Does this protocol support isinstance checks? abstract_attributes: List[str] deletable_attributes: List[str] # Used by mypyc only # Does this type have concrete `__slots__` defined? @@ -2706,12 +2860,18 @@ class is generic then it will be a type constructor of higher kind. metadata: Dict[str, JsonDict] FLAGS: Final = [ - 'is_abstract', 'is_enum', 'fallback_to_any', 'is_named_tuple', - 'is_newtype', 'is_protocol', 'runtime_protocol', 'is_final', - 'is_intersection', + "is_abstract", + "is_enum", + "fallback_to_any", + "is_named_tuple", + "is_newtype", + "is_protocol", + "runtime_protocol", + "is_final", + "is_intersection", ] - def __init__(self, names: 'SymbolTable', defn: ClassDef, module_name: str) -> None: + def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> None: """Initialize a TypeInfo.""" super().__init__() self._fullname = defn.fullname @@ -2779,14 +2939,14 @@ def is_generic(self) -> bool: """Is the type generic (i.e. does it have type variables)?""" return len(self.type_vars) > 0 - def get(self, name: str) -> 'Optional[SymbolTableNode]': + def get(self, name: str) -> "Optional[SymbolTableNode]": for cls in self.mro: n = cls.names.get(name) if n: return n return None - def get_containing_type_info(self, name: str) -> 'Optional[TypeInfo]': + def get_containing_type_info(self, name: str) -> "Optional[TypeInfo]": for cls in self.mro: if name in cls.names: return cls @@ -2804,7 +2964,7 @@ def protocol_members(self) -> List[str]: members.add(name) return sorted(list(members)) - def __getitem__(self, name: str) -> 'SymbolTableNode': + def __getitem__(self, name: str) -> "SymbolTableNode": n = self.get(name) if n: return n @@ -2812,7 +2972,7 @@ def __getitem__(self, name: str) -> 'SymbolTableNode': raise KeyError(name) def __repr__(self) -> str: - return f'' + return f"" def __bool__(self) -> bool: # We defined this here instead of just overriding it in @@ -2835,24 +2995,28 @@ def get_method(self, name: str) -> Union[FuncBase, Decorator, None]: return None return None - def calculate_metaclass_type(self) -> 'Optional[mypy.types.Instance]': + def calculate_metaclass_type(self) -> "Optional[mypy.types.Instance]": declared = self.declared_metaclass - if declared is not None and not declared.type.has_base('builtins.type'): + if declared is not None and not declared.type.has_base("builtins.type"): return declared - if self._fullname == 'builtins.type': + if self._fullname == "builtins.type": return mypy.types.Instance(self, []) - candidates = [s.declared_metaclass - for s in self.mro - if s.declared_metaclass is not None - and s.declared_metaclass.type is not None] + candidates = [ + s.declared_metaclass + for s in self.mro + if s.declared_metaclass is not None and s.declared_metaclass.type is not None + ] for c in candidates: if all(other.type in c.type.mro for other in candidates): return c return None def is_metaclass(self) -> bool: - return (self.has_base('builtins.type') or self.fullname == 'abc.ABCMeta' or - self.fallback_to_any) + return ( + self.has_base("builtins.type") + or self.fullname == "abc.ABCMeta" + or self.fallback_to_any + ) def has_base(self, fullname: str) -> bool: """Return True if type has a base type with the specified name. @@ -2864,7 +3028,7 @@ def has_base(self, fullname: str) -> bool: return True return False - def direct_base_classes(self) -> 'List[TypeInfo]': + def direct_base_classes(self) -> "List[TypeInfo]": """Return a direct base classes. Omit base classes of other base classes. @@ -2878,90 +3042,95 @@ def __str__(self) -> str: """ return self.dump() - def dump(self, - str_conv: 'Optional[mypy.strconv.StrConv]' = None, - type_str_conv: 'Optional[mypy.types.TypeStrVisitor]' = None) -> str: + def dump( + self, + str_conv: "Optional[mypy.strconv.StrConv]" = None, + type_str_conv: "Optional[mypy.types.TypeStrVisitor]" = None, + ) -> str: """Return a string dump of the contents of the TypeInfo.""" if not str_conv: str_conv = mypy.strconv.StrConv() base: str = "" - def type_str(typ: 'mypy.types.Type') -> str: + def type_str(typ: "mypy.types.Type") -> str: if type_str_conv: return typ.accept(type_str_conv) return str(typ) - head = 'TypeInfo' + str_conv.format_id(self) + head = "TypeInfo" + str_conv.format_id(self) if self.bases: base = f"Bases({', '.join(type_str(base) for base in self.bases)})" - mro = 'Mro({})'.format(', '.join(item.fullname + str_conv.format_id(item) - for item in self.mro)) + mro = "Mro({})".format( + ", ".join(item.fullname + str_conv.format_id(item) for item in self.mro) + ) names = [] for name in sorted(self.names): description = name + str_conv.format_id(self.names[name].node) node = self.names[name].node if isinstance(node, Var) and node.type: - description += f' ({type_str(node.type)})' + description += f" ({type_str(node.type)})" names.append(description) - items = [ - f'Name({self.fullname})', - base, - mro, - ('Names', names), - ] + items = [f"Name({self.fullname})", base, mro, ("Names", names)] if self.declared_metaclass: - items.append(f'DeclaredMetaclass({type_str(self.declared_metaclass)})') + items.append(f"DeclaredMetaclass({type_str(self.declared_metaclass)})") if self.metaclass_type: - items.append(f'MetaclassType({type_str(self.metaclass_type)})') - return mypy.strconv.dump_tagged( - items, - head, - str_conv=str_conv) + items.append(f"MetaclassType({type_str(self.metaclass_type)})") + return mypy.strconv.dump_tagged(items, head, str_conv=str_conv) def serialize(self) -> JsonDict: # NOTE: This is where all ClassDefs originate, so there shouldn't be duplicates. - data = {'.class': 'TypeInfo', - 'module_name': self.module_name, - 'fullname': self.fullname, - 'names': self.names.serialize(self.fullname), - 'defn': self.defn.serialize(), - 'abstract_attributes': self.abstract_attributes, - 'type_vars': self.type_vars, - 'has_param_spec_type': self.has_param_spec_type, - 'bases': [b.serialize() for b in self.bases], - 'mro': [c.fullname for c in self.mro], - '_promote': [p.serialize() for p in self._promote], - 'declared_metaclass': (None if self.declared_metaclass is None - else self.declared_metaclass.serialize()), - 'metaclass_type': - None if self.metaclass_type is None else self.metaclass_type.serialize(), - 'tuple_type': None if self.tuple_type is None else self.tuple_type.serialize(), - 'typeddict_type': - None if self.typeddict_type is None else self.typeddict_type.serialize(), - 'flags': get_flags(self, TypeInfo.FLAGS), - 'metadata': self.metadata, - 'slots': list(sorted(self.slots)) if self.slots is not None else None, - 'deletable_attributes': self.deletable_attributes, - } + data = { + ".class": "TypeInfo", + "module_name": self.module_name, + "fullname": self.fullname, + "names": self.names.serialize(self.fullname), + "defn": self.defn.serialize(), + "abstract_attributes": self.abstract_attributes, + "type_vars": self.type_vars, + "has_param_spec_type": self.has_param_spec_type, + "bases": [b.serialize() for b in self.bases], + "mro": [c.fullname for c in self.mro], + "_promote": [p.serialize() for p in self._promote], + "declared_metaclass": ( + None if self.declared_metaclass is None else self.declared_metaclass.serialize() + ), + "metaclass_type": None + if self.metaclass_type is None + else self.metaclass_type.serialize(), + "tuple_type": None if self.tuple_type is None else self.tuple_type.serialize(), + "typeddict_type": None + if self.typeddict_type is None + else self.typeddict_type.serialize(), + "flags": get_flags(self, TypeInfo.FLAGS), + "metadata": self.metadata, + "slots": list(sorted(self.slots)) if self.slots is not None else None, + "deletable_attributes": self.deletable_attributes, + } return data @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeInfo': - names = SymbolTable.deserialize(data['names']) - defn = ClassDef.deserialize(data['defn']) - module_name = data['module_name'] + def deserialize(cls, data: JsonDict) -> "TypeInfo": + names = SymbolTable.deserialize(data["names"]) + defn = ClassDef.deserialize(data["defn"]) + module_name = data["module_name"] ti = TypeInfo(names, defn, module_name) - ti._fullname = data['fullname'] + ti._fullname = data["fullname"] # TODO: Is there a reason to reconstruct ti.subtypes? - ti.abstract_attributes = data['abstract_attributes'] - ti.type_vars = data['type_vars'] - ti.has_param_spec_type = data['has_param_spec_type'] - ti.bases = [mypy.types.Instance.deserialize(b) for b in data['bases']] - ti._promote = [mypy.types.deserialize_type(p) for p in data['_promote']] - ti.declared_metaclass = (None if data['declared_metaclass'] is None - else mypy.types.Instance.deserialize(data['declared_metaclass'])) - ti.metaclass_type = (None if data['metaclass_type'] is None - else mypy.types.Instance.deserialize(data['metaclass_type'])) + ti.abstract_attributes = data["abstract_attributes"] + ti.type_vars = data["type_vars"] + ti.has_param_spec_type = data["has_param_spec_type"] + ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]] + ti._promote = [mypy.types.deserialize_type(p) for p in data["_promote"]] + ti.declared_metaclass = ( + None + if data["declared_metaclass"] is None + else mypy.types.Instance.deserialize(data["declared_metaclass"]) + ) + ti.metaclass_type = ( + None + if data["metaclass_type"] is None + else mypy.types.Instance.deserialize(data["metaclass_type"]) + ) # NOTE: ti.mro will be set in the fixup phase based on these # names. The reason we need to store the mro instead of just # recomputing it from base classes has to do with a subtle @@ -2972,21 +3141,27 @@ def deserialize(cls, data: JsonDict) -> 'TypeInfo': # way to detect that the mro has changed! Thus we need to make # sure to load the original mro so that once the class is # rechecked, it can tell that the mro has changed. - ti._mro_refs = data['mro'] - ti.tuple_type = (None if data['tuple_type'] is None - else mypy.types.TupleType.deserialize(data['tuple_type'])) - ti.typeddict_type = (None if data['typeddict_type'] is None - else mypy.types.TypedDictType.deserialize(data['typeddict_type'])) - ti.metadata = data['metadata'] - ti.slots = set(data['slots']) if data['slots'] is not None else None - ti.deletable_attributes = data['deletable_attributes'] - set_flags(ti, data['flags']) + ti._mro_refs = data["mro"] + ti.tuple_type = ( + None + if data["tuple_type"] is None + else mypy.types.TupleType.deserialize(data["tuple_type"]) + ) + ti.typeddict_type = ( + None + if data["typeddict_type"] is None + else mypy.types.TypedDictType.deserialize(data["typeddict_type"]) + ) + ti.metadata = data["metadata"] + ti.slots = set(data["slots"]) if data["slots"] is not None else None + ti.deletable_attributes = data["deletable_attributes"] + set_flags(ti, data["flags"]) return ti class FakeInfo(TypeInfo): - __slots__ = ('msg',) + __slots__ = ("msg",) # types.py defines a single instance of this class, called types.NOT_READY. # This instance is used as a temporary placeholder in the process of de-serialization @@ -3013,9 +3188,9 @@ def __init__(self, msg: str) -> None: def __getattribute__(self, attr: str) -> None: # Handle __class__ so that isinstance still works... - if attr == '__class__': + if attr == "__class__": return object.__getattribute__(self, attr) - raise AssertionError(object.__getattribute__(self, 'msg')) + raise AssertionError(object.__getattribute__(self, "msg")) VAR_NO_INFO: Final[TypeInfo] = FakeInfo("Var is lacking info") @@ -3112,15 +3287,29 @@ def f(x: B[T]) -> T: ... # without T, Any would be used here eager: If True, immediately expand alias when referred to (useful for aliases within functions that can't be looked up from the symbol table) """ - __slots__ = ('target', '_fullname', 'alias_tvars', 'no_args', 'normalized', - '_is_recursive', 'eager') - - def __init__(self, target: 'mypy.types.Type', fullname: str, line: int, column: int, - *, - alias_tvars: Optional[List[str]] = None, - no_args: bool = False, - normalized: bool = False, - eager: bool = False) -> None: + + __slots__ = ( + "target", + "_fullname", + "alias_tvars", + "no_args", + "normalized", + "_is_recursive", + "eager", + ) + + def __init__( + self, + target: "mypy.types.Type", + fullname: str, + line: int, + column: int, + *, + alias_tvars: Optional[List[str]] = None, + no_args: bool = False, + normalized: bool = False, + eager: bool = False, + ) -> None: self._fullname = fullname self.target = target if alias_tvars is None: @@ -3136,7 +3325,7 @@ def __init__(self, target: 'mypy.types.Type', fullname: str, line: int, column: @property def name(self) -> str: - return self._fullname.split('.')[-1] + return self._fullname.split(".")[-1] @property def fullname(self) -> str: @@ -3159,17 +3348,24 @@ def accept(self, visitor: NodeVisitor[T]) -> T: return visitor.visit_type_alias(self) @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeAlias': - assert data['.class'] == 'TypeAlias' - fullname = data['fullname'] - alias_tvars = data['alias_tvars'] - target = mypy.types.deserialize_type(data['target']) - no_args = data['no_args'] - normalized = data['normalized'] - line = data['line'] - column = data['column'] - return cls(target, fullname, line, column, alias_tvars=alias_tvars, - no_args=no_args, normalized=normalized) + def deserialize(cls, data: JsonDict) -> "TypeAlias": + assert data[".class"] == "TypeAlias" + fullname = data["fullname"] + alias_tvars = data["alias_tvars"] + target = mypy.types.deserialize_type(data["target"]) + no_args = data["no_args"] + normalized = data["normalized"] + line = data["line"] + column = data["column"] + return cls( + target, + fullname, + line, + column, + alias_tvars=alias_tvars, + no_args=no_args, + normalized=normalized, + ) class PlaceholderNode(SymbolNode): @@ -3221,10 +3417,11 @@ class C(Sequence[C]): ... something that can support general recursive types. """ - __slots__ = ('_fullname', 'node', 'becomes_typeinfo') + __slots__ = ("_fullname", "node", "becomes_typeinfo") - def __init__(self, fullname: str, node: Node, line: int, *, - becomes_typeinfo: bool = False) -> None: + def __init__( + self, fullname: str, node: Node, line: int, *, becomes_typeinfo: bool = False + ) -> None: self._fullname = fullname self.node = node self.becomes_typeinfo = becomes_typeinfo @@ -3232,7 +3429,7 @@ def __init__(self, fullname: str, node: Node, line: int, *, @property def name(self) -> str: - return self._fullname.split('.')[-1] + return self._fullname.split(".")[-1] @property def fullname(self) -> str: @@ -3305,25 +3502,28 @@ class SymbolTableNode: are shared by all node kinds. """ - __slots__ = ('kind', - 'node', - 'module_public', - 'module_hidden', - 'cross_ref', - 'implicit', - 'plugin_generated', - 'no_serialize', - ) - - def __init__(self, - kind: int, - node: Optional[SymbolNode], - module_public: bool = True, - implicit: bool = False, - module_hidden: bool = False, - *, - plugin_generated: bool = False, - no_serialize: bool = False) -> None: + __slots__ = ( + "kind", + "node", + "module_public", + "module_hidden", + "cross_ref", + "implicit", + "plugin_generated", + "no_serialize", + ) + + def __init__( + self, + kind: int, + node: Optional[SymbolNode], + module_public: bool = True, + implicit: bool = False, + module_hidden: bool = False, + *, + plugin_generated: bool = False, + no_serialize: bool = False, + ) -> None: self.kind = kind self.node = node self.module_public = module_public @@ -3341,7 +3541,7 @@ def fullname(self) -> Optional[str]: return None @property - def type(self) -> 'Optional[mypy.types.Type]': + def type(self) -> "Optional[mypy.types.Type]": node = self.node if isinstance(node, (Var, SYMBOL_FUNCBASE_TYPES)) and node.type is not None: return node.type @@ -3350,22 +3550,20 @@ def type(self) -> 'Optional[mypy.types.Type]': else: return None - def copy(self) -> 'SymbolTableNode': - new = SymbolTableNode(self.kind, - self.node, - self.module_public, - self.implicit, - self.module_hidden) + def copy(self) -> "SymbolTableNode": + new = SymbolTableNode( + self.kind, self.node, self.module_public, self.implicit, self.module_hidden + ) new.cross_ref = self.cross_ref return new def __str__(self) -> str: - s = f'{node_kinds[self.kind]}/{short_type(self.node)}' + s = f"{node_kinds[self.kind]}/{short_type(self.node)}" if isinstance(self.node, SymbolNode): - s += f' ({self.node.fullname})' + s += f" ({self.node.fullname})" # Include declared type of variables and functions. if self.type is not None: - s += f' : {self.type}' + s += f" : {self.type}" return s def serialize(self, prefix: str, name: str) -> JsonDict: @@ -3375,56 +3573,55 @@ def serialize(self, prefix: str, name: str) -> JsonDict: prefix: full name of the containing module or class; or None name: name of this object relative to the containing object """ - data: JsonDict = { - ".class": "SymbolTableNode", - "kind": node_kinds[self.kind], - } + data: JsonDict = {".class": "SymbolTableNode", "kind": node_kinds[self.kind]} if self.module_hidden: - data['module_hidden'] = True + data["module_hidden"] = True if not self.module_public: - data['module_public'] = False + data["module_public"] = False if self.implicit: - data['implicit'] = True + data["implicit"] = True if self.plugin_generated: - data['plugin_generated'] = True + data["plugin_generated"] = True if isinstance(self.node, MypyFile): - data['cross_ref'] = self.node.fullname + data["cross_ref"] = self.node.fullname else: - assert self.node is not None, f'{prefix}:{name}' + assert self.node is not None, f"{prefix}:{name}" if prefix is not None: fullname = self.node.fullname - if (fullname is not None and '.' in fullname - and fullname != prefix + '.' + name - and not (isinstance(self.node, Var) - and self.node.from_module_getattr)): - assert not isinstance(self.node, PlaceholderNode), ( - f'Definition of {fullname} is unexpectedly incomplete' - ) - data['cross_ref'] = fullname + if ( + fullname is not None + and "." in fullname + and fullname != prefix + "." + name + and not (isinstance(self.node, Var) and self.node.from_module_getattr) + ): + assert not isinstance( + self.node, PlaceholderNode + ), f"Definition of {fullname} is unexpectedly incomplete" + data["cross_ref"] = fullname return data - data['node'] = self.node.serialize() + data["node"] = self.node.serialize() return data @classmethod - def deserialize(cls, data: JsonDict) -> 'SymbolTableNode': - assert data['.class'] == 'SymbolTableNode' - kind = inverse_node_kinds[data['kind']] - if 'cross_ref' in data: + def deserialize(cls, data: JsonDict) -> "SymbolTableNode": + assert data[".class"] == "SymbolTableNode" + kind = inverse_node_kinds[data["kind"]] + if "cross_ref" in data: # This will be fixed up later. stnode = SymbolTableNode(kind, None) - stnode.cross_ref = data['cross_ref'] + stnode.cross_ref = data["cross_ref"] else: - assert 'node' in data, data - node = SymbolNode.deserialize(data['node']) + assert "node" in data, data + node = SymbolNode.deserialize(data["node"]) stnode = SymbolTableNode(kind, node) - if 'module_hidden' in data: - stnode.module_hidden = data['module_hidden'] - if 'module_public' in data: - stnode.module_public = data['module_public'] - if 'implicit' in data: - stnode.implicit = data['implicit'] - if 'plugin_generated' in data: - stnode.plugin_generated = data['plugin_generated'] + if "module_hidden" in data: + stnode.module_hidden = data["module_hidden"] + if "module_public" in data: + stnode.module_public = data["module_public"] + if "implicit" in data: + stnode.implicit = data["implicit"] + if "plugin_generated" in data: + stnode.plugin_generated = data["plugin_generated"] return stnode @@ -3441,20 +3638,20 @@ def __str__(self) -> str: for key, value in self.items(): # Filter out the implicit import of builtins. if isinstance(value, SymbolTableNode): - if (value.fullname != 'builtins' and - (value.fullname or '').split('.')[-1] not in - implicit_module_attrs): - a.append(' ' + str(key) + ' : ' + str(value)) + if ( + value.fullname != "builtins" + and (value.fullname or "").split(".")[-1] not in implicit_module_attrs + ): + a.append(" " + str(key) + " : " + str(value)) else: - a.append(' ') + a.append(" ") a = sorted(a) - a.insert(0, 'SymbolTable(') - a[-1] += ')' - return '\n'.join(a) + a.insert(0, "SymbolTable(") + a[-1] += ")" + return "\n".join(a) - def copy(self) -> 'SymbolTable': - return SymbolTable([(key, node.copy()) - for key, node in self.items()]) + def copy(self) -> "SymbolTable": + return SymbolTable([(key, node.copy()) for key, node in self.items()]) def serialize(self, fullname: str) -> JsonDict: data: JsonDict = {".class": "SymbolTable"} @@ -3463,17 +3660,17 @@ def serialize(self, fullname: str) -> JsonDict: # module that gets added to every module by # SemanticAnalyzerPass2.visit_file(), but it shouldn't be # accessed by users of the module. - if key == '__builtins__' or value.no_serialize: + if key == "__builtins__" or value.no_serialize: continue data[key] = value.serialize(fullname, key) return data @classmethod - def deserialize(cls, data: JsonDict) -> 'SymbolTable': - assert data['.class'] == 'SymbolTable' + def deserialize(cls, data: JsonDict) -> "SymbolTable": + assert data[".class"] == "SymbolTable" st = SymbolTable() for key, value in data.items(): - if key != '.class': + if key != ".class": st[key] = SymbolTableNode.deserialize(value) return st @@ -3500,19 +3697,22 @@ def get_member_expr_fullname(expr: MemberExpr) -> Optional[str]: initial = get_member_expr_fullname(expr.expr) else: return None - return f'{initial}.{expr.name}' + return f"{initial}.{expr.name}" deserialize_map: Final = { key: obj.deserialize for key, obj in globals().items() if type(obj) is not FakeInfo - and isinstance(obj, type) and issubclass(obj, SymbolNode) and obj is not SymbolNode + and isinstance(obj, type) + and issubclass(obj, SymbolNode) + and obj is not SymbolNode } def check_arg_kinds( - arg_kinds: List[ArgKind], nodes: List[T], fail: Callable[[str, T], None]) -> None: + arg_kinds: List[ArgKind], nodes: List[T], fail: Callable[[str, T], None] +) -> None: is_var_arg = False is_kw_arg = False seen_named = False @@ -3520,9 +3720,10 @@ def check_arg_kinds( for kind, node in zip(arg_kinds, nodes): if kind == ARG_POS: if is_var_arg or is_kw_arg or seen_named or seen_opt: - fail("Required positional args may not appear " - "after default, named or var args", - node) + fail( + "Required positional args may not appear " "after default, named or var args", + node, + ) break elif kind == ARG_OPT: if is_var_arg or is_kw_arg or seen_named: @@ -3546,8 +3747,12 @@ def check_arg_kinds( is_kw_arg = True -def check_arg_names(names: Sequence[Optional[str]], nodes: List[T], fail: Callable[[str, T], None], - description: str = 'function definition') -> None: +def check_arg_names( + names: Sequence[Optional[str]], + nodes: List[T], + fail: Callable[[str, T], None], + description: str = "function definition", +) -> None: seen_names: Set[Optional[str]] = set() for name, node in zip(names, nodes): if name is not None and name in seen_names: @@ -3568,9 +3773,9 @@ def is_final_node(node: Optional[SymbolNode]) -> bool: return isinstance(node, (Var, FuncDef, OverloadedFuncDef, Decorator)) and node.is_final -def local_definitions(names: SymbolTable, - name_prefix: str, - info: Optional[TypeInfo] = None) -> Iterator[Definition]: +def local_definitions( + names: SymbolTable, name_prefix: str, info: Optional[TypeInfo] = None +) -> Iterator[Definition]: """Iterate over local definitions (not imported) in a symbol table. Recursively iterate over class members and nested classes. @@ -3578,10 +3783,10 @@ def local_definitions(names: SymbolTable, # TODO: What should the name be? Or maybe remove it? for name, symnode in names.items(): shortname = name - if '-redef' in name: + if "-redef" in name: # Restore original name from mangled name of multiply defined function - shortname = name.split('-redef')[0] - fullname = name_prefix + '.' + shortname + shortname = name.split("-redef")[0] + fullname = name_prefix + "." + shortname node = symnode.node if node and node.fullname == fullname: yield fullname, symnode, info diff --git a/mypy/operators.py b/mypy/operators.py index 85cbfcb99528c..4655f9d184ad5 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -2,34 +2,33 @@ from typing_extensions import Final - # Map from binary operator id to related method name (in Python 3). op_methods: Final = { - '+': '__add__', - '-': '__sub__', - '*': '__mul__', - '/': '__truediv__', - '%': '__mod__', - 'divmod': '__divmod__', - '//': '__floordiv__', - '**': '__pow__', - '@': '__matmul__', - '&': '__and__', - '|': '__or__', - '^': '__xor__', - '<<': '__lshift__', - '>>': '__rshift__', - '==': '__eq__', - '!=': '__ne__', - '<': '__lt__', - '>=': '__ge__', - '>': '__gt__', - '<=': '__le__', - 'in': '__contains__', + "+": "__add__", + "-": "__sub__", + "*": "__mul__", + "/": "__truediv__", + "%": "__mod__", + "divmod": "__divmod__", + "//": "__floordiv__", + "**": "__pow__", + "@": "__matmul__", + "&": "__and__", + "|": "__or__", + "^": "__xor__", + "<<": "__lshift__", + ">>": "__rshift__", + "==": "__eq__", + "!=": "__ne__", + "<": "__lt__", + ">=": "__ge__", + ">": "__gt__", + "<=": "__le__", + "in": "__contains__", } op_methods_to_symbols: Final = {v: k for (k, v) in op_methods.items()} -op_methods_to_symbols['__div__'] = '/' +op_methods_to_symbols["__div__"] = "/" comparison_fallback_method: Final = "__cmp__" ops_falling_back_to_cmp: Final = {"__ne__", "__eq__", "__lt__", "__le__", "__gt__", "__ge__"} @@ -54,26 +53,26 @@ inplace_operator_methods: Final = {"__i" + op_methods[op][2:] for op in ops_with_inplace_method} reverse_op_methods: Final = { - '__add__': '__radd__', - '__sub__': '__rsub__', - '__mul__': '__rmul__', - '__truediv__': '__rtruediv__', - '__mod__': '__rmod__', - '__divmod__': '__rdivmod__', - '__floordiv__': '__rfloordiv__', - '__pow__': '__rpow__', - '__matmul__': '__rmatmul__', - '__and__': '__rand__', - '__or__': '__ror__', - '__xor__': '__rxor__', - '__lshift__': '__rlshift__', - '__rshift__': '__rrshift__', - '__eq__': '__eq__', - '__ne__': '__ne__', - '__lt__': '__gt__', - '__ge__': '__le__', - '__gt__': '__lt__', - '__le__': '__ge__', + "__add__": "__radd__", + "__sub__": "__rsub__", + "__mul__": "__rmul__", + "__truediv__": "__rtruediv__", + "__mod__": "__rmod__", + "__divmod__": "__rdivmod__", + "__floordiv__": "__rfloordiv__", + "__pow__": "__rpow__", + "__matmul__": "__rmatmul__", + "__and__": "__rand__", + "__or__": "__ror__", + "__xor__": "__rxor__", + "__lshift__": "__rlshift__", + "__rshift__": "__rrshift__", + "__eq__": "__eq__", + "__ne__": "__ne__", + "__lt__": "__gt__", + "__ge__": "__le__", + "__gt__": "__lt__", + "__le__": "__ge__", } reverse_op_method_names: Final = set(reverse_op_methods.values()) @@ -82,28 +81,24 @@ # the output of A().__add__(A()) and skip calling the __radd__ method entirely. # This shortcut is used only for the following methods: op_methods_that_shortcut: Final = { - '__add__', - '__sub__', - '__mul__', - '__div__', - '__truediv__', - '__mod__', - '__divmod__', - '__floordiv__', - '__pow__', - '__matmul__', - '__and__', - '__or__', - '__xor__', - '__lshift__', - '__rshift__', + "__add__", + "__sub__", + "__mul__", + "__div__", + "__truediv__", + "__mod__", + "__divmod__", + "__floordiv__", + "__pow__", + "__matmul__", + "__and__", + "__or__", + "__xor__", + "__lshift__", + "__rshift__", } normal_from_reverse_op: Final = {m: n for n, m in reverse_op_methods.items()} reverse_op_method_set: Final = set(reverse_op_methods.values()) -unary_op_methods: Final = { - '-': '__neg__', - '+': '__pos__', - '~': '__invert__', -} +unary_op_methods: Final = {"-": "__neg__", "+": "__pos__", "~": "__invert__"} diff --git a/mypy/options.py b/mypy/options.py index de76d549b20fd..15b474466e312 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -1,12 +1,12 @@ -from mypy.backports import OrderedDict -import re import pprint +import re import sys +from typing import Any, Callable, Dict, List, Mapping, Optional, Pattern, Set, Tuple -from typing_extensions import Final, TYPE_CHECKING -from typing import Dict, List, Mapping, Optional, Pattern, Set, Tuple, Callable, Any +from typing_extensions import TYPE_CHECKING, Final from mypy import defaults +from mypy.backports import OrderedDict from mypy.util import get_class_descriptors, replace_object_state if TYPE_CHECKING: @@ -85,7 +85,7 @@ def __init__(self) -> None: self.ignore_missing_imports = False # Is ignore_missing_imports set in a per-module section self.ignore_missing_imports_per_module = False - self.follow_imports = 'normal' # normal|silent|skip|error + self.follow_imports = "normal" # normal|silent|skip|error # Whether to respect the follow_imports setting even for stub files. # Intended to be used for disabling specific stubs. self.follow_imports_for_stubs = False @@ -319,18 +319,18 @@ def new_semantic_analyzer(self) -> bool: def snapshot(self) -> object: """Produce a comparable snapshot of this Option""" # Under mypyc, we don't have a __dict__, so we need to do worse things. - d = dict(getattr(self, '__dict__', ())) + d = dict(getattr(self, "__dict__", ())) for k in get_class_descriptors(Options): if hasattr(self, k) and k != "new_semantic_analyzer": d[k] = getattr(self, k) # Remove private attributes from snapshot - d = {k: v for k, v in d.items() if not k.startswith('_')} + d = {k: v for k, v in d.items() if not k.startswith("_")} return d def __repr__(self) -> str: - return f'Options({pprint.pformat(self.snapshot())})' + return f"Options({pprint.pformat(self.snapshot())})" - def apply_changes(self, changes: Dict[str, object]) -> 'Options': + def apply_changes(self, changes: Dict[str, object]) -> "Options": new_options = Options() # Under mypyc, we don't have a __dict__, so we need to do worse things. replace_object_state(new_options, self, copy_dict=True) @@ -360,12 +360,10 @@ def build_per_module_cache(self) -> None: # than foo.bar.*. # (A section being "processed last" results in its config "winning".) # Unstructured glob configs are stored and are all checked for each module. - unstructured_glob_keys = [k for k in self.per_module_options.keys() - if '*' in k[:-1]] - structured_keys = [k for k in self.per_module_options.keys() - if '*' not in k[:-1]] - wildcards = sorted(k for k in structured_keys if k.endswith('.*')) - concrete = [k for k in structured_keys if not k.endswith('.*')] + unstructured_glob_keys = [k for k in self.per_module_options.keys() if "*" in k[:-1]] + structured_keys = [k for k in self.per_module_options.keys() if "*" not in k[:-1]] + wildcards = sorted(k for k in structured_keys if k.endswith(".*")) + concrete = [k for k in structured_keys if not k.endswith(".*")] for glob in unstructured_glob_keys: self._glob_options.append((glob, self.compile_glob(glob))) @@ -387,7 +385,7 @@ def build_per_module_cache(self) -> None: # they only count as used if actually used by a real module. self.unused_configs.update(structured_keys) - def clone_for_module(self, module: str) -> 'Options': + def clone_for_module(self, module: str) -> "Options": """Create an Options object that incorporates per-module options. NOTE: Once this method is called all Options objects should be @@ -408,9 +406,9 @@ def clone_for_module(self, module: str) -> 'Options': # This is technically quadratic in the length of the path, but module paths # don't actually get all that long. options = self - path = module.split('.') + path = module.split(".") for i in range(len(path), 0, -1): - key = '.'.join(path[:i] + ['*']) + key = ".".join(path[:i] + ["*"]) if key in self._per_module_cache: self.unused_configs.discard(key) options = self._per_module_cache[key] @@ -418,7 +416,7 @@ def clone_for_module(self, module: str) -> 'Options': # OK and *now* we need to look for unstructured glob matches. # We only do this for concrete modules, not structured wildcards. - if not module.endswith('.*'): + if not module.endswith(".*"): for key, pattern in self._glob_options: if pattern.match(module): self.unused_configs.discard(key) @@ -434,11 +432,11 @@ def compile_glob(self, s: str) -> Pattern[str]: # Compile one of the glob patterns to a regex so that '.*' can # match *zero or more* module sections. This means we compile # '.*' into '(\..*)?'. - parts = s.split('.') - expr = re.escape(parts[0]) if parts[0] != '*' else '.*' + parts = s.split(".") + expr = re.escape(parts[0]) if parts[0] != "*" else ".*" for part in parts[1:]: - expr += re.escape('.' + part) if part != '*' else r'(\..*)?' - return re.compile(expr + '\\Z') + expr += re.escape("." + part) if part != "*" else r"(\..*)?" + return re.compile(expr + "\\Z") def select_options_affecting_cache(self) -> Mapping[str, object]: return {opt: getattr(self, opt) for opt in OPTIONS_AFFECTING_CACHE} diff --git a/mypy/parse.py b/mypy/parse.py index c39a2388028a9..4c18fe22a1dc7 100644 --- a/mypy/parse.py +++ b/mypy/parse.py @@ -1,15 +1,17 @@ -from typing import Union, Optional +from typing import Optional, Union from mypy.errors import Errors -from mypy.options import Options from mypy.nodes import MypyFile +from mypy.options import Options -def parse(source: Union[str, bytes], - fnam: str, - module: Optional[str], - errors: Optional[Errors], - options: Options) -> MypyFile: +def parse( + source: Union[str, bytes], + fnam: str, + module: Optional[str], + errors: Optional[Errors], + options: Options, +) -> MypyFile: """Parse a source file, without doing any semantic analysis. Return the parse tree. If errors is not provided, raise ParseError @@ -17,20 +19,18 @@ def parse(source: Union[str, bytes], The python_version (major, minor) option determines the Python syntax variant. """ - is_stub_file = fnam.endswith('.pyi') + is_stub_file = fnam.endswith(".pyi") if options.transform_source is not None: source = options.transform_source(source) if options.python_version[0] >= 3 or is_stub_file: import mypy.fastparse - return mypy.fastparse.parse(source, - fnam=fnam, - module=module, - errors=errors, - options=options) + + return mypy.fastparse.parse( + source, fnam=fnam, module=module, errors=errors, options=options + ) else: import mypy.fastparse2 - return mypy.fastparse2.parse(source, - fnam=fnam, - module=module, - errors=errors, - options=options) + + return mypy.fastparse2.parse( + source, fnam=fnam, module=module, errors=errors, options=options + ) diff --git a/mypy/patterns.py b/mypy/patterns.py index f7f5f56d0ed5e..11aec70655c64 100644 --- a/mypy/patterns.py +++ b/mypy/patterns.py @@ -1,13 +1,12 @@ """Classes for representing match statement patterns.""" -from typing import TypeVar, List, Optional, Union +from typing import List, Optional, TypeVar, Union from mypy_extensions import trait -from mypy.nodes import Node, RefExpr, NameExpr, Expression +from mypy.nodes import Expression, NameExpr, Node, RefExpr from mypy.visitor import PatternVisitor - -T = TypeVar('T') +T = TypeVar("T") @trait @@ -17,11 +16,12 @@ class Pattern(Node): __slots__ = () def accept(self, visitor: PatternVisitor[T]) -> T: - raise RuntimeError('Not implemented') + raise RuntimeError("Not implemented") class AsPattern(Pattern): """The pattern as """ + # The python ast, and therefore also our ast merges capture, wildcard and as patterns into one # for easier handling. # If pattern is None this is a capture pattern. If name and pattern are both none this is a @@ -41,6 +41,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class OrPattern(Pattern): """The pattern | | ...""" + patterns: List[Pattern] def __init__(self, patterns: List[Pattern]) -> None: @@ -53,6 +54,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ValuePattern(Pattern): """The pattern x.y (or x.y.z, ...)""" + expr: Expression def __init__(self, expr: Expression): @@ -77,6 +79,7 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class SequencePattern(Pattern): """The pattern [, ...]""" + patterns: List[Pattern] def __init__(self, patterns: List[Pattern]): @@ -105,8 +108,7 @@ class MappingPattern(Pattern): values: List[Pattern] rest: Optional[NameExpr] - def __init__(self, keys: List[Expression], values: List[Pattern], - rest: Optional[NameExpr]): + def __init__(self, keys: List[Expression], values: List[Pattern], rest: Optional[NameExpr]): super().__init__() assert len(keys) == len(values) self.keys = keys @@ -119,13 +121,19 @@ def accept(self, visitor: PatternVisitor[T]) -> T: class ClassPattern(Pattern): """The pattern Cls(...)""" + class_ref: RefExpr positionals: List[Pattern] keyword_keys: List[str] keyword_values: List[Pattern] - def __init__(self, class_ref: RefExpr, positionals: List[Pattern], keyword_keys: List[str], - keyword_values: List[Pattern]): + def __init__( + self, + class_ref: RefExpr, + positionals: List[Pattern], + keyword_keys: List[str], + keyword_values: List[Pattern], + ): super().__init__() assert len(keyword_keys) == len(keyword_values) self.class_ref = class_ref diff --git a/mypy/plugin.py b/mypy/plugin.py index 2f571d7eecc6d..948d2b1e829af 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -120,21 +120,35 @@ class C: pass """ from abc import abstractmethod -from typing import Any, Callable, List, Tuple, Optional, NamedTuple, TypeVar, Dict, Union -from mypy_extensions import trait, mypyc_attr +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, TypeVar, Union +from mypy_extensions import mypyc_attr, trait + +from mypy.errorcodes import ErrorCode +from mypy.lookup import lookup_fully_qualified +from mypy.message_registry import ErrorMessage +from mypy.messages import MessageBuilder from mypy.nodes import ( - Expression, Context, ClassDef, SymbolTableNode, MypyFile, CallExpr, ArgKind, TypeInfo + ArgKind, + CallExpr, + ClassDef, + Context, + Expression, + MypyFile, + SymbolTableNode, + TypeInfo, ) +from mypy.options import Options from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( - Type, Instance, CallableType, TypeList, UnboundType, ProperType, FunctionLike + CallableType, + FunctionLike, + Instance, + ProperType, + Type, + TypeList, + UnboundType, ) -from mypy.messages import MessageBuilder -from mypy.options import Options -from mypy.lookup import lookup_fully_qualified -from mypy.errorcodes import ErrorCode -from mypy.message_registry import ErrorMessage @trait @@ -167,9 +181,9 @@ def analyze_type(self, typ: Type) -> Type: raise NotImplementedError @abstractmethod - def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], - List[ArgKind], - List[Optional[str]]]]: + def analyze_callable_args( + self, arglist: TypeList + ) -> Optional[Tuple[List[Type], List[ArgKind], List[Optional[str]]]]: """Find types, kinds, and names of arguments from extended callable syntax.""" raise NotImplementedError @@ -177,7 +191,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], # A context for a hook that semantically analyzes an unbound type. class AnalyzeTypeContext(NamedTuple): type: UnboundType # Type to analyze - context: Context # Relevant location context (e.g. for error messages) + context: Context # Relevant location context (e.g. for error messages) api: TypeAnalyzerPluginInterface @@ -223,8 +237,9 @@ def type_context(self) -> List[Optional[Type]]: raise NotImplementedError @abstractmethod - def fail(self, msg: Union[str, ErrorMessage], ctx: Context, *, - code: Optional[ErrorCode] = None) -> None: + def fail( + self, msg: Union[str, ErrorMessage], ctx: Context, *, code: Optional[ErrorCode] = None + ) -> None: """Emit an error message at given location.""" raise NotImplementedError @@ -251,8 +266,7 @@ class SemanticAnalyzerPluginInterface: msg: MessageBuilder @abstractmethod - def named_type(self, fullname: str, - args: Optional[List[Type]] = None) -> Instance: + def named_type(self, fullname: str, args: Optional[List[Type]] = None) -> Instance: """Construct an instance of a builtin type with given type arguments.""" raise NotImplementedError @@ -263,8 +277,9 @@ def builtin_type(self, fully_qualified_name: str) -> Instance: raise NotImplementedError @abstractmethod - def named_type_or_none(self, fullname: str, - args: Optional[List[Type]] = None) -> Optional[Instance]: + def named_type_or_none( + self, fullname: str, args: Optional[List[Type]] = None + ) -> Optional[Instance]: """Construct an instance of a type with given type arguments. Return None if a type could not be constructed for the qualified @@ -283,18 +298,29 @@ def parse_bool(self, expr: Expression) -> Optional[bool]: raise NotImplementedError @abstractmethod - def fail(self, msg: str, ctx: Context, serious: bool = False, *, - blocker: bool = False, code: Optional[ErrorCode] = None) -> None: + def fail( + self, + msg: str, + ctx: Context, + serious: bool = False, + *, + blocker: bool = False, + code: Optional[ErrorCode] = None, + ) -> None: """Emit an error message at given location.""" raise NotImplementedError @abstractmethod - def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarLikeScope] = None, - allow_tuple_literal: bool = False, - allow_unbound_tvars: bool = False, - report_invalid_types: bool = True, - third_pass: bool = False) -> Optional[Type]: + def anal_type( + self, + t: Type, + *, + tvar_scope: Optional[TypeVarLikeScope] = None, + allow_tuple_literal: bool = False, + allow_unbound_tvars: bool = False, + report_invalid_types: bool = True, + third_pass: bool = False, + ) -> Optional[Type]: """Analyze an unbound type. Return None if some part of the type is not ready yet. In this @@ -325,8 +351,9 @@ def lookup_fully_qualified_or_none(self, name: str) -> Optional[SymbolTableNode] raise NotImplementedError @abstractmethod - def lookup_qualified(self, name: str, ctx: Context, - suppress_errors: bool = False) -> Optional[SymbolTableNode]: + def lookup_qualified( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> Optional[SymbolTableNode]: """Lookup symbol using a name in current scope. This follows Python local->non-local->global->builtins rules. @@ -463,7 +490,7 @@ class AttributeContext(NamedTuple): # A context for a class hook that modifies the class definition. class ClassDefContext(NamedTuple): cls: ClassDef # The class definition - reason: Expression # The expression being applied (decorator, metaclass, base class) + reason: Expression # The expression being applied (decorator, metaclass, base class) api: SemanticAnalyzerPluginInterface @@ -545,8 +572,9 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: """ return [] - def get_type_analyze_hook(self, fullname: str - ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: + def get_type_analyze_hook( + self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: """Customize behaviour of the type analyzer for given full names. This method is called during the semantic analysis pass whenever mypy sees an @@ -564,8 +592,9 @@ def func(x: Other[int]) -> None: """ return None - def get_function_signature_hook(self, fullname: str - ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: + def get_function_signature_hook( + self, fullname: str + ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: """Adjust the signature of a function. This method is called before type checking a function call. Plugin @@ -580,8 +609,7 @@ def get_function_signature_hook(self, fullname: str """ return None - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: """Adjust the return type of a function call. This method is called after type checking a call. Plugin may adjust the return @@ -597,8 +625,9 @@ def get_function_hook(self, fullname: str """ return None - def get_method_signature_hook(self, fullname: str - ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: + def get_method_signature_hook( + self, fullname: str + ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: """Adjust the signature of a method. This method is called before type checking a method call. Plugin @@ -626,8 +655,7 @@ class Derived(Base): """ return None - def get_method_hook(self, fullname: str - ) -> Optional[Callable[[MethodContext], Type]]: + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: """Adjust return type of a method call. This is the same as get_function_hook(), but is called with the @@ -635,8 +663,7 @@ def get_method_hook(self, fullname: str """ return None - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: + def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], Type]]: """Adjust type of an instance attribute. This method is called with attribute full name using the class of the instance where @@ -667,8 +694,9 @@ class Derived(Base): """ return None - def get_class_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: + def get_class_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: """ Adjust type of a class attribute. @@ -686,8 +714,9 @@ class Cls: """ return None - def get_class_decorator_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: """Update class definition for given class decorators. The plugin can modify a TypeInfo _in place_ (for example add some generated @@ -705,8 +734,9 @@ def get_class_decorator_hook(self, fullname: str """ return None - def get_class_decorator_hook_2(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], bool]]: + def get_class_decorator_hook_2( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: """Update class definition for given class decorators. Similar to get_class_decorator_hook, but this runs in a later pass when @@ -722,8 +752,7 @@ def get_class_decorator_hook_2(self, fullname: str """ return None - def get_metaclass_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: """Update class definition for given declared metaclasses. Same as get_class_decorator_hook() but for metaclasses. Note: @@ -734,8 +763,7 @@ def get_metaclass_hook(self, fullname: str """ return None - def get_base_class_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: """Update class definition for given base classes. Same as get_class_decorator_hook() but for base classes. Base classes @@ -744,8 +772,9 @@ def get_base_class_hook(self, fullname: str """ return None - def get_customize_class_mro_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: """Customize MRO for given classes. The plugin can modify the class MRO _in place_. This method is called @@ -753,8 +782,9 @@ def get_customize_class_mro_hook(self, fullname: str """ return None - def get_dynamic_class_hook(self, fullname: str - ) -> Optional[Callable[[DynamicClassDefContext], None]]: + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: """Semantically analyze a dynamic class definition. This plugin hook allows one to semantically analyze dynamic class definitions like: @@ -770,7 +800,7 @@ def get_dynamic_class_hook(self, fullname: str return None -T = TypeVar('T') +T = TypeVar("T") class ChainedPlugin(Plugin): @@ -807,56 +837,59 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: deps.extend(plugin.get_additional_deps(file)) return deps - def get_type_analyze_hook(self, fullname: str - ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: + def get_type_analyze_hook( + self, fullname: str + ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname)) - def get_function_signature_hook(self, fullname: str - ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: + def get_function_signature_hook( + self, fullname: str + ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: return self._find_hook(lambda plugin: plugin.get_function_signature_hook(fullname)) - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: return self._find_hook(lambda plugin: plugin.get_function_hook(fullname)) - def get_method_signature_hook(self, fullname: str - ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: + def get_method_signature_hook( + self, fullname: str + ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: return self._find_hook(lambda plugin: plugin.get_method_signature_hook(fullname)) - def get_method_hook(self, fullname: str - ) -> Optional[Callable[[MethodContext], Type]]: + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: return self._find_hook(lambda plugin: plugin.get_method_hook(fullname)) - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: + def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname)) - def get_class_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: + def get_class_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_class_attribute_hook(fullname)) - def get_class_decorator_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname)) - def get_class_decorator_hook_2(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], bool]]: + def get_class_decorator_hook_2( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: return self._find_hook(lambda plugin: plugin.get_class_decorator_hook_2(fullname)) - def get_metaclass_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_metaclass_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_metaclass_hook(fullname)) - def get_base_class_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_base_class_hook(fullname)) - def get_customize_class_mro_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: + def get_customize_class_mro_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_customize_class_mro_hook(fullname)) - def get_dynamic_class_hook(self, fullname: str - ) -> Optional[Callable[[DynamicClassDefContext], None]]: + def get_dynamic_class_hook( + self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_dynamic_class_hook(fullname)) def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: diff --git a/mypy/plugins/attrs.py b/mypy/plugins/attrs.py index 06c11f130f11a..765753b71d313 100644 --- a/mypy/plugins/attrs.py +++ b/mypy/plugins/attrs.py @@ -1,55 +1,79 @@ """Plugin for supporting the attrs library (http://www.attrs.org)""" -from mypy.backports import OrderedDict +from typing import Dict, Iterable, List, Optional, Tuple, cast -from typing import Optional, Dict, List, cast, Tuple, Iterable from typing_extensions import Final import mypy.plugin # To avoid circular imports. -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError +from mypy.backports import OrderedDict +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type from mypy.nodes import ( - Context, Argument, Var, ARG_OPT, ARG_POS, TypeInfo, AssignmentStmt, - TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, FuncDef, - is_class_var, TempNode, Decorator, MemberExpr, Expression, - SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, ARG_NAMED_OPT, ARG_NAMED, - TypeVarExpr, PlaceholderNode, LambdaExpr + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + MDEF, + Argument, + AssignmentStmt, + CallExpr, + Context, + Decorator, + Expression, + FuncDef, + JsonDict, + LambdaExpr, + ListExpr, + MemberExpr, + NameExpr, + OverloadedFuncDef, + PlaceholderNode, + RefExpr, + SymbolTableNode, + TempNode, + TupleExpr, + TypeInfo, + TypeVarExpr, + Var, + is_class_var, ) from mypy.plugin import SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - _get_argument, _get_bool_argument, _get_decorator_bool_argument, add_method, - deserialize_and_fixup_type, add_attribute_to_class, + _get_argument, + _get_bool_argument, + _get_decorator_bool_argument, + add_attribute_to_class, + add_method, + deserialize_and_fixup_type, ) +from mypy.server.trigger import make_wildcard_trigger +from mypy.typeops import make_simplified_union, map_type_from_supertype from mypy.types import ( - TupleType, Type, AnyType, TypeOfAny, CallableType, NoneType, TypeVarType, - Overloaded, UnionType, FunctionLike, Instance, get_proper_type, + AnyType, + CallableType, + FunctionLike, + Instance, LiteralType, + NoneType, + Overloaded, + TupleType, + Type, + TypeOfAny, + TypeVarType, + UnionType, + get_proper_type, ) -from mypy.typeops import make_simplified_union, map_type_from_supertype from mypy.typevars import fill_typevars from mypy.util import unmangle -from mypy.server.trigger import make_wildcard_trigger KW_ONLY_PYTHON_2_UNSUPPORTED: Final = "kw_only is not supported in Python 2" # The names of the different functions that create classes or arguments. -attr_class_makers: Final = { - 'attr.s', - 'attr.attrs', - 'attr.attributes', -} -attr_dataclass_makers: Final = { - 'attr.dataclass', -} +attr_class_makers: Final = {"attr.s", "attr.attrs", "attr.attributes"} +attr_dataclass_makers: Final = {"attr.dataclass"} attr_frozen_makers: Final = {"attr.frozen", "attrs.frozen"} attr_define_makers: Final = {"attr.define", "attr.mutable", "attrs.define", "attrs.mutable"} -attr_attrib_makers: Final = { - 'attr.ib', - 'attr.attrib', - 'attr.attr', - 'attr.field', - 'attrs.field', -} -attr_optional_converters: Final = {'attr.converters.optional', 'attrs.converters.optional'} +attr_attrib_makers: Final = {"attr.ib", "attr.attrib", "attr.attr", "attr.field", "attrs.field"} +attr_optional_converters: Final = {"attr.converters.optional", "attrs.converters.optional"} SELF_TVAR_NAME: Final = "_AT" MAGIC_ATTR_NAME: Final = "__attrs_attrs__" @@ -59,19 +83,24 @@ class Converter: """Holds information about a `converter=` argument""" - def __init__(self, - init_type: Optional[Type] = None, - ) -> None: + def __init__(self, init_type: Optional[Type] = None) -> None: self.init_type = init_type class Attribute: """The value of an attr.ib() call.""" - def __init__(self, name: str, info: TypeInfo, - has_default: bool, init: bool, kw_only: bool, converter: Optional[Converter], - context: Context, - init_type: Optional[Type]) -> None: + def __init__( + self, + name: str, + info: TypeInfo, + has_default: bool, + init: bool, + kw_only: bool, + converter: Optional[Converter], + context: Context, + init_type: Optional[Type], + ) -> None: self.name = name self.info = info self.has_default = has_default @@ -81,7 +110,7 @@ def __init__(self, name: str, info: TypeInfo, self.context = context self.init_type = init_type - def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: + def argument(self, ctx: "mypy.plugin.ClassDefContext") -> Argument: """Return this attribute as an argument to __init__.""" assert self.init @@ -121,44 +150,48 @@ def argument(self, ctx: 'mypy.plugin.ClassDefContext') -> Argument: arg_kind = ARG_OPT if self.has_default else ARG_POS # Attrs removes leading underscores when creating the __init__ arguments. - return Argument(Var(self.name.lstrip("_"), init_type), init_type, - None, - arg_kind) + return Argument(Var(self.name.lstrip("_"), init_type), init_type, None, arg_kind) def serialize(self) -> JsonDict: """Serialize this object so it can be saved and restored.""" return { - 'name': self.name, - 'has_default': self.has_default, - 'init': self.init, - 'kw_only': self.kw_only, - 'has_converter': self.converter is not None, - 'converter_init_type': self.converter.init_type.serialize() - if self.converter and self.converter.init_type else None, - 'context_line': self.context.line, - 'context_column': self.context.column, - 'init_type': self.init_type.serialize() if self.init_type else None, + "name": self.name, + "has_default": self.has_default, + "init": self.init, + "kw_only": self.kw_only, + "has_converter": self.converter is not None, + "converter_init_type": self.converter.init_type.serialize() + if self.converter and self.converter.init_type + else None, + "context_line": self.context.line, + "context_column": self.context.column, + "init_type": self.init_type.serialize() if self.init_type else None, } @classmethod - def deserialize(cls, info: TypeInfo, - data: JsonDict, - api: SemanticAnalyzerPluginInterface) -> 'Attribute': + def deserialize( + cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface + ) -> "Attribute": """Return the Attribute that was serialized.""" - raw_init_type = data['init_type'] + raw_init_type = data["init_type"] init_type = deserialize_and_fixup_type(raw_init_type, api) if raw_init_type else None - raw_converter_init_type = data['converter_init_type'] - converter_init_type = (deserialize_and_fixup_type(raw_converter_init_type, api) - if raw_converter_init_type else None) + raw_converter_init_type = data["converter_init_type"] + converter_init_type = ( + deserialize_and_fixup_type(raw_converter_init_type, api) + if raw_converter_init_type + else None + ) - return Attribute(data['name'], + return Attribute( + data["name"], info, - data['has_default'], - data['init'], - data['kw_only'], - Converter(converter_init_type) if data['has_converter'] else None, - Context(line=data['context_line'], column=data['context_column']), - init_type) + data["has_default"], + data["init"], + data["kw_only"], + Converter(converter_init_type) if data["has_converter"] else None, + Context(line=data["context_line"], column=data["context_column"]), + init_type, + ) def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: """Expands type vars in the context of a subtype when an attribute is inherited @@ -169,14 +202,14 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: self.init_type = None -def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool: +def _determine_eq_order(ctx: "mypy.plugin.ClassDefContext") -> bool: """ Validate the combination of *cmp*, *eq*, and *order*. Derive the effective value of order. """ - cmp = _get_decorator_optional_bool_argument(ctx, 'cmp') - eq = _get_decorator_optional_bool_argument(ctx, 'eq') - order = _get_decorator_optional_bool_argument(ctx, 'order') + cmp = _get_decorator_optional_bool_argument(ctx, "cmp") + eq = _get_decorator_optional_bool_argument(ctx, "eq") + order = _get_decorator_optional_bool_argument(ctx, "order") if cmp is not None and any((eq is not None, order is not None)): ctx.api.fail('Don\'t mix "cmp" with "eq" and "order"', ctx.reason) @@ -193,15 +226,13 @@ def _determine_eq_order(ctx: 'mypy.plugin.ClassDefContext') -> bool: order = eq if eq is False and order is True: - ctx.api.fail('eq must be True if order is True', ctx.reason) + ctx.api.fail("eq must be True if order is True", ctx.reason) return order def _get_decorator_optional_bool_argument( - ctx: 'mypy.plugin.ClassDefContext', - name: str, - default: Optional[bool] = None, + ctx: "mypy.plugin.ClassDefContext", name: str, default: Optional[bool] = None ) -> Optional[bool]: """Return the Optional[bool] argument for the decorator. @@ -211,11 +242,11 @@ def _get_decorator_optional_bool_argument( attr_value = _get_argument(ctx.reason, name) if attr_value: if isinstance(attr_value, NameExpr): - if attr_value.fullname == 'builtins.True': + if attr_value.fullname == "builtins.True": return True - if attr_value.fullname == 'builtins.False': + if attr_value.fullname == "builtins.False": return False - if attr_value.fullname == 'builtins.None': + if attr_value.fullname == "builtins.None": return None ctx.api.fail(f'"{name}" argument must be True or False.', ctx.reason) return default @@ -224,19 +255,21 @@ def _get_decorator_optional_bool_argument( return default -def attr_tag_callback(ctx: 'mypy.plugin.ClassDefContext') -> None: +def attr_tag_callback(ctx: "mypy.plugin.ClassDefContext") -> None: """Record that we have an attrs class in the main semantic analysis pass. The later pass implemented by attr_class_maker_callback will use this to detect attrs lasses in base classes. """ # The value is ignored, only the existence matters. - ctx.cls.info.metadata['attrs_tag'] = {} + ctx.cls.info.metadata["attrs_tag"] = {} -def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', - auto_attribs_default: Optional[bool] = False, - frozen_default: bool = False) -> bool: +def attr_class_maker_callback( + ctx: "mypy.plugin.ClassDefContext", + auto_attribs_default: Optional[bool] = False, + frozen_default: bool = False, +) -> bool: """Add necessary dunder methods to classes decorated with attr.s. attrs is a package that lets you define classes without writing dull boilerplate code. @@ -253,14 +286,14 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', """ info = ctx.cls.info - init = _get_decorator_bool_argument(ctx, 'init', True) + init = _get_decorator_bool_argument(ctx, "init", True) frozen = _get_frozen(ctx, frozen_default) order = _determine_eq_order(ctx) - slots = _get_decorator_bool_argument(ctx, 'slots', False) + slots = _get_decorator_bool_argument(ctx, "slots", False) - auto_attribs = _get_decorator_optional_bool_argument(ctx, 'auto_attribs', auto_attribs_default) - kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False) - match_args = _get_decorator_bool_argument(ctx, 'match_args', True) + auto_attribs = _get_decorator_optional_bool_argument(ctx, "auto_attribs", auto_attribs_default) + kw_only = _get_decorator_bool_argument(ctx, "kw_only", False) + match_args = _get_decorator_bool_argument(ctx, "match_args", True) early_fail = False if ctx.api.options.python_version[0] < 3: @@ -279,7 +312,7 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', return True for super_info in ctx.cls.info.mro[1:-1]: - if 'attrs_tag' in super_info.metadata and 'attrs' not in super_info.metadata: + if "attrs_tag" in super_info.metadata and "attrs" not in super_info.metadata: # Super class is not ready yet. Request another pass. return False @@ -303,9 +336,9 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', _add_match_args(ctx, attributes) # Save the attributes so that subclasses can reuse them. - ctx.cls.info.metadata['attrs'] = { - 'attributes': [attr.serialize() for attr in attributes], - 'frozen': frozen, + ctx.cls.info.metadata["attrs"] = { + "attributes": [attr.serialize() for attr in attributes], + "frozen": frozen, } adder = MethodAdder(ctx) @@ -319,20 +352,20 @@ def attr_class_maker_callback(ctx: 'mypy.plugin.ClassDefContext', return True -def _get_frozen(ctx: 'mypy.plugin.ClassDefContext', frozen_default: bool) -> bool: +def _get_frozen(ctx: "mypy.plugin.ClassDefContext", frozen_default: bool) -> bool: """Return whether this class is frozen.""" - if _get_decorator_bool_argument(ctx, 'frozen', frozen_default): + if _get_decorator_bool_argument(ctx, "frozen", frozen_default): return True # Subclasses of frozen classes are frozen so check that. for super_info in ctx.cls.info.mro[1:-1]: - if 'attrs' in super_info.metadata and super_info.metadata['attrs']['frozen']: + if "attrs" in super_info.metadata and super_info.metadata["attrs"]["frozen"]: return True return False -def _analyze_class(ctx: 'mypy.plugin.ClassDefContext', - auto_attribs: Optional[bool], - kw_only: bool) -> List[Attribute]: +def _analyze_class( + ctx: "mypy.plugin.ClassDefContext", auto_attribs: Optional[bool], kw_only: bool +) -> List[Attribute]: """Analyze the class body of an attr maker, its parents, and return the Attributes found. auto_attribs=True means we'll generate attributes from type annotations also. @@ -372,14 +405,14 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext', taken_attr_names = set(own_attrs) super_attrs = [] for super_info in ctx.cls.info.mro[1:-1]: - if 'attrs' in super_info.metadata: + if "attrs" in super_info.metadata: # Each class depends on the set of attributes in its attrs ancestors. ctx.api.add_plugin_dependency(make_wildcard_trigger(super_info.fullname)) - for data in super_info.metadata['attrs']['attributes']: + for data in super_info.metadata["attrs"]["attributes"]: # Only add an attribute if it hasn't been defined before. This # allows for overwriting attribute definitions by subclassing. - if data['name'] not in taken_attr_names: + if data["name"] not in taken_attr_names: a = Attribute.deserialize(super_info, data, ctx.api) a.expand_typevar_from_subtype(ctx.cls.info) super_attrs.append(a) @@ -403,9 +436,7 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext', context = attribute.context if i >= len(super_attrs) else ctx.cls if not attribute.has_default and last_default: - ctx.api.fail( - "Non-default attributes not allowed after default attributes.", - context) + ctx.api.fail("Non-default attributes not allowed after default attributes.", context) last_default |= attribute.has_default return attributes @@ -413,13 +444,10 @@ def _analyze_class(ctx: 'mypy.plugin.ClassDefContext', def _add_empty_metadata(info: TypeInfo) -> None: """Add empty metadata to mark that we've finished processing this class.""" - info.metadata['attrs'] = { - 'attributes': [], - 'frozen': False, - } + info.metadata["attrs"] = {"attributes": [], "frozen": False} -def _detect_auto_attribs(ctx: 'mypy.plugin.ClassDefContext') -> bool: +def _detect_auto_attribs(ctx: "mypy.plugin.ClassDefContext") -> bool: """Return whether auto_attribs should be enabled or disabled. It's disabled if there are any unannotated attribs() @@ -436,19 +464,21 @@ def _detect_auto_attribs(ctx: 'mypy.plugin.ClassDefContext') -> bool: for lhs, rvalue in zip(lvalues, rvalues): # Check if the right hand side is a call to an attribute maker. - if (isinstance(rvalue, CallExpr) - and isinstance(rvalue.callee, RefExpr) - and rvalue.callee.fullname in attr_attrib_makers - and not stmt.new_syntax): + if ( + isinstance(rvalue, CallExpr) + and isinstance(rvalue.callee, RefExpr) + and rvalue.callee.fullname in attr_attrib_makers + and not stmt.new_syntax + ): # This means we have an attrib without an annotation and so # we can't do auto_attribs=True return False return True -def _attributes_from_assignment(ctx: 'mypy.plugin.ClassDefContext', - stmt: AssignmentStmt, auto_attribs: bool, - kw_only: bool) -> Iterable[Attribute]: +def _attributes_from_assignment( + ctx: "mypy.plugin.ClassDefContext", stmt: AssignmentStmt, auto_attribs: bool, kw_only: bool +) -> Iterable[Attribute]: """Return Attribute objects that are created by this assignment. The assignments can look like this: @@ -469,9 +499,11 @@ def _attributes_from_assignment(ctx: 'mypy.plugin.ClassDefContext', for lhs, rvalue in zip(lvalues, rvalues): # Check if the right hand side is a call to an attribute maker. - if (isinstance(rvalue, CallExpr) - and isinstance(rvalue.callee, RefExpr) - and rvalue.callee.fullname in attr_attrib_makers): + if ( + isinstance(rvalue, CallExpr) + and isinstance(rvalue.callee, RefExpr) + and rvalue.callee.fullname in attr_attrib_makers + ): attr = _attribute_from_attrib_maker(ctx, auto_attribs, kw_only, lhs, rvalue, stmt) if attr: yield attr @@ -487,14 +519,16 @@ def _cleanup_decorator(stmt: Decorator, attr_map: Dict[str, Attribute]) -> None: """ remove_me = [] for func_decorator in stmt.decorators: - if (isinstance(func_decorator, MemberExpr) - and isinstance(func_decorator.expr, NameExpr) - and func_decorator.expr.name in attr_map): + if ( + isinstance(func_decorator, MemberExpr) + and isinstance(func_decorator.expr, NameExpr) + and func_decorator.expr.name in attr_map + ): - if func_decorator.name == 'default': + if func_decorator.name == "default": attr_map[func_decorator.expr.name].has_default = True - if func_decorator.name in ('default', 'validator'): + if func_decorator.name in ("default", "validator"): # These are decorators on the attrib object that only exist during # class creation time. In order to not trigger a type error later we # just remove them. This might leave us with a Decorator with no @@ -508,11 +542,13 @@ def _cleanup_decorator(stmt: Decorator, attr_map: Dict[str, Attribute]) -> None: stmt.decorators.remove(dec) -def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext', - kw_only: bool, - lhs: NameExpr, - rvalue: Expression, - stmt: AssignmentStmt) -> Attribute: +def _attribute_from_auto_attrib( + ctx: "mypy.plugin.ClassDefContext", + kw_only: bool, + lhs: NameExpr, + rvalue: Expression, + stmt: AssignmentStmt, +) -> Attribute: """Return an Attribute for a new type assignment.""" name = unmangle(lhs.name) # `x: int` (without equal sign) assigns rvalue to TempNode(AnyType()) @@ -522,12 +558,14 @@ def _attribute_from_auto_attrib(ctx: 'mypy.plugin.ClassDefContext', return Attribute(name, ctx.cls.info, has_rhs, True, kw_only, None, stmt, init_type) -def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', - auto_attribs: bool, - kw_only: bool, - lhs: NameExpr, - rvalue: CallExpr, - stmt: AssignmentStmt) -> Optional[Attribute]: +def _attribute_from_attrib_maker( + ctx: "mypy.plugin.ClassDefContext", + auto_attribs: bool, + kw_only: bool, + lhs: NameExpr, + rvalue: CallExpr, + stmt: AssignmentStmt, +) -> Optional[Attribute]: """Return an Attribute from the assignment or None if you can't make one.""" if auto_attribs and not stmt.new_syntax: # auto_attribs requires an annotation on *every* attr.ib. @@ -543,17 +581,17 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', init_type = stmt.type # Read all the arguments from the call. - init = _get_bool_argument(ctx, rvalue, 'init', True) + init = _get_bool_argument(ctx, rvalue, "init", True) # Note: If the class decorator says kw_only=True the attribute is ignored. # See https://github.com/python-attrs/attrs/issues/481 for explanation. - kw_only |= _get_bool_argument(ctx, rvalue, 'kw_only', False) + kw_only |= _get_bool_argument(ctx, rvalue, "kw_only", False) if kw_only and ctx.api.options.python_version[0] < 3: ctx.api.fail(KW_ONLY_PYTHON_2_UNSUPPORTED, stmt) return None # TODO: Check for attr.NOTHING - attr_has_default = bool(_get_argument(rvalue, 'default')) - attr_has_factory = bool(_get_argument(rvalue, 'factory')) + attr_has_default = bool(_get_argument(rvalue, "default")) + attr_has_factory = bool(_get_argument(rvalue, "factory")) if attr_has_default and attr_has_factory: ctx.api.fail('Can\'t pass both "default" and "factory".', rvalue) @@ -561,12 +599,12 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', attr_has_default = True # If the type isn't set through annotation but is passed through `type=` use that. - type_arg = _get_argument(rvalue, 'type') + type_arg = _get_argument(rvalue, "type") if type_arg and not init_type: try: un_type = expr_to_unanalyzed_type(type_arg, ctx.api.options, ctx.api.is_stub_file) except TypeTranslationError: - ctx.api.fail('Invalid argument to type', type_arg) + ctx.api.fail("Invalid argument to type", type_arg) else: init_type = ctx.api.anal_type(un_type) if init_type and isinstance(lhs.node, Var) and not lhs.node.type: @@ -575,8 +613,8 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', lhs.is_inferred_def = False # Note: convert is deprecated but works the same as converter. - converter = _get_argument(rvalue, 'converter') - convert = _get_argument(rvalue, 'convert') + converter = _get_argument(rvalue, "converter") + convert = _get_argument(rvalue, "convert") if convert and converter: ctx.api.fail('Can\'t pass both "convert" and "converter".', rvalue) elif convert: @@ -585,22 +623,26 @@ def _attribute_from_attrib_maker(ctx: 'mypy.plugin.ClassDefContext', converter_info = _parse_converter(ctx, converter) name = unmangle(lhs.name) - return Attribute(name, ctx.cls.info, attr_has_default, init, - kw_only, converter_info, stmt, init_type) + return Attribute( + name, ctx.cls.info, attr_has_default, init, kw_only, converter_info, stmt, init_type + ) -def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', - converter_expr: Optional[Expression]) -> Optional[Converter]: +def _parse_converter( + ctx: "mypy.plugin.ClassDefContext", converter_expr: Optional[Expression] +) -> Optional[Converter]: """Return the Converter object from an Expression.""" # TODO: Support complex converters, e.g. lambdas, calls, etc. if not converter_expr: return None converter_info = Converter() - if (isinstance(converter_expr, CallExpr) - and isinstance(converter_expr.callee, RefExpr) - and converter_expr.callee.fullname in attr_optional_converters - and converter_expr.args - and converter_expr.args[0]): + if ( + isinstance(converter_expr, CallExpr) + and isinstance(converter_expr.callee, RefExpr) + and converter_expr.callee.fullname in attr_optional_converters + and converter_expr.args + and converter_expr.args[0] + ): # Special handling for attr.converters.optional(type) # We extract the type and add make the init_args Optional in Attribute.argument converter_expr = converter_expr.args[0] @@ -616,11 +658,13 @@ def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', else: # The converter is an unannotated function. converter_info.init_type = AnyType(TypeOfAny.unannotated) return converter_info - elif (isinstance(converter_expr.node, OverloadedFuncDef) - and is_valid_overloaded_converter(converter_expr.node)): + elif isinstance(converter_expr.node, OverloadedFuncDef) and is_valid_overloaded_converter( + converter_expr.node + ): converter_type = converter_expr.node.type elif isinstance(converter_expr.node, TypeInfo): from mypy.checkmember import type_object_type # To avoid import cycle. + converter_type = type_object_type(converter_expr.node, ctx.api.named_type) if isinstance(converter_expr, LambdaExpr): # TODO: should we send a fail if converter_expr.min_args > 1? @@ -632,7 +676,7 @@ def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', ctx.api.fail( "Unsupported converter, only named functions, types and lambdas are currently " "supported", - converter_expr + converter_expr, ) converter_info.init_type = AnyType(TypeOfAny.from_error) return converter_info @@ -663,13 +707,15 @@ def _parse_converter(ctx: 'mypy.plugin.ClassDefContext', def is_valid_overloaded_converter(defn: OverloadedFuncDef) -> bool: - return all((not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike)) - for item in defn.items) + return all( + (not isinstance(item, Decorator) or isinstance(item.func.type, FunctionLike)) + for item in defn.items + ) def _parse_assignments( - lvalue: Expression, - stmt: AssignmentStmt) -> Tuple[List[NameExpr], List[Expression]]: + lvalue: Expression, stmt: AssignmentStmt +) -> Tuple[List[NameExpr], List[Expression]]: """Convert a possibly complex assignment expression into lists of lvalues and rvalues.""" lvalues: List[NameExpr] = [] rvalues: List[Expression] = [] @@ -684,26 +730,28 @@ def _parse_assignments( return lvalues, rvalues -def _add_order(ctx: 'mypy.plugin.ClassDefContext', adder: 'MethodAdder') -> None: +def _add_order(ctx: "mypy.plugin.ClassDefContext", adder: "MethodAdder") -> None: """Generate all the ordering methods for this class.""" - bool_type = ctx.api.named_type('builtins.bool') - object_type = ctx.api.named_type('builtins.object') + bool_type = ctx.api.named_type("builtins.bool") + object_type = ctx.api.named_type("builtins.object") # Make the types be: # AT = TypeVar('AT') # def __lt__(self: AT, other: AT) -> bool # This way comparisons with subclasses will work correctly. - tvd = TypeVarType(SELF_TVAR_NAME, ctx.cls.info.fullname + '.' + SELF_TVAR_NAME, - -1, [], object_type) - self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, ctx.cls.info.fullname + '.' + SELF_TVAR_NAME, - [], object_type) + tvd = TypeVarType( + SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, -1, [], object_type + ) + self_tvar_expr = TypeVarExpr( + SELF_TVAR_NAME, ctx.cls.info.fullname + "." + SELF_TVAR_NAME, [], object_type + ) ctx.cls.info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) - args = [Argument(Var('other', tvd), tvd, None, ARG_POS)] - for method in ['__lt__', '__le__', '__gt__', '__ge__']: + args = [Argument(Var("other", tvd), tvd, None, ARG_POS)] + for method in ["__lt__", "__le__", "__gt__", "__ge__"]: adder.add_method(method, args, bool_type, self_type=tvd, tvd=tvd) -def _make_frozen(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute]) -> None: +def _make_frozen(ctx: "mypy.plugin.ClassDefContext", attributes: List[Attribute]) -> None: """Turn all the attributes into properties to simulate frozen classes.""" for attribute in attributes: if attribute.name in ctx.cls.info.names: @@ -716,13 +764,14 @@ def _make_frozen(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute] # can modify it. var = Var(attribute.name, ctx.cls.info[attribute.name].type) var.info = ctx.cls.info - var._fullname = f'{ctx.cls.info.fullname}.{var.name}' + var._fullname = f"{ctx.cls.info.fullname}.{var.name}" ctx.cls.info.names[var.name] = SymbolTableNode(MDEF, var) var.is_property = True -def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute], - adder: 'MethodAdder') -> None: +def _add_init( + ctx: "mypy.plugin.ClassDefContext", attributes: List[Attribute], adder: "MethodAdder" +) -> None: """Generate an __init__ method for the attributes and add it to the class.""" # Convert attributes to arguments with kw_only arguments at the end of # the argument list @@ -749,19 +798,20 @@ def _add_init(ctx: 'mypy.plugin.ClassDefContext', attributes: List[Attribute], for a in args: a.variable.type = AnyType(TypeOfAny.implementation_artifact) a.type_annotation = AnyType(TypeOfAny.implementation_artifact) - adder.add_method('__init__', args, NoneType()) + adder.add_method("__init__", args, NoneType()) -def _add_attrs_magic_attribute(ctx: 'mypy.plugin.ClassDefContext', - attrs: 'List[Tuple[str, Optional[Type]]]') -> None: +def _add_attrs_magic_attribute( + ctx: "mypy.plugin.ClassDefContext", attrs: "List[Tuple[str, Optional[Type]]]" +) -> None: any_type = AnyType(TypeOfAny.explicit) - attributes_types: 'List[Type]' = [ - ctx.api.named_type_or_none('attr.Attribute', [attr_type or any_type]) or any_type + attributes_types: "List[Type]" = [ + ctx.api.named_type_or_none("attr.Attribute", [attr_type or any_type]) or any_type for _, attr_type in attrs ] - fallback_type = ctx.api.named_type('builtins.tuple', [ - ctx.api.named_type_or_none('attr.Attribute', [any_type]) or any_type, - ]) + fallback_type = ctx.api.named_type( + "builtins.tuple", [ctx.api.named_type_or_none("attr.Attribute", [any_type]) or any_type] + ) ti = ctx.api.basic_new_typeinfo(MAGIC_ATTR_CLS_NAME, fallback_type, 0) ti.is_named_tuple = True @@ -781,40 +831,30 @@ def _add_attrs_magic_attribute(ctx: 'mypy.plugin.ClassDefContext', var._fullname = f"{ctx.cls.fullname}.{MAGIC_ATTR_CLS_NAME}" var.allow_incompatible_override = True ctx.cls.info.names[MAGIC_ATTR_NAME] = SymbolTableNode( - kind=MDEF, - node=var, - plugin_generated=True, - no_serialize=True, + kind=MDEF, node=var, plugin_generated=True, no_serialize=True ) -def _add_slots(ctx: 'mypy.plugin.ClassDefContext', - attributes: List[Attribute]) -> None: +def _add_slots(ctx: "mypy.plugin.ClassDefContext", attributes: List[Attribute]) -> None: # Unlike `@dataclasses.dataclass`, `__slots__` is rewritten here. ctx.cls.info.slots = {attr.name for attr in attributes} -def _add_match_args(ctx: 'mypy.plugin.ClassDefContext', - attributes: List[Attribute]) -> None: - if ('__match_args__' not in ctx.cls.info.names - or ctx.cls.info.names['__match_args__'].plugin_generated): - str_type = ctx.api.named_type('builtins.str') +def _add_match_args(ctx: "mypy.plugin.ClassDefContext", attributes: List[Attribute]) -> None: + if ( + "__match_args__" not in ctx.cls.info.names + or ctx.cls.info.names["__match_args__"].plugin_generated + ): + str_type = ctx.api.named_type("builtins.str") match_args = TupleType( [ - str_type.copy_modified( - last_known_value=LiteralType(attr.name, fallback=str_type), - ) + str_type.copy_modified(last_known_value=LiteralType(attr.name, fallback=str_type)) for attr in attributes if not attr.kw_only and attr.init ], - fallback=ctx.api.named_type('builtins.tuple'), - ) - add_attribute_to_class( - api=ctx.api, - cls=ctx.cls, - name='__match_args__', - typ=match_args, + fallback=ctx.api.named_type("builtins.tuple"), ) + add_attribute_to_class(api=ctx.api, cls=ctx.cls, name="__match_args__", typ=match_args) class MethodAdder: @@ -825,14 +865,18 @@ class MethodAdder: # TODO: Combine this with the code build_namedtuple_typeinfo to support both. - def __init__(self, ctx: 'mypy.plugin.ClassDefContext') -> None: + def __init__(self, ctx: "mypy.plugin.ClassDefContext") -> None: self.ctx = ctx self.self_type = fill_typevars(ctx.cls.info) - def add_method(self, - method_name: str, args: List[Argument], ret_type: Type, - self_type: Optional[Type] = None, - tvd: Optional[TypeVarType] = None) -> None: + def add_method( + self, + method_name: str, + args: List[Argument], + ret_type: Type, + self_type: Optional[Type] = None, + tvd: Optional[TypeVarType] = None, + ) -> None: """Add a method: def (self, ) -> ): ... to info. self_type: The type to use for the self argument or None to use the inferred self type. diff --git a/mypy/plugins/common.py b/mypy/plugins/common.py index 985a3f0fa6c7b..6832b5554410a 100644 --- a/mypy/plugins/common.py +++ b/mypy/plugins/common.py @@ -1,25 +1,38 @@ from typing import List, Optional, Union +from mypy.fixup import TypeFixer from mypy.nodes import ( - ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES, - FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict, + ARG_POS, + MDEF, + SYMBOL_FUNCBASE_TYPES, + Argument, + Block, + CallExpr, + ClassDef, + Expression, + FuncDef, + JsonDict, + PassStmt, + RefExpr, + SymbolTableNode, + Var, ) from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface -from mypy.semanal import set_callable_name, ALLOW_INCOMPATIBLE_OVERRIDE +from mypy.semanal import ALLOW_INCOMPATIBLE_OVERRIDE, set_callable_name +from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API from mypy.types import ( - CallableType, Overloaded, Type, TypeVarType, deserialize_type, get_proper_type, + CallableType, + Overloaded, + Type, + TypeVarType, + deserialize_type, + get_proper_type, ) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name -from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API -from mypy.fixup import TypeFixer -def _get_decorator_bool_argument( - ctx: ClassDefContext, - name: str, - default: bool, -) -> bool: +def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool: """Return the bool argument for the decorator. This handles both @decorator(...) and @decorator. @@ -30,8 +43,7 @@ def _get_decorator_bool_argument( return default -def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, - name: str, default: bool) -> bool: +def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool: """Return the boolean value for an argument to a call or the default if it's not found. """ @@ -57,8 +69,7 @@ def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: callee_type = None callee_node = call.callee.node - if (isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) - and callee_node.type): + if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type: callee_node_type = get_proper_type(callee_node.type) if isinstance(callee_node_type, Overloaded): # We take the last overload. @@ -83,33 +94,36 @@ def _get_argument(call: CallExpr, name: str) -> Optional[Expression]: def add_method( - ctx: ClassDefContext, - name: str, - args: List[Argument], - return_type: Type, - self_type: Optional[Type] = None, - tvar_def: Optional[TypeVarType] = None, + ctx: ClassDefContext, + name: str, + args: List[Argument], + return_type: Type, + self_type: Optional[Type] = None, + tvar_def: Optional[TypeVarType] = None, ) -> None: """ Adds a new method to a class. Deprecated, use add_method_to_class() instead. """ - add_method_to_class(ctx.api, ctx.cls, - name=name, - args=args, - return_type=return_type, - self_type=self_type, - tvar_def=tvar_def) + add_method_to_class( + ctx.api, + ctx.cls, + name=name, + args=args, + return_type=return_type, + self_type=self_type, + tvar_def=tvar_def, + ) def add_method_to_class( - api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], - cls: ClassDef, - name: str, - args: List[Argument], - return_type: Type, - self_type: Optional[Type] = None, - tvar_def: Optional[TypeVarType] = None, + api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface], + cls: ClassDef, + name: str, + args: List[Argument], + return_type: Type, + self_type: Optional[Type] = None, + tvar_def: Optional[TypeVarType] = None, ) -> None: """Adds a new method to a class definition.""" info = cls.info @@ -123,14 +137,14 @@ def add_method_to_class( self_type = self_type or fill_typevars(info) if isinstance(api, SemanticAnalyzerPluginInterface): - function_type = api.named_type('builtins.function') + function_type = api.named_type("builtins.function") else: - function_type = api.named_generic_type('builtins.function', []) + function_type = api.named_generic_type("builtins.function", []) - args = [Argument(Var('self'), self_type, None, ARG_POS)] + args + args = [Argument(Var("self"), self_type, None, ARG_POS)] + args arg_types, arg_names, arg_kinds = [], [], [] for arg in args: - assert arg.type_annotation, 'All arguments must be fully typed.' + assert arg.type_annotation, "All arguments must be fully typed." arg_types.append(arg.type_annotation) arg_names.append(arg.variable.name) arg_kinds.append(arg.kind) @@ -142,7 +156,7 @@ def add_method_to_class( func = FuncDef(name, args, Block([PassStmt()])) func.info = info func.type = set_callable_name(signature, func) - func._fullname = info.fullname + '.' + name + func._fullname = info.fullname + "." + name func.line = info.line # NOTE: we would like the plugin generated node to dominate, but we still @@ -157,13 +171,13 @@ def add_method_to_class( def add_attribute_to_class( - api: SemanticAnalyzerPluginInterface, - cls: ClassDef, - name: str, - typ: Type, - final: bool = False, - no_serialize: bool = False, - override_allow_incompatible: bool = False, + api: SemanticAnalyzerPluginInterface, + cls: ClassDef, + name: str, + typ: Type, + final: bool = False, + no_serialize: bool = False, + override_allow_incompatible: bool = False, ) -> None: """ Adds a new attribute to a class definition. @@ -185,12 +199,9 @@ def add_attribute_to_class( node.allow_incompatible_override = True else: node.allow_incompatible_override = override_allow_incompatible - node._fullname = info.fullname + '.' + name + node._fullname = info.fullname + "." + name info.names[name] = SymbolTableNode( - MDEF, - node, - plugin_generated=True, - no_serialize=no_serialize, + MDEF, node, plugin_generated=True, no_serialize=no_serialize ) diff --git a/mypy/plugins/ctypes.py b/mypy/plugins/ctypes.py index 87ffcdfe33398..b2a12cc7ba1a1 100644 --- a/mypy/plugins/ctypes.py +++ b/mypy/plugins/ctypes.py @@ -8,46 +8,59 @@ from mypy.maptype import map_instance_to_supertype from mypy.messages import format_type from mypy.subtypes import is_subtype +from mypy.typeops import make_simplified_union from mypy.types import ( - AnyType, CallableType, Instance, NoneType, Type, TypeOfAny, UnionType, - union_items, ProperType, get_proper_type + AnyType, + CallableType, + Instance, + NoneType, + ProperType, + Type, + TypeOfAny, + UnionType, + get_proper_type, + union_items, ) -from mypy.typeops import make_simplified_union -def _get_bytes_type(api: 'mypy.plugin.CheckerPluginInterface') -> Instance: +def _get_bytes_type(api: "mypy.plugin.CheckerPluginInterface") -> Instance: """Return the type corresponding to bytes on the current Python version. This is bytes in Python 3, and str in Python 2. """ return api.named_generic_type( - 'builtins.bytes' if api.options.python_version >= (3,) else 'builtins.str', []) + "builtins.bytes" if api.options.python_version >= (3,) else "builtins.str", [] + ) -def _get_text_type(api: 'mypy.plugin.CheckerPluginInterface') -> Instance: +def _get_text_type(api: "mypy.plugin.CheckerPluginInterface") -> Instance: """Return the type corresponding to Text on the current Python version. This is str in Python 3, and unicode in Python 2. """ return api.named_generic_type( - 'builtins.str' if api.options.python_version >= (3,) else 'builtins.unicode', []) + "builtins.str" if api.options.python_version >= (3,) else "builtins.unicode", [] + ) -def _find_simplecdata_base_arg(tp: Instance, api: 'mypy.plugin.CheckerPluginInterface' - ) -> Optional[ProperType]: +def _find_simplecdata_base_arg( + tp: Instance, api: "mypy.plugin.CheckerPluginInterface" +) -> Optional[ProperType]: """Try to find a parametrized _SimpleCData in tp's bases and return its single type argument. None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases. """ - if tp.type.has_base('ctypes._SimpleCData'): - simplecdata_base = map_instance_to_supertype(tp, - api.named_generic_type('ctypes._SimpleCData', [AnyType(TypeOfAny.special_form)]).type) - assert len(simplecdata_base.args) == 1, '_SimpleCData takes exactly one type argument' + if tp.type.has_base("ctypes._SimpleCData"): + simplecdata_base = map_instance_to_supertype( + tp, + api.named_generic_type("ctypes._SimpleCData", [AnyType(TypeOfAny.special_form)]).type, + ) + assert len(simplecdata_base.args) == 1, "_SimpleCData takes exactly one type argument" return get_proper_type(simplecdata_base.args[0]) return None -def _autoconvertible_to_cdata(tp: Type, api: 'mypy.plugin.CheckerPluginInterface') -> Type: +def _autoconvertible_to_cdata(tp: Type, api: "mypy.plugin.CheckerPluginInterface") -> Type: """Get a type that is compatible with all types that can be implicitly converted to the given CData type. @@ -72,10 +85,10 @@ def _autoconvertible_to_cdata(tp: Type, api: 'mypy.plugin.CheckerPluginInterface # the original "boxed" type. allowed_types.append(unboxed) - if t.type.has_base('ctypes._PointerLike'): + if t.type.has_base("ctypes._PointerLike"): # Pointer-like _SimpleCData subclasses can also be converted from # an int or None. - allowed_types.append(api.named_generic_type('builtins.int', [])) + allowed_types.append(api.named_generic_type("builtins.int", [])) allowed_types.append(NoneType()) return make_simplified_union(allowed_types) @@ -94,7 +107,7 @@ def _autounboxed_cdata(tp: Type) -> ProperType: return make_simplified_union([_autounboxed_cdata(t) for t in tp.items]) elif isinstance(tp, Instance): for base in tp.type.bases: - if base.type.fullname == 'ctypes._SimpleCData': + if base.type.fullname == "ctypes._SimpleCData": # If tp has _SimpleCData as a direct base class, # the auto-unboxed type is the single type argument of the _SimpleCData type. assert len(base.args) == 1 @@ -108,58 +121,67 @@ def _get_array_element_type(tp: Type) -> Optional[ProperType]: """Get the element type of the Array type tp, or None if not specified.""" tp = get_proper_type(tp) if isinstance(tp, Instance): - assert tp.type.fullname == 'ctypes.Array' + assert tp.type.fullname == "ctypes.Array" if len(tp.args) == 1: return get_proper_type(tp.args[0]) return None -def array_constructor_callback(ctx: 'mypy.plugin.FunctionContext') -> Type: +def array_constructor_callback(ctx: "mypy.plugin.FunctionContext") -> Type: """Callback to provide an accurate signature for the ctypes.Array constructor.""" # Extract the element type from the constructor's return type, i. e. the type of the array # being constructed. et = _get_array_element_type(ctx.default_return_type) if et is not None: allowed = _autoconvertible_to_cdata(et, ctx.api) - assert len(ctx.arg_types) == 1, \ - "The stub of the ctypes.Array constructor should have a single vararg parameter" + assert ( + len(ctx.arg_types) == 1 + ), "The stub of the ctypes.Array constructor should have a single vararg parameter" for arg_num, (arg_kind, arg_type) in enumerate(zip(ctx.arg_kinds[0], ctx.arg_types[0]), 1): if arg_kind == nodes.ARG_POS and not is_subtype(arg_type, allowed): ctx.api.msg.fail( - 'Array constructor argument {} of type {}' - ' is not convertible to the array element type {}' - .format(arg_num, format_type(arg_type), format_type(et)), ctx.context) + "Array constructor argument {} of type {}" + " is not convertible to the array element type {}".format( + arg_num, format_type(arg_type), format_type(et) + ), + ctx.context, + ) elif arg_kind == nodes.ARG_STAR: ty = ctx.api.named_generic_type("typing.Iterable", [allowed]) if not is_subtype(arg_type, ty): it = ctx.api.named_generic_type("typing.Iterable", [et]) ctx.api.msg.fail( - 'Array constructor argument {} of type {}' - ' is not convertible to the array element type {}' - .format(arg_num, format_type(arg_type), format_type(it)), ctx.context) + "Array constructor argument {} of type {}" + " is not convertible to the array element type {}".format( + arg_num, format_type(arg_type), format_type(it) + ), + ctx.context, + ) return ctx.default_return_type -def array_getitem_callback(ctx: 'mypy.plugin.MethodContext') -> Type: +def array_getitem_callback(ctx: "mypy.plugin.MethodContext") -> Type: """Callback to provide an accurate return type for ctypes.Array.__getitem__.""" et = _get_array_element_type(ctx.type) if et is not None: unboxed = _autounboxed_cdata(et) - assert len(ctx.arg_types) == 1, \ - 'The stub of ctypes.Array.__getitem__ should have exactly one parameter' - assert len(ctx.arg_types[0]) == 1, \ - "ctypes.Array.__getitem__'s parameter should not be variadic" + assert ( + len(ctx.arg_types) == 1 + ), "The stub of ctypes.Array.__getitem__ should have exactly one parameter" + assert ( + len(ctx.arg_types[0]) == 1 + ), "ctypes.Array.__getitem__'s parameter should not be variadic" index_type = get_proper_type(ctx.arg_types[0][0]) if isinstance(index_type, Instance): - if index_type.type.has_base('builtins.int'): + if index_type.type.has_base("builtins.int"): return unboxed - elif index_type.type.has_base('builtins.slice'): - return ctx.api.named_generic_type('builtins.list', [unboxed]) + elif index_type.type.has_base("builtins.slice"): + return ctx.api.named_generic_type("builtins.list", [unboxed]) return ctx.default_return_type -def array_setitem_callback(ctx: 'mypy.plugin.MethodSigContext') -> CallableType: +def array_setitem_callback(ctx: "mypy.plugin.MethodSigContext") -> CallableType: """Callback to provide an accurate signature for ctypes.Array.__setitem__.""" et = _get_array_element_type(ctx.type) if et is not None: @@ -168,29 +190,29 @@ def array_setitem_callback(ctx: 'mypy.plugin.MethodSigContext') -> CallableType: index_type = get_proper_type(ctx.default_signature.arg_types[0]) if isinstance(index_type, Instance): arg_type = None - if index_type.type.has_base('builtins.int'): + if index_type.type.has_base("builtins.int"): arg_type = allowed - elif index_type.type.has_base('builtins.slice'): - arg_type = ctx.api.named_generic_type('builtins.list', [allowed]) + elif index_type.type.has_base("builtins.slice"): + arg_type = ctx.api.named_generic_type("builtins.list", [allowed]) if arg_type is not None: # Note: arg_type can only be None if index_type is invalid, in which case we use # the default signature and let mypy report an error about it. return ctx.default_signature.copy_modified( - arg_types=ctx.default_signature.arg_types[:1] + [arg_type], + arg_types=ctx.default_signature.arg_types[:1] + [arg_type] ) return ctx.default_signature -def array_iter_callback(ctx: 'mypy.plugin.MethodContext') -> Type: +def array_iter_callback(ctx: "mypy.plugin.MethodContext") -> Type: """Callback to provide an accurate return type for ctypes.Array.__iter__.""" et = _get_array_element_type(ctx.type) if et is not None: unboxed = _autounboxed_cdata(et) - return ctx.api.named_generic_type('typing.Iterator', [unboxed]) + return ctx.api.named_generic_type("typing.Iterator", [unboxed]) return ctx.default_return_type -def array_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: +def array_value_callback(ctx: "mypy.plugin.AttributeContext") -> Type: """Callback to provide an accurate type for ctypes.Array.value.""" et = _get_array_element_type(ctx.type) if et is not None: @@ -198,32 +220,37 @@ def array_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: for tp in union_items(et): if isinstance(tp, AnyType): types.append(AnyType(TypeOfAny.from_another_any, source_any=tp)) - elif isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_char': + elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_char": types.append(_get_bytes_type(ctx.api)) - elif isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_wchar': + elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_wchar": types.append(_get_text_type(ctx.api)) else: ctx.api.msg.fail( 'Array attribute "value" is only available' - ' with element type "c_char" or "c_wchar", not {}' - .format(format_type(et)), ctx.context) + ' with element type "c_char" or "c_wchar", not {}'.format(format_type(et)), + ctx.context, + ) return make_simplified_union(types) return ctx.default_attr_type -def array_raw_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: +def array_raw_callback(ctx: "mypy.plugin.AttributeContext") -> Type: """Callback to provide an accurate type for ctypes.Array.raw.""" et = _get_array_element_type(ctx.type) if et is not None: types: List[Type] = [] for tp in union_items(et): - if (isinstance(tp, AnyType) - or isinstance(tp, Instance) and tp.type.fullname == 'ctypes.c_char'): + if ( + isinstance(tp, AnyType) + or isinstance(tp, Instance) + and tp.type.fullname == "ctypes.c_char" + ): types.append(_get_bytes_type(ctx.api)) else: ctx.api.msg.fail( 'Array attribute "raw" is only available' - ' with element type "c_char", not {}' - .format(format_type(et)), ctx.context) + ' with element type "c_char", not {}'.format(format_type(et)), + ctx.context, + ) return make_simplified_union(types) return ctx.default_attr_type diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 87b42a499a1c0..f360d1577b144 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -1,35 +1,60 @@ """Plugin that provides support for dataclasses.""" -from typing import Dict, List, Set, Tuple, Optional +from typing import Dict, List, Optional, Set, Tuple + from typing_extensions import Final from mypy.nodes import ( - ARG_OPT, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_STAR, ARG_STAR2, MDEF, - Argument, AssignmentStmt, CallExpr, TypeAlias, Context, Expression, JsonDict, - NameExpr, RefExpr, SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, - PlaceholderNode + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + MDEF, + Argument, + AssignmentStmt, + CallExpr, + Context, + Expression, + JsonDict, + NameExpr, + PlaceholderNode, + RefExpr, + SymbolTableNode, + TempNode, + TypeAlias, + TypeInfo, + TypeVarExpr, + Var, ) from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( - add_method, _get_decorator_bool_argument, deserialize_and_fixup_type, add_attribute_to_class, + _get_decorator_bool_argument, + add_attribute_to_class, + add_method, + deserialize_and_fixup_type, ) +from mypy.server.trigger import make_wildcard_trigger +from mypy.state import state from mypy.typeops import map_type_from_supertype from mypy.types import ( - Type, Instance, NoneType, TypeVarType, CallableType, TupleType, LiteralType, - get_proper_type, AnyType, TypeOfAny, + AnyType, + CallableType, + Instance, + LiteralType, + NoneType, + TupleType, + Type, + TypeOfAny, + TypeVarType, + get_proper_type, ) -from mypy.server.trigger import make_wildcard_trigger -from mypy.state import state # The set of decorators that generate dataclasses. -dataclass_makers: Final = { - 'dataclass', - 'dataclasses.dataclass', -} +dataclass_makers: Final = {"dataclass", "dataclasses.dataclass"} # The set of functions that generate dataclass fields. -field_makers: Final = { - 'dataclasses.field', -} +field_makers: Final = {"dataclasses.field"} SELF_TVAR_NAME: Final = "_DT" @@ -37,16 +62,16 @@ class DataclassAttribute: def __init__( - self, - name: str, - is_in_init: bool, - is_init_var: bool, - has_default: bool, - line: int, - column: int, - type: Optional[Type], - info: TypeInfo, - kw_only: bool, + self, + name: str, + is_in_init: bool, + is_init_var: bool, + has_default: bool, + line: int, + column: int, + type: Optional[Type], + info: TypeInfo, + kw_only: bool, ) -> None: self.name = name self.is_in_init = is_in_init @@ -67,10 +92,7 @@ def to_argument(self) -> Argument: elif not self.kw_only and self.has_default: arg_kind = ARG_OPT return Argument( - variable=self.to_var(), - type_annotation=self.type, - initializer=None, - kind=arg_kind, + variable=self.to_var(), type_annotation=self.type, initializer=None, kind=arg_kind ) def to_var(self) -> Var: @@ -79,24 +101,24 @@ def to_var(self) -> Var: def serialize(self) -> JsonDict: assert self.type return { - 'name': self.name, - 'is_in_init': self.is_in_init, - 'is_init_var': self.is_init_var, - 'has_default': self.has_default, - 'line': self.line, - 'column': self.column, - 'type': self.type.serialize(), - 'kw_only': self.kw_only, + "name": self.name, + "is_in_init": self.is_in_init, + "is_init_var": self.is_init_var, + "has_default": self.has_default, + "line": self.line, + "column": self.column, + "type": self.type.serialize(), + "kw_only": self.kw_only, } @classmethod def deserialize( cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface - ) -> 'DataclassAttribute': + ) -> "DataclassAttribute": data = data.copy() - if data.get('kw_only') is None: - data['kw_only'] = False - typ = deserialize_and_fixup_type(data.pop('type'), api) + if data.get("kw_only") is None: + data["kw_only"] = False + typ = deserialize_and_fixup_type(data.pop("type"), api) return cls(type=typ, info=info, **data) def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: @@ -134,12 +156,12 @@ def transform(self) -> bool: if attr.type is None: return False decorator_arguments = { - 'init': _get_decorator_bool_argument(self._ctx, 'init', True), - 'eq': _get_decorator_bool_argument(self._ctx, 'eq', True), - 'order': _get_decorator_bool_argument(self._ctx, 'order', False), - 'frozen': _get_decorator_bool_argument(self._ctx, 'frozen', False), - 'slots': _get_decorator_bool_argument(self._ctx, 'slots', False), - 'match_args': _get_decorator_bool_argument(self._ctx, 'match_args', True), + "init": _get_decorator_bool_argument(self._ctx, "init", True), + "eq": _get_decorator_bool_argument(self._ctx, "eq", True), + "order": _get_decorator_bool_argument(self._ctx, "order", False), + "frozen": _get_decorator_bool_argument(self._ctx, "frozen", False), + "slots": _get_decorator_bool_argument(self._ctx, "slots", False), + "match_args": _get_decorator_bool_argument(self._ctx, "match_args", True), } py_version = self._ctx.api.options.python_version @@ -147,12 +169,17 @@ def transform(self) -> bool: # processed them yet. In order to work around this, we can simply skip generating # __init__ if there are no attributes, because if the user truly did not define any, # then the object default __init__ with an empty signature will be present anyway. - if (decorator_arguments['init'] and - ('__init__' not in info.names or info.names['__init__'].plugin_generated) and - attributes): - - args = [attr.to_argument() for attr in attributes if attr.is_in_init - and not self._is_kw_only_type(attr.type)] + if ( + decorator_arguments["init"] + and ("__init__" not in info.names or info.names["__init__"].plugin_generated) + and attributes + ): + + args = [ + attr.to_argument() + for attr in attributes + if attr.is_in_init and not self._is_kw_only_type(attr.type) + ] if info.fallback_to_any: # Make positional args optional since we don't know their order. @@ -162,48 +189,49 @@ def transform(self) -> bool: if arg.kind == ARG_POS: arg.kind = ARG_OPT - nameless_var = Var('') - args = [Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR), - *args, - Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR2), - ] - - add_method( - ctx, - '__init__', - args=args, - return_type=NoneType(), - ) + nameless_var = Var("") + args = [ + Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR), + *args, + Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR2), + ] - if (decorator_arguments['eq'] and info.get('__eq__') is None or - decorator_arguments['order']): + add_method(ctx, "__init__", args=args, return_type=NoneType()) + + if ( + decorator_arguments["eq"] + and info.get("__eq__") is None + or decorator_arguments["order"] + ): # Type variable for self types in generated methods. - obj_type = ctx.api.named_type('builtins.object') - self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, - [], obj_type) + obj_type = ctx.api.named_type("builtins.object") + self_tvar_expr = TypeVarExpr( + SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type + ) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) # Add <, >, <=, >=, but only if the class has an eq method. - if decorator_arguments['order']: - if not decorator_arguments['eq']: - ctx.api.fail('eq must be True if order is True', ctx.cls) + if decorator_arguments["order"]: + if not decorator_arguments["eq"]: + ctx.api.fail("eq must be True if order is True", ctx.cls) - for method_name in ['__lt__', '__gt__', '__le__', '__ge__']: + for method_name in ["__lt__", "__gt__", "__le__", "__ge__"]: # Like for __eq__ and __ne__, we want "other" to match # the self type. - obj_type = ctx.api.named_type('builtins.object') - order_tvar_def = TypeVarType(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, - -1, [], obj_type) - order_return_type = ctx.api.named_type('builtins.bool') + obj_type = ctx.api.named_type("builtins.object") + order_tvar_def = TypeVarType( + SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type + ) + order_return_type = ctx.api.named_type("builtins.bool") order_args = [ - Argument(Var('other', order_tvar_def), order_tvar_def, None, ARG_POS) + Argument(Var("other", order_tvar_def), order_tvar_def, None, ARG_POS) ] existing_method = info.get(method_name) if existing_method is not None and not existing_method.plugin_generated: assert existing_method.node ctx.api.fail( - f'You may not have a custom {method_name} method when order=True', + f"You may not have a custom {method_name} method when order=True", existing_method.node, ) @@ -216,62 +244,65 @@ def transform(self) -> bool: tvar_def=order_tvar_def, ) - if decorator_arguments['frozen']: + if decorator_arguments["frozen"]: self._propertize_callables(attributes, settable=False) self._freeze(attributes) else: self._propertize_callables(attributes) - if decorator_arguments['slots']: + if decorator_arguments["slots"]: self.add_slots(info, attributes, correct_version=py_version >= (3, 10)) self.reset_init_only_vars(info, attributes) - if (decorator_arguments['match_args'] and - ('__match_args__' not in info.names or - info.names['__match_args__'].plugin_generated) and - attributes and - py_version >= (3, 10)): + if ( + decorator_arguments["match_args"] + and ( + "__match_args__" not in info.names or info.names["__match_args__"].plugin_generated + ) + and attributes + and py_version >= (3, 10) + ): str_type = ctx.api.named_type("builtins.str") - literals: List[Type] = [LiteralType(attr.name, str_type) - for attr in attributes if attr.is_in_init] + literals: List[Type] = [ + LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init + ] match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple")) add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type) self._add_dataclass_fields_magic_attribute() - info.metadata['dataclass'] = { - 'attributes': [attr.serialize() for attr in attributes], - 'frozen': decorator_arguments['frozen'], + info.metadata["dataclass"] = { + "attributes": [attr.serialize() for attr in attributes], + "frozen": decorator_arguments["frozen"], } return True - def add_slots(self, - info: TypeInfo, - attributes: List[DataclassAttribute], - *, - correct_version: bool) -> None: + def add_slots( + self, info: TypeInfo, attributes: List[DataclassAttribute], *, correct_version: bool + ) -> None: if not correct_version: # This means that version is lower than `3.10`, # it is just a non-existent argument for `dataclass` function. self._ctx.api.fail( 'Keyword argument "slots" for "dataclass" ' - 'is only valid in Python 3.10 and higher', + "is only valid in Python 3.10 and higher", self._ctx.reason, ) return generated_slots = {attr.name for attr in attributes} - if ((info.slots is not None and info.slots != generated_slots) - or info.names.get('__slots__')): + if (info.slots is not None and info.slots != generated_slots) or info.names.get( + "__slots__" + ): # This means we have a slots conflict. # Class explicitly specifies a different `__slots__` field. # And `@dataclass(slots=True)` is used. # In runtime this raises a type error. self._ctx.api.fail( '"{}" both defines "__slots__" and is used with "slots=True"'.format( - self._ctx.cls.name, + self._ctx.cls.name ), self._ctx.cls, ) @@ -314,7 +345,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: cls = self._ctx.cls attrs: List[DataclassAttribute] = [] known_attrs: Set[str] = set() - kw_only = _get_decorator_bool_argument(ctx, 'kw_only', False) + kw_only = _get_decorator_bool_argument(ctx, "kw_only", False) for stmt in cls.defs.body: # Any assignment that doesn't use the new type declaration # syntax can be ignored out of hand. @@ -337,11 +368,8 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: if isinstance(node, TypeAlias): ctx.api.fail( - ( - 'Type aliases inside dataclass definitions ' - 'are not supported at runtime' - ), - node + ("Type aliases inside dataclass definitions " "are not supported at runtime"), + node, ) # Skip processing this node. This doesn't match the runtime behaviour, # but the only alternative would be to modify the SymbolTable, @@ -357,8 +385,10 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: # x: InitVar[int] is turned into x: int and is removed from the class. is_init_var = False node_type = get_proper_type(node.type) - if (isinstance(node_type, Instance) and - node_type.type.fullname == 'dataclasses.InitVar'): + if ( + isinstance(node_type, Instance) + and node_type.type.fullname == "dataclasses.InitVar" + ): is_init_var = True node.type = node_type.args[0] @@ -367,7 +397,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx) - is_in_init_param = field_args.get('init') + is_in_init_param = field_args.get("init") if is_in_init_param is None: is_in_init = True else: @@ -377,7 +407,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: # Ensure that something like x: int = field() is rejected # after an attribute with a default. if has_field_call: - has_default = 'default' in field_args or 'default_factory' in field_args + has_default = "default" in field_args or "default_factory" in field_args # All other assignments are already type checked. elif not isinstance(stmt.rvalue, TempNode): @@ -391,22 +421,24 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: is_kw_only = kw_only # Use the kw_only field arg if it is provided. Otherwise use the # kw_only value from the decorator parameter. - field_kw_only_param = field_args.get('kw_only') + field_kw_only_param = field_args.get("kw_only") if field_kw_only_param is not None: is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param)) known_attrs.add(lhs.name) - attrs.append(DataclassAttribute( - name=lhs.name, - is_in_init=is_in_init, - is_init_var=is_init_var, - has_default=has_default, - line=stmt.line, - column=stmt.column, - type=sym.type, - info=cls.info, - kw_only=is_kw_only, - )) + attrs.append( + DataclassAttribute( + name=lhs.name, + is_in_init=is_in_init, + is_init_var=is_init_var, + has_default=has_default, + line=stmt.line, + column=stmt.column, + type=sym.type, + info=cls.info, + kw_only=is_kw_only, + ) + ) # Next, collect attributes belonging to any class in the MRO # as long as those attributes weren't already collected. This @@ -415,10 +447,10 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: # we'll have unmodified attrs laying around. all_attrs = attrs.copy() for info in cls.info.mro[1:-1]: - if 'dataclass_tag' in info.metadata and 'dataclass' not in info.metadata: + if "dataclass_tag" in info.metadata and "dataclass" not in info.metadata: # We haven't processed the base class yet. Need another pass. return None - if 'dataclass' not in info.metadata: + if "dataclass" not in info.metadata: continue super_attrs = [] @@ -461,21 +493,15 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]: if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only: # If the issue comes from merging different classes, report it # at the class definition point. - context = (Context(line=attr.line, column=attr.column) if attr in attrs - else ctx.cls) + context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls ctx.api.fail( - 'Attributes without a default cannot follow attributes with one', - context, + "Attributes without a default cannot follow attributes with one", context ) found_default = found_default or (attr.has_default and attr.is_in_init) if found_kw_sentinel and self._is_kw_only_type(attr.type): - context = (Context(line=attr.line, column=attr.column) if attr in attrs - else ctx.cls) - ctx.api.fail( - 'There may not be more than one field with the KW_ONLY type', - context, - ) + context = Context(line=attr.line, column=attr.column) if attr in attrs else ctx.cls + ctx.api.fail("There may not be more than one field with the KW_ONLY type", context) found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type) return all_attrs @@ -495,12 +521,12 @@ def _freeze(self, attributes: List[DataclassAttribute]) -> None: var = attr.to_var() var.info = info var.is_property = True - var._fullname = info.fullname + '.' + var.name + var._fullname = info.fullname + "." + var.name info.names[var.name] = SymbolTableNode(MDEF, var) - def _propertize_callables(self, - attributes: List[DataclassAttribute], - settable: bool = True) -> None: + def _propertize_callables( + self, attributes: List[DataclassAttribute], settable: bool = True + ) -> None: """Converts all attributes with callable types to @property methods. This avoids the typechecker getting confused and thinking that @@ -515,7 +541,7 @@ def _propertize_callables(self, var.info = info var.is_property = True var.is_settable_property = settable - var._fullname = info.fullname + '.' + var.name + var._fullname = info.fullname + "." + var.name info.names[var.name] = SymbolTableNode(MDEF, var) def _is_kw_only_type(self, node: Optional[Type]) -> bool: @@ -525,23 +551,20 @@ def _is_kw_only_type(self, node: Optional[Type]) -> bool: node_type = get_proper_type(node) if not isinstance(node_type, Instance): return False - return node_type.type.fullname == 'dataclasses.KW_ONLY' + return node_type.type.fullname == "dataclasses.KW_ONLY" def _add_dataclass_fields_magic_attribute(self) -> None: - attr_name = '__dataclass_fields__' + attr_name = "__dataclass_fields__" any_type = AnyType(TypeOfAny.explicit) - field_type = self._ctx.api.named_type_or_none('dataclasses.Field', [any_type]) or any_type - attr_type = self._ctx.api.named_type('builtins.dict', [ - self._ctx.api.named_type('builtins.str'), - field_type, - ]) + field_type = self._ctx.api.named_type_or_none("dataclasses.Field", [any_type]) or any_type + attr_type = self._ctx.api.named_type( + "builtins.dict", [self._ctx.api.named_type("builtins.str"), field_type] + ) var = Var(name=attr_name, type=attr_type) var.info = self._ctx.cls.info - var._fullname = self._ctx.cls.info.fullname + '.' + attr_name + var._fullname = self._ctx.cls.info.fullname + "." + attr_name self._ctx.cls.info.names[attr_name] = SymbolTableNode( - kind=MDEF, - node=var, - plugin_generated=True, + kind=MDEF, node=var, plugin_generated=True ) @@ -552,26 +575,26 @@ def dataclass_tag_callback(ctx: ClassDefContext) -> None: to detect dataclasses in base classes. """ # The value is ignored, only the existence matters. - ctx.cls.info.metadata['dataclass_tag'] = {} + ctx.cls.info.metadata["dataclass_tag"] = {} def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: - """Hooks into the class typechecking process to add support for dataclasses. - """ + """Hooks into the class typechecking process to add support for dataclasses.""" transformer = DataclassTransformer(ctx) return transformer.transform() -def _collect_field_args(expr: Expression, - ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]: +def _collect_field_args( + expr: Expression, ctx: ClassDefContext +) -> Tuple[bool, Dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to dataclass.field and the second is a dictionary of the keyword arguments that field() was called with. """ if ( - isinstance(expr, CallExpr) and - isinstance(expr.callee, RefExpr) and - expr.callee.fullname in field_makers + isinstance(expr, CallExpr) + and isinstance(expr.callee, RefExpr) + and expr.callee.fullname in field_makers ): # field() only takes keyword arguments. args = {} diff --git a/mypy/plugins/default.py b/mypy/plugins/default.py index 40997803aa7e3..699e355542546 100644 --- a/mypy/plugins/default.py +++ b/mypy/plugins/default.py @@ -1,75 +1,90 @@ from functools import partial -from typing import Callable, Optional, List +from typing import Callable, List, Optional from mypy import message_registry -from mypy.nodes import StrExpr, IntExpr, DictExpr, UnaryExpr +from mypy.checkexpr import is_literal_type_like +from mypy.nodes import DictExpr, IntExpr, StrExpr, UnaryExpr from mypy.plugin import ( - Plugin, FunctionContext, MethodContext, MethodSigContext, AttributeContext, ClassDefContext + AttributeContext, + ClassDefContext, + FunctionContext, + MethodContext, + MethodSigContext, + Plugin, ) from mypy.plugins.common import try_getting_str_literals -from mypy.types import ( - FunctionLike, Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType, - TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType, TupleType -) from mypy.subtypes import is_subtype from mypy.typeops import make_simplified_union -from mypy.checkexpr import is_literal_type_like +from mypy.types import ( + TPDICT_FB_NAMES, + AnyType, + CallableType, + FunctionLike, + Instance, + LiteralType, + NoneType, + TupleType, + Type, + TypedDictType, + TypeOfAny, + TypeVarType, + get_proper_type, +) class DefaultPlugin(Plugin): """Type checker plugin that is enabled by default.""" - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: from mypy.plugins import ctypes, singledispatch - if fullname in ('contextlib.contextmanager', 'contextlib.asynccontextmanager'): + if fullname in ("contextlib.contextmanager", "contextlib.asynccontextmanager"): return contextmanager_callback - elif fullname == 'ctypes.Array': + elif fullname == "ctypes.Array": return ctypes.array_constructor_callback - elif fullname == 'functools.singledispatch': + elif fullname == "functools.singledispatch": return singledispatch.create_singledispatch_function_callback return None - def get_method_signature_hook(self, fullname: str - ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: + def get_method_signature_hook( + self, fullname: str + ) -> Optional[Callable[[MethodSigContext], FunctionLike]]: from mypy.plugins import ctypes, singledispatch - if fullname == 'typing.Mapping.get': + if fullname == "typing.Mapping.get": return typed_dict_get_signature_callback - elif fullname in {n + '.setdefault' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}: return typed_dict_setdefault_signature_callback - elif fullname in {n + '.pop' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}: return typed_dict_pop_signature_callback - elif fullname in {n + '.update' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".update" for n in TPDICT_FB_NAMES}: return typed_dict_update_signature_callback - elif fullname == 'ctypes.Array.__setitem__': + elif fullname == "ctypes.Array.__setitem__": return ctypes.array_setitem_callback elif fullname == singledispatch.SINGLEDISPATCH_CALLABLE_CALL_METHOD: return singledispatch.call_singledispatch_function_callback return None - def get_method_hook(self, fullname: str - ) -> Optional[Callable[[MethodContext], Type]]: + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: from mypy.plugins import ctypes, singledispatch - if fullname == 'typing.Mapping.get': + if fullname == "typing.Mapping.get": return typed_dict_get_callback - elif fullname == 'builtins.int.__pow__': + elif fullname == "builtins.int.__pow__": return int_pow_callback - elif fullname == 'builtins.int.__neg__': + elif fullname == "builtins.int.__neg__": return int_neg_callback - elif fullname in ('builtins.tuple.__mul__', 'builtins.tuple.__rmul__'): + elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"): return tuple_mul_callback - elif fullname in {n + '.setdefault' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".setdefault" for n in TPDICT_FB_NAMES}: return typed_dict_setdefault_callback - elif fullname in {n + '.pop' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".pop" for n in TPDICT_FB_NAMES}: return typed_dict_pop_callback - elif fullname in {n + '.__delitem__' for n in TPDICT_FB_NAMES}: + elif fullname in {n + ".__delitem__" for n in TPDICT_FB_NAMES}: return typed_dict_delitem_callback - elif fullname == 'ctypes.Array.__getitem__': + elif fullname == "ctypes.Array.__getitem__": return ctypes.array_getitem_callback - elif fullname == 'ctypes.Array.__iter__': + elif fullname == "ctypes.Array.__iter__": return ctypes.array_iter_callback elif fullname == singledispatch.SINGLEDISPATCH_REGISTER_METHOD: return singledispatch.singledispatch_register_callback @@ -77,14 +92,12 @@ def get_method_hook(self, fullname: str return singledispatch.call_singledispatch_function_after_register_argument return None - def get_attribute_hook(self, fullname: str - ) -> Optional[Callable[[AttributeContext], Type]]: - from mypy.plugins import ctypes - from mypy.plugins import enums + def get_attribute_hook(self, fullname: str) -> Optional[Callable[[AttributeContext], Type]]: + from mypy.plugins import ctypes, enums - if fullname == 'ctypes.Array.value': + if fullname == "ctypes.Array.value": return ctypes.array_value_callback - elif fullname == 'ctypes.Array.raw': + elif fullname == "ctypes.Array.raw": return ctypes.array_raw_callback elif fullname in enums.ENUM_NAME_ACCESS: return enums.enum_name_callback @@ -92,10 +105,10 @@ def get_attribute_hook(self, fullname: str return enums.enum_value_callback return None - def get_class_decorator_hook(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], None]]: - from mypy.plugins import dataclasses - from mypy.plugins import attrs + def get_class_decorator_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + from mypy.plugins import attrs, dataclasses # These dataclass and attrs hooks run in the main semantic analysis pass # and only tag known dataclasses/attrs classes, so that the second @@ -103,19 +116,20 @@ def get_class_decorator_hook(self, fullname: str # in the MRO. if fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_tag_callback - if (fullname in attrs.attr_class_makers - or fullname in attrs.attr_dataclass_makers - or fullname in attrs.attr_frozen_makers - or fullname in attrs.attr_define_makers): + if ( + fullname in attrs.attr_class_makers + or fullname in attrs.attr_dataclass_makers + or fullname in attrs.attr_frozen_makers + or fullname in attrs.attr_define_makers + ): return attrs.attr_tag_callback return None - def get_class_decorator_hook_2(self, fullname: str - ) -> Optional[Callable[[ClassDefContext], bool]]: - from mypy.plugins import dataclasses - from mypy.plugins import functools - from mypy.plugins import attrs + def get_class_decorator_hook_2( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], bool]]: + from mypy.plugins import attrs, dataclasses, functools if fullname in dataclasses.dataclass_makers: return dataclasses.dataclass_class_maker_callback @@ -124,21 +138,13 @@ def get_class_decorator_hook_2(self, fullname: str elif fullname in attrs.attr_class_makers: return attrs.attr_class_maker_callback elif fullname in attrs.attr_dataclass_makers: - return partial( - attrs.attr_class_maker_callback, - auto_attribs_default=True, - ) + return partial(attrs.attr_class_maker_callback, auto_attribs_default=True) elif fullname in attrs.attr_frozen_makers: return partial( - attrs.attr_class_maker_callback, - auto_attribs_default=None, - frozen_default=True, + attrs.attr_class_maker_callback, auto_attribs_default=None, frozen_default=True ) elif fullname in attrs.attr_define_makers: - return partial( - attrs.attr_class_maker_callback, - auto_attribs_default=None, - ) + return partial(attrs.attr_class_maker_callback, auto_attribs_default=None) return None @@ -149,8 +155,7 @@ def contextmanager_callback(ctx: FunctionContext) -> Type: if ctx.arg_types and len(ctx.arg_types[0]) == 1: arg_type = get_proper_type(ctx.arg_types[0][0]) default_return = get_proper_type(ctx.default_return_type) - if (isinstance(arg_type, CallableType) - and isinstance(default_return, CallableType)): + if isinstance(arg_type, CallableType) and isinstance(default_return, CallableType): # The stub signature doesn't preserve information about arguments so # add them back here. return default_return.copy_modified( @@ -158,7 +163,8 @@ def contextmanager_callback(ctx: FunctionContext) -> Type: arg_kinds=arg_type.arg_kinds, arg_names=arg_type.arg_names, variables=arg_type.variables, - is_ellipsis_args=arg_type.is_ellipsis_args) + is_ellipsis_args=arg_type.is_ellipsis_args, + ) return ctx.default_return_type @@ -169,21 +175,25 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: depends on a TypedDict value type. """ signature = ctx.default_signature - if (isinstance(ctx.type, TypedDictType) - and len(ctx.args) == 2 - and len(ctx.args[0]) == 1 - and isinstance(ctx.args[0][0], StrExpr) - and len(signature.arg_types) == 2 - and len(signature.variables) == 1 - and len(ctx.args[1]) == 1): + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.args) == 2 + and len(ctx.args[0]) == 1 + and isinstance(ctx.args[0][0], StrExpr) + and len(signature.arg_types) == 2 + and len(signature.variables) == 1 + and len(ctx.args[1]) == 1 + ): key = ctx.args[0][0].value value_type = get_proper_type(ctx.type.items.get(key)) ret_type = signature.ret_type if value_type: default_arg = ctx.args[1][0] - if (isinstance(value_type, TypedDictType) - and isinstance(default_arg, DictExpr) - and len(default_arg.items) == 0): + if ( + isinstance(value_type, TypedDictType) + and isinstance(default_arg, DictExpr) + and len(default_arg.items) == 0 + ): # Caller has empty dict {} as default for typed dict. value_type = value_type.copy_modified(required_keys=set()) # Tweak the signature to include the value type as context. It's @@ -192,17 +202,19 @@ def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType: tv = signature.variables[0] assert isinstance(tv, TypeVarType) return signature.copy_modified( - arg_types=[signature.arg_types[0], - make_simplified_union([value_type, tv])], - ret_type=ret_type) + arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])], + ret_type=ret_type, + ) return signature def typed_dict_get_callback(ctx: MethodContext) -> Type: """Infer a precise return type for TypedDict.get with literal first argument.""" - if (isinstance(ctx.type, TypedDictType) - and len(ctx.arg_types) >= 1 - and len(ctx.arg_types[0]) == 1): + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) >= 1 + and len(ctx.arg_types[0]) == 1 + ): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: return ctx.default_return_type @@ -215,11 +227,13 @@ def typed_dict_get_callback(ctx: MethodContext) -> Type: if len(ctx.arg_types) == 1: output_types.append(value_type) - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): + elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: default_arg = ctx.args[1][0] - if (isinstance(default_arg, DictExpr) and len(default_arg.items) == 0 - and isinstance(value_type, TypedDictType)): + if ( + isinstance(default_arg, DictExpr) + and len(default_arg.items) == 0 + and isinstance(value_type, TypedDictType) + ): # Special case '{}' as the default for a typed dict type. output_types.append(value_type.copy_modified(required_keys=set())) else: @@ -240,14 +254,16 @@ def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: depends on a TypedDict value type. """ signature = ctx.default_signature - str_type = ctx.api.named_generic_type('builtins.str', []) - if (isinstance(ctx.type, TypedDictType) - and len(ctx.args) == 2 - and len(ctx.args[0]) == 1 - and isinstance(ctx.args[0][0], StrExpr) - and len(signature.arg_types) == 2 - and len(signature.variables) == 1 - and len(ctx.args[1]) == 1): + str_type = ctx.api.named_generic_type("builtins.str", []) + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.args) == 2 + and len(ctx.args[0]) == 1 + and isinstance(ctx.args[0][0], StrExpr) + and len(signature.arg_types) == 2 + and len(signature.variables) == 1 + and len(ctx.args[1]) == 1 + ): key = ctx.args[0][0].value value_type = ctx.type.items.get(key) if value_type: @@ -257,17 +273,17 @@ def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType: tv = signature.variables[0] assert isinstance(tv, TypeVarType) typ = make_simplified_union([value_type, tv]) - return signature.copy_modified( - arg_types=[str_type, typ], - ret_type=typ) + return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ) return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]]) def typed_dict_pop_callback(ctx: MethodContext) -> Type: """Type check and infer a precise return type for TypedDict.pop.""" - if (isinstance(ctx.type, TypedDictType) - and len(ctx.arg_types) >= 1 - and len(ctx.arg_types[0]) == 1): + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) >= 1 + and len(ctx.arg_types[0]) == 1 + ): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) @@ -287,8 +303,7 @@ def typed_dict_pop_callback(ctx: MethodContext) -> Type: if len(ctx.args[1]) == 0: return make_simplified_union(value_types) - elif (len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 - and len(ctx.args[1]) == 1): + elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1: return make_simplified_union([*value_types, ctx.arg_types[1][0]]) return ctx.default_return_type @@ -300,13 +315,15 @@ def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableT depends on a TypedDict value type. """ signature = ctx.default_signature - str_type = ctx.api.named_generic_type('builtins.str', []) - if (isinstance(ctx.type, TypedDictType) - and len(ctx.args) == 2 - and len(ctx.args[0]) == 1 - and isinstance(ctx.args[0][0], StrExpr) - and len(signature.arg_types) == 2 - and len(ctx.args[1]) == 1): + str_type = ctx.api.named_generic_type("builtins.str", []) + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.args) == 2 + and len(ctx.args[0]) == 1 + and isinstance(ctx.args[0][0], StrExpr) + and len(signature.arg_types) == 2 + and len(ctx.args[1]) == 1 + ): key = ctx.args[0][0].value value_type = ctx.type.items.get(key) if value_type: @@ -316,10 +333,12 @@ def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableT def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: """Type check TypedDict.setdefault and infer a precise return type.""" - if (isinstance(ctx.type, TypedDictType) - and len(ctx.arg_types) == 2 - and len(ctx.arg_types[0]) == 1 - and len(ctx.arg_types[1]) == 1): + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) == 2 + and len(ctx.arg_types[0]) == 1 + and len(ctx.arg_types[1]) == 1 + ): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) @@ -341,7 +360,8 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: # default can be assigned to all key-value pairs we're updating. if not is_subtype(default_type, value_type): ctx.api.msg.typeddict_setdefault_arguments_inconsistent( - default_type, value_type, ctx.context) + default_type, value_type, ctx.context + ) return AnyType(TypeOfAny.from_error) value_types.append(value_type) @@ -352,9 +372,11 @@ def typed_dict_setdefault_callback(ctx: MethodContext) -> Type: def typed_dict_delitem_callback(ctx: MethodContext) -> Type: """Type check TypedDict.__delitem__.""" - if (isinstance(ctx.type, TypedDictType) - and len(ctx.arg_types) == 1 - and len(ctx.arg_types[0]) == 1): + if ( + isinstance(ctx.type, TypedDictType) + and len(ctx.arg_types) == 1 + and len(ctx.arg_types[0]) == 1 + ): keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0]) if keys is None: ctx.api.fail(message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL, ctx.context) @@ -371,8 +393,7 @@ def typed_dict_delitem_callback(ctx: MethodContext) -> Type: def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType: """Try to infer a better signature type for TypedDict.update.""" signature = ctx.default_signature - if (isinstance(ctx.type, TypedDictType) - and len(signature.arg_types) == 1): + if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1: arg_type = get_proper_type(signature.arg_types[0]) assert isinstance(arg_type, TypedDictType) arg_type = arg_type.as_anonymous() @@ -385,20 +406,19 @@ def int_pow_callback(ctx: MethodContext) -> Type: """Infer a more precise return type for int.__pow__.""" # int.__pow__ has an optional modulo argument, # so we expect 2 argument positions - if (len(ctx.arg_types) == 2 - and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0): + if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0: arg = ctx.args[0][0] if isinstance(arg, IntExpr): exponent = arg.value - elif isinstance(arg, UnaryExpr) and arg.op == '-' and isinstance(arg.expr, IntExpr): + elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr): exponent = -arg.expr.value else: # Right operand not an int literal or a negated literal -- give up. return ctx.default_return_type if exponent >= 0: - return ctx.api.named_generic_type('builtins.int', []) + return ctx.api.named_generic_type("builtins.int", []) else: - return ctx.api.named_generic_type('builtins.float', []) + return ctx.api.named_generic_type("builtins.float", []) return ctx.default_return_type @@ -415,12 +435,11 @@ def int_neg_callback(ctx: MethodContext) -> Type: if is_literal_type_like(ctx.api.type_context[-1]): return LiteralType(value=-value, fallback=fallback) else: - return ctx.type.copy_modified(last_known_value=LiteralType( - value=-value, - fallback=ctx.type, - line=ctx.type.line, - column=ctx.type.column, - )) + return ctx.type.copy_modified( + last_known_value=LiteralType( + value=-value, fallback=ctx.type, line=ctx.type.line, column=ctx.type.column + ) + ) elif isinstance(ctx.type, LiteralType): value = ctx.type.value fallback = ctx.type.fallback diff --git a/mypy/plugins/enums.py b/mypy/plugins/enums.py index afd59bf0374d0..4451745f1589e 100644 --- a/mypy/plugins/enums.py +++ b/mypy/plugins/enums.py @@ -11,14 +11,15 @@ semanal_enum.py). """ from typing import Iterable, Optional, Sequence, TypeVar, cast + from typing_extensions import Final import mypy.plugin # To avoid circular imports. -from mypy.types import Type, Instance, LiteralType, CallableType, ProperType, get_proper_type -from mypy.typeops import make_simplified_union from mypy.nodes import TypeInfo -from mypy.subtypes import is_equivalent from mypy.semanal_enum import ENUM_BASES +from mypy.subtypes import is_equivalent +from mypy.typeops import make_simplified_union +from mypy.types import CallableType, Instance, LiteralType, ProperType, Type, get_proper_type ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | { f"{prefix}._name_" for prefix in ENUM_BASES @@ -28,7 +29,7 @@ } -def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: +def enum_name_callback(ctx: "mypy.plugin.AttributeContext") -> Type: """This plugin refines the 'name' attribute in enums to act as if they were declared to be final. @@ -47,12 +48,12 @@ def enum_name_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: if enum_field_name is None: return ctx.default_attr_type else: - str_type = ctx.api.named_generic_type('builtins.str', []) + str_type = ctx.api.named_generic_type("builtins.str", []) literal_type = LiteralType(enum_field_name, fallback=str_type) return str_type.copy_modified(last_known_value=literal_type) -_T = TypeVar('_T') +_T = TypeVar("_T") def _first(it: Iterable[_T]) -> Optional[_T]: @@ -66,8 +67,8 @@ def _first(it: Iterable[_T]) -> Optional[_T]: def _infer_value_type_with_auto_fallback( - ctx: 'mypy.plugin.AttributeContext', - proper_type: Optional[ProperType]) -> Optional[Type]: + ctx: "mypy.plugin.AttributeContext", proper_type: Optional[ProperType] +) -> Optional[Type]: """Figure out the type of an enum value accounting for `auto()`. This method is a no-op for a `None` proper_type and also in the case where @@ -75,28 +76,26 @@ def _infer_value_type_with_auto_fallback( """ if proper_type is None: return None - if not (isinstance(proper_type, Instance) and - proper_type.type.fullname == 'enum.auto'): + if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"): return proper_type - assert isinstance(ctx.type, Instance), 'An incorrect ctx.type was passed.' + assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed." info = ctx.type.type # Find the first _generate_next_value_ on the mro. We need to know # if it is `Enum` because `Enum` types say that the return-value of # `_generate_next_value_` is `Any`. In reality the default `auto()` # returns an `int` (presumably the `Any` in typeshed is to make it # easier to subclass and change the returned type). - type_with_gnv = _first( - ti for ti in info.mro if ti.names.get('_generate_next_value_')) + type_with_gnv = _first(ti for ti in info.mro if ti.names.get("_generate_next_value_")) if type_with_gnv is None: return ctx.default_attr_type - stnode = type_with_gnv.names['_generate_next_value_'] + stnode = type_with_gnv.names["_generate_next_value_"] # This should be a `CallableType` node_type = get_proper_type(stnode.type) if isinstance(node_type, CallableType): - if type_with_gnv.fullname == 'enum.Enum': - int_type = ctx.api.named_generic_type('builtins.int', []) + if type_with_gnv.fullname == "enum.Enum": + int_type = ctx.api.named_generic_type("builtins.int", []) return int_type return get_proper_type(node_type.ret_type) return ctx.default_attr_type @@ -110,14 +109,14 @@ def _implements_new(info: TypeInfo) -> bool: type_with_new = _first( ti for ti in info.mro - if ti.names.get('__new__') and not ti.fullname.startswith('builtins.') + if ti.names.get("__new__") and not ti.fullname.startswith("builtins.") ) if type_with_new is None: return False - return type_with_new.fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.StrEnum') + return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum") -def enum_value_callback(ctx: 'mypy.plugin.AttributeContext') -> Type: +def enum_value_callback(ctx: "mypy.plugin.AttributeContext") -> Type: """This plugin refines the 'value' attribute in enums to refer to the original underlying value. For example, suppose we have the following: @@ -164,11 +163,13 @@ class SomeEnum: node_types = ( get_proper_type(n.type) if n else None for n in stnodes - if n is None or not n.implicit) + if n is None or not n.implicit + ) proper_types = list( _infer_value_type_with_auto_fallback(ctx, t) for t in node_types - if t is None or not isinstance(t, CallableType)) + if t is None or not isinstance(t, CallableType) + ) underlying_type = _first(proper_types) if underlying_type is None: return ctx.default_attr_type @@ -179,7 +180,8 @@ class SomeEnum: # See https://github.com/python/mypy/pull/9443 all_same_value_type = all( proper_type is not None and proper_type == underlying_type - for proper_type in proper_types) + for proper_type in proper_types + ) if all_same_value_type: if underlying_type is not None: return underlying_type @@ -200,7 +202,8 @@ class SomeEnum: # Result will be `Literal[1] | Literal[2] | Literal[3]` for this case. all_equivalent_types = all( proper_type is not None and is_equivalent(proper_type, underlying_type) - for proper_type in proper_types) + for proper_type in proper_types + ) if all_equivalent_types: return make_simplified_union(cast(Sequence[Type], proper_types)) return ctx.default_attr_type @@ -218,8 +221,7 @@ class SomeEnum: if stnode is None: return ctx.default_attr_type - underlying_type = _infer_value_type_with_auto_fallback( - ctx, get_proper_type(stnode.type)) + underlying_type = _infer_value_type_with_auto_fallback(ctx, get_proper_type(stnode.type)) if underlying_type is None: return ctx.default_attr_type diff --git a/mypy/plugins/functools.py b/mypy/plugins/functools.py index db10b7f1a2623..074d911557751 100644 --- a/mypy/plugins/functools.py +++ b/mypy/plugins/functools.py @@ -1,23 +1,16 @@ """Plugin for supporting the functools standard library module.""" from typing import Dict, NamedTuple, Optional + from typing_extensions import Final import mypy.plugin from mypy.nodes import ARG_POS, ARG_STAR2, Argument, FuncItem, Var from mypy.plugins.common import add_method_to_class -from mypy.types import AnyType, CallableType, get_proper_type, Type, TypeOfAny, UnboundType - +from mypy.types import AnyType, CallableType, Type, TypeOfAny, UnboundType, get_proper_type -functools_total_ordering_makers: Final = { - 'functools.total_ordering', -} +functools_total_ordering_makers: Final = {"functools.total_ordering"} -_ORDERING_METHODS: Final = { - '__lt__', - '__le__', - '__gt__', - '__ge__', -} +_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"} class _MethodInfo(NamedTuple): @@ -25,8 +18,9 @@ class _MethodInfo(NamedTuple): type: CallableType -def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, - auto_attribs_default: bool = False) -> bool: +def functools_total_ordering_maker_callback( + ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False +) -> bool: """Add dunder methods to classes decorated with functools.total_ordering.""" if ctx.api.options.python_version < (3,): # This plugin is not supported in Python 2 mode (it's a no-op). @@ -36,7 +30,8 @@ def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, if not comparison_methods: ctx.api.fail( 'No ordering operation defined when using "functools.total_ordering": < > <= >=', - ctx.reason) + ctx.reason, + ) return True # prefer __lt__ to __le__ to __gt__ to __ge__ @@ -47,18 +42,20 @@ def functools_total_ordering_maker_callback(ctx: mypy.plugin.ClassDefContext, return True other_type = _find_other_type(root_method) - bool_type = ctx.api.named_type('builtins.bool') + bool_type = ctx.api.named_type("builtins.bool") ret_type: Type = bool_type - if root_method.type.ret_type != ctx.api.named_type('builtins.bool'): + if root_method.type.ret_type != ctx.api.named_type("builtins.bool"): proper_ret_type = get_proper_type(root_method.type.ret_type) - if not (isinstance(proper_ret_type, UnboundType) - and proper_ret_type.name.split('.')[-1] == 'bool'): + if not ( + isinstance(proper_ret_type, UnboundType) + and proper_ret_type.name.split(".")[-1] == "bool" + ): ret_type = AnyType(TypeOfAny.implementation_artifact) for additional_op in _ORDERING_METHODS: # Either the method is not implemented # or has an unknown signature that we can now extrapolate. if not comparison_methods.get(additional_op): - args = [Argument(Var('other', other_type), other_type, None, ARG_POS)] + args = [Argument(Var("other", other_type), other_type, None, ARG_POS)] add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type) return True diff --git a/mypy/plugins/singledispatch.py b/mypy/plugins/singledispatch.py index d6150836c562d..a01942d88ab8e 100644 --- a/mypy/plugins/singledispatch.py +++ b/mypy/plugins/singledispatch.py @@ -1,16 +1,23 @@ +from typing import List, NamedTuple, Optional, Sequence, TypeVar, Union + +from typing_extensions import Final + from mypy.messages import format_type +from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var +from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext from mypy.plugins.common import add_method_to_class -from mypy.nodes import ( - ARG_POS, Argument, Block, ClassDef, SymbolTable, TypeInfo, Var, Context -) from mypy.subtypes import is_subtype from mypy.types import ( - AnyType, CallableType, Instance, NoneType, Overloaded, Type, TypeOfAny, get_proper_type, - FunctionLike + AnyType, + CallableType, + FunctionLike, + Instance, + NoneType, + Overloaded, + Type, + TypeOfAny, + get_proper_type, ) -from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext -from typing import List, NamedTuple, Optional, Sequence, TypeVar, Union -from typing_extensions import Final class SingledispatchTypeVars(NamedTuple): @@ -23,11 +30,11 @@ class RegisterCallableInfo(NamedTuple): singledispatch_obj: Instance -SINGLEDISPATCH_TYPE: Final = 'functools._SingleDispatchCallable' +SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable" -SINGLEDISPATCH_REGISTER_METHOD: Final = f'{SINGLEDISPATCH_TYPE}.register' +SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register" -SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f'{SINGLEDISPATCH_TYPE}.__call__' +SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__" def get_singledispatch_info(typ: Instance) -> Optional[SingledispatchTypeVars]: @@ -36,7 +43,7 @@ def get_singledispatch_info(typ: Instance) -> Optional[SingledispatchTypeVars]: return None -T = TypeVar('T') +T = TypeVar("T") def get_first_arg(args: List[List[T]]) -> Optional[T]: @@ -46,23 +53,24 @@ def get_first_arg(args: List[List[T]]) -> Optional[T]: return None -REGISTER_RETURN_CLASS: Final = '_SingleDispatchRegisterCallable' +REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable" -REGISTER_CALLABLE_CALL_METHOD: Final = f'functools.{REGISTER_RETURN_CLASS}.__call__' +REGISTER_CALLABLE_CALL_METHOD: Final = f"functools.{REGISTER_RETURN_CLASS}.__call__" -def make_fake_register_class_instance(api: CheckerPluginInterface, type_args: Sequence[Type] - ) -> Instance: +def make_fake_register_class_instance( + api: CheckerPluginInterface, type_args: Sequence[Type] +) -> Instance: defn = ClassDef(REGISTER_RETURN_CLASS, Block([])) - defn.fullname = f'functools.{REGISTER_RETURN_CLASS}' + defn.fullname = f"functools.{REGISTER_RETURN_CLASS}" info = TypeInfo(SymbolTable(), defn, "functools") - obj_type = api.named_generic_type('builtins.object', []).type + obj_type = api.named_generic_type("builtins.object", []).type info.bases = [Instance(obj_type, [])] info.mro = [info, obj_type] defn.info = info - func_arg = Argument(Var('name'), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS) - add_method_to_class(api, defn, '__call__', [func_arg], NoneType()) + func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS) + add_method_to_class(api, defn, "__call__", [func_arg], NoneType()) return Instance(info, type_args) @@ -93,16 +101,14 @@ def create_singledispatch_function_callback(ctx: FunctionContext) -> Type: if len(func_type.arg_kinds) < 1: fail( - ctx, - 'Singledispatch function requires at least one argument', - func_type.definition, + ctx, "Singledispatch function requires at least one argument", func_type.definition ) return ctx.default_return_type elif not func_type.arg_kinds[0].is_positional(star=True): fail( ctx, - 'First argument to singledispatch function must be a positional argument', + "First argument to singledispatch function must be a positional argument", func_type.definition, ) return ctx.default_return_type @@ -132,10 +138,7 @@ def singledispatch_register_callback(ctx: MethodContext) -> Type: # actual type register_type = first_arg_type.items[0].ret_type type_args = RegisterCallableInfo(register_type, ctx.type) - register_callable = make_fake_register_class_instance( - ctx.api, - type_args - ) + register_callable = make_fake_register_class_instance(ctx.api, type_args) return register_callable elif isinstance(first_arg_type, CallableType): # TODO: do more checking for registered functions @@ -150,8 +153,12 @@ def singledispatch_register_callback(ctx: MethodContext) -> Type: return ctx.default_return_type -def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Type, - register_arg: Optional[Type] = None) -> None: +def register_function( + ctx: PluginContext, + singledispatch_obj: Instance, + func: Type, + register_arg: Optional[Type] = None, +) -> None: """Register a function""" func = get_proper_type(func) @@ -172,9 +179,13 @@ def register_function(ctx: PluginContext, singledispatch_obj: Instance, func: Ty fallback_dispatch_type = fallback.arg_types[0] if not is_subtype(dispatch_type, fallback_dispatch_type): - fail(ctx, 'Dispatch type {} must be subtype of fallback function first argument {}'.format( + fail( + ctx, + "Dispatch type {} must be subtype of fallback function first argument {}".format( format_type(dispatch_type), format_type(fallback_dispatch_type) - ), func.definition) + ), + func.definition, + ) return return diff --git a/mypy/pyinfo.py b/mypy/pyinfo.py index ed0fed3707009..278c16f3ae928 100644 --- a/mypy/pyinfo.py +++ b/mypy/pyinfo.py @@ -1,4 +1,5 @@ from __future__ import print_function + """Utilities to find the site and prefix information of a Python executable, which may be Python 2. This file MUST remain compatible with Python 2. Since we cannot make any assumptions about the @@ -13,15 +14,16 @@ MYPY = False if MYPY: - from typing import Tuple, List + from typing import List, Tuple -if __name__ == '__main__': +if __name__ == "__main__": # HACK: We don't want to pick up mypy.types as the top-level types # module. This could happen if this file is run as a script. # This workaround fixes it. old_sys_path = sys.path sys.path = sys.path[1:] import types # noqa + sys.path = old_sys_path @@ -47,7 +49,7 @@ def getsyspath(): stdlib_zip = os.path.join( sys.base_exec_prefix, getattr(sys, "platlibdir", "lib"), - "python{}{}.zip".format(sys.version_info.major, sys.version_info.minor) + "python{}{}.zip".format(sys.version_info.major, sys.version_info.minor), ) stdlib = sysconfig.get_path("stdlib") stdlib_ext = os.path.join(stdlib, "lib-dynload") @@ -72,14 +74,11 @@ def getsyspath(): def getsearchdirs(): # type: () -> Tuple[List[str], List[str]] - return ( - getsyspath(), - getsitepackages(), - ) + return (getsyspath(), getsitepackages()) -if __name__ == '__main__': - if sys.argv[-1] == 'getsearchdirs': +if __name__ == "__main__": + if sys.argv[-1] == "getsearchdirs": print(repr(getsearchdirs())) else: print("ERROR: incorrect argument to pyinfo.py.", file=sys.stderr) diff --git a/mypy/reachability.py b/mypy/reachability.py index eec472376317a..714b3bae07d30 100644 --- a/mypy/reachability.py +++ b/mypy/reachability.py @@ -1,17 +1,36 @@ """Utilities related to determining the reachability of code (in semantic analysis).""" -from typing import Tuple, TypeVar, Union, Optional +from typing import Optional, Tuple, TypeVar, Union + from typing_extensions import Final +from mypy.literals import literal from mypy.nodes import ( - Expression, IfStmt, Block, AssertStmt, MatchStmt, NameExpr, UnaryExpr, MemberExpr, OpExpr, - ComparisonExpr, StrExpr, UnicodeExpr, CallExpr, IntExpr, TupleExpr, IndexExpr, SliceExpr, - Import, ImportFrom, ImportAll, LITERAL_YES + LITERAL_YES, + AssertStmt, + Block, + CallExpr, + ComparisonExpr, + Expression, + IfStmt, + Import, + ImportAll, + ImportFrom, + IndexExpr, + IntExpr, + MatchStmt, + MemberExpr, + NameExpr, + OpExpr, + SliceExpr, + StrExpr, + TupleExpr, + UnaryExpr, + UnicodeExpr, ) from mypy.options import Options -from mypy.patterns import Pattern, AsPattern, OrPattern +from mypy.patterns import AsPattern, OrPattern, Pattern from mypy.traverser import TraverserVisitor -from mypy.literals import literal # Inferred truth value of an expression. ALWAYS_TRUE: Final = 1 @@ -28,14 +47,7 @@ MYPY_FALSE: MYPY_TRUE, } -reverse_op: Final = { - "==": "==", - "!=": "!=", - "<": ">", - ">": "<", - "<=": ">=", - ">=": "<=", -} +reverse_op: Final = {"==": "==", "!=": "!=", "<": ">", ">": "<", "<=": ">=", ">=": "<="} def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: @@ -51,7 +63,7 @@ def infer_reachability_of_if_statement(s: IfStmt, options: Options) -> None: # This condition is false at runtime; this will affect # import priorities. mark_block_mypy_only(s.body[i]) - for body in s.body[i + 1:]: + for body in s.body[i + 1 :]: mark_block_unreachable(body) # Make sure else body always exists and is marked as @@ -73,13 +85,17 @@ def infer_reachability_of_match_statement(s: MatchStmt, options: Options) -> Non else: guard_value = ALWAYS_TRUE - if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) \ - or guard_value in (ALWAYS_FALSE, MYPY_FALSE): + if pattern_value in (ALWAYS_FALSE, MYPY_FALSE) or guard_value in ( + ALWAYS_FALSE, + MYPY_FALSE, + ): # The case is considered always false, so we skip the case body. mark_block_unreachable(s.bodies[i]) - elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) \ - and guard_value in (ALWAYS_TRUE, MYPY_TRUE): - for body in s.bodies[i + 1:]: + elif pattern_value in (ALWAYS_FALSE, MYPY_TRUE) and guard_value in ( + ALWAYS_TRUE, + MYPY_TRUE, + ): + for body in s.bodies[i + 1 :]: mark_block_unreachable(body) if guard_value == MYPY_TRUE: @@ -100,11 +116,11 @@ def infer_condition_value(expr: Expression, options: Options) -> int: false under mypy and true at runtime, else TRUTH_VALUE_UNKNOWN. """ pyversion = options.python_version - name = '' + name = "" negated = False alias = expr if isinstance(alias, UnaryExpr): - if alias.op == 'not': + if alias.op == "not": expr = alias.expr negated = True result = TRUTH_VALUE_UNKNOWN @@ -112,10 +128,11 @@ def infer_condition_value(expr: Expression, options: Options) -> int: name = expr.name elif isinstance(expr, MemberExpr): name = expr.name - elif isinstance(expr, OpExpr) and expr.op in ('and', 'or'): + elif isinstance(expr, OpExpr) and expr.op in ("and", "or"): left = infer_condition_value(expr.left, options) - if ((left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or - (left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): + if (left in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( + left in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" + ): # Either `True and ` or `False or `: the result will # always be the right-hand-side. return infer_condition_value(expr.right, options) @@ -128,11 +145,11 @@ def infer_condition_value(expr: Expression, options: Options) -> int: if result == TRUTH_VALUE_UNKNOWN: result = consider_sys_platform(expr, options.platform) if result == TRUTH_VALUE_UNKNOWN: - if name == 'PY2': + if name == "PY2": result = ALWAYS_TRUE if pyversion[0] == 2 else ALWAYS_FALSE - elif name == 'PY3': + elif name == "PY3": result = ALWAYS_TRUE if pyversion[0] == 3 else ALWAYS_FALSE - elif name == 'MYPY' or name == 'TYPE_CHECKING': + elif name == "MYPY" or name == "TYPE_CHECKING": result = MYPY_TRUE elif name in options.always_true: result = ALWAYS_TRUE @@ -146,8 +163,9 @@ def infer_condition_value(expr: Expression, options: Options) -> int: def infer_pattern_value(pattern: Pattern) -> int: if isinstance(pattern, AsPattern) and pattern.pattern is None: return ALWAYS_TRUE - elif isinstance(pattern, OrPattern) and \ - any(infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns): + elif isinstance(pattern, OrPattern) and any( + infer_pattern_value(p) == ALWAYS_TRUE for p in pattern.patterns + ): return ALWAYS_TRUE else: return TRUTH_VALUE_UNKNOWN @@ -169,7 +187,7 @@ def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> i if len(expr.operators) > 1: return TRUTH_VALUE_UNKNOWN op = expr.operators[0] - if op not in ('==', '!=', '<=', '>=', '<', '>'): + if op not in ("==", "!=", "<=", ">=", "<", ">"): return TRUTH_VALUE_UNKNOWN index = contains_sys_version_info(expr.operands[0]) @@ -192,7 +210,7 @@ def consider_sys_version_info(expr: Expression, pyversion: Tuple[int, ...]) -> i hi = 2 if 0 <= lo < hi <= 2: val = pyversion[lo:hi] - if len(val) == len(thing) or len(val) > len(thing) and op not in ('==', '!='): + if len(val) == len(thing) or len(val) > len(thing) and op not in ("==", "!="): return fixed_comparison(val, op, thing) return TRUTH_VALUE_UNKNOWN @@ -211,9 +229,9 @@ def consider_sys_platform(expr: Expression, platform: str) -> int: if len(expr.operators) > 1: return TRUTH_VALUE_UNKNOWN op = expr.operators[0] - if op not in ('==', '!='): + if op not in ("==", "!="): return TRUTH_VALUE_UNKNOWN - if not is_sys_attr(expr.operands[0], 'platform'): + if not is_sys_attr(expr.operands[0], "platform"): return TRUTH_VALUE_UNKNOWN right = expr.operands[1] if not isinstance(right, (StrExpr, UnicodeExpr)): @@ -224,9 +242,9 @@ def consider_sys_platform(expr: Expression, platform: str) -> int: return TRUTH_VALUE_UNKNOWN if len(expr.args) != 1 or not isinstance(expr.args[0], (StrExpr, UnicodeExpr)): return TRUTH_VALUE_UNKNOWN - if not is_sys_attr(expr.callee.expr, 'platform'): + if not is_sys_attr(expr.callee.expr, "platform"): return TRUTH_VALUE_UNKNOWN - if expr.callee.name != 'startswith': + if expr.callee.name != "startswith": return TRUTH_VALUE_UNKNOWN if platform.startswith(expr.args[0].value): return ALWAYS_TRUE @@ -236,28 +254,29 @@ def consider_sys_platform(expr: Expression, platform: str) -> int: return TRUTH_VALUE_UNKNOWN -Targ = TypeVar('Targ', int, str, Tuple[int, ...]) +Targ = TypeVar("Targ", int, str, Tuple[int, ...]) def fixed_comparison(left: Targ, op: str, right: Targ) -> int: rmap = {False: ALWAYS_FALSE, True: ALWAYS_TRUE} - if op == '==': + if op == "==": return rmap[left == right] - if op == '!=': + if op == "!=": return rmap[left != right] - if op == '<=': + if op == "<=": return rmap[left <= right] - if op == '>=': + if op == ">=": return rmap[left >= right] - if op == '<': + if op == "<": return rmap[left < right] - if op == '>': + if op == ">": return rmap[left > right] return TRUTH_VALUE_UNKNOWN -def contains_int_or_tuple_of_ints(expr: Expression - ) -> Union[None, int, Tuple[int], Tuple[int, ...]]: +def contains_int_or_tuple_of_ints( + expr: Expression, +) -> Union[None, int, Tuple[int], Tuple[int, ...]]: if isinstance(expr, IntExpr): return expr.value if isinstance(expr, TupleExpr): @@ -271,11 +290,12 @@ def contains_int_or_tuple_of_ints(expr: Expression return None -def contains_sys_version_info(expr: Expression - ) -> Union[None, int, Tuple[Optional[int], Optional[int]]]: - if is_sys_attr(expr, 'version_info'): +def contains_sys_version_info( + expr: Expression, +) -> Union[None, int, Tuple[Optional[int], Optional[int]]]: + if is_sys_attr(expr, "version_info"): return (None, None) # Same as sys.version_info[:] - if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, 'version_info'): + if isinstance(expr, IndexExpr) and is_sys_attr(expr.base, "version_info"): index = expr.index if isinstance(index, IntExpr): return index.value @@ -301,7 +321,7 @@ def is_sys_attr(expr: Expression, name: str) -> bool: # - import sys as _sys # - from sys import version_info if isinstance(expr, MemberExpr) and expr.name == name: - if isinstance(expr.expr, NameExpr) and expr.expr.name == 'sys': + if isinstance(expr.expr, NameExpr) and expr.expr.name == "sys": # TODO: Guard against a local named sys, etc. # (Though later passes will still do most checking.) return True diff --git a/mypy/renaming.py b/mypy/renaming.py index ae21631f0f0ae..6db8bbad7e14b 100644 --- a/mypy/renaming.py +++ b/mypy/renaming.py @@ -1,11 +1,31 @@ from contextlib import contextmanager from typing import Dict, Iterator, List, Set + from typing_extensions import Final from mypy.nodes import ( - Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr, - WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr, - ImportFrom, MemberExpr, IndexExpr, Import, ImportAll, ClassDef + AssignmentStmt, + Block, + BreakStmt, + ClassDef, + ContinueStmt, + ForStmt, + FuncDef, + Import, + ImportAll, + ImportFrom, + IndexExpr, + ListExpr, + Lvalue, + MatchStmt, + MemberExpr, + MypyFile, + NameExpr, + StarExpr, + TryStmt, + TupleExpr, + WhileStmt, + WithStmt, ) from mypy.patterns import AsPattern from mypy.traverser import TraverserVisitor @@ -90,7 +110,7 @@ def visit_func_def(self, fdef: FuncDef) -> None: name = arg.variable.name # 'self' can't be redefined since it's special as it allows definition of # attributes. 'cls' can't be used to define attributes so we can ignore it. - can_be_redefined = name != 'self' # TODO: Proper check + can_be_redefined = name != "self" # TODO: Proper check self.record_assignment(arg.variable.name, can_be_redefined) self.handle_arg(name) @@ -521,7 +541,7 @@ def enter_scope(self) -> Iterator[None]: self.flush_refs() def reject_redefinition_of_vars_in_scope(self) -> None: - self.record_skipped('*') + self.record_skipped("*") def record_skipped(self, name: str) -> None: self.skipped[-1].add(name) @@ -529,7 +549,7 @@ def record_skipped(self, name: str) -> None: def flush_refs(self) -> None: ref_dict = self.refs.pop() skipped = self.skipped.pop() - if '*' not in skipped: + if "*" not in skipped: for name, refs in ref_dict.items(): if len(refs) <= 1 or name in skipped: continue diff --git a/mypy/report.py b/mypy/report.py index 28fa5c274b74c..ca8aa03428c92 100644 --- a/mypy/report.py +++ b/mypy/report.py @@ -1,31 +1,32 @@ """Classes for producing HTML reports about imprecision.""" -from abc import ABCMeta, abstractmethod import collections +import itertools import json import os import shutil -import tokenize -import time import sys -import itertools +import time +import tokenize +import typing +from abc import ABCMeta, abstractmethod from operator import attrgetter +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, cast from urllib.request import pathname2url -import typing -from typing import Any, Callable, Dict, List, Optional, Tuple, cast, Iterator from typing_extensions import Final, TypeAlias as _TypeAlias -from mypy.nodes import MypyFile, Expression, FuncDef from mypy import stats +from mypy.defaults import REPORTER_NAMES +from mypy.nodes import Expression, FuncDef, MypyFile from mypy.options import Options from mypy.traverser import TraverserVisitor from mypy.types import Type, TypeOfAny from mypy.version import __version__ -from mypy.defaults import REPORTER_NAMES try: from lxml import etree # type: ignore + LXML_INSTALLED = True except ImportError: LXML_INSTALLED = False @@ -43,8 +44,7 @@ ) ReporterClasses: _TypeAlias = Dict[ - str, - Tuple[Callable[['Reports', str], 'AbstractReporter'], bool], + str, Tuple[Callable[["Reports", str], "AbstractReporter"], bool], ] reporter_classes: Final[ReporterClasses] = {} @@ -59,28 +59,34 @@ def __init__(self, data_dir: str, report_dirs: Dict[str, str]) -> None: for report_type, report_dir in sorted(report_dirs.items()): self.add_report(report_type, report_dir) - def add_report(self, report_type: str, report_dir: str) -> 'AbstractReporter': + def add_report(self, report_type: str, report_dir: str) -> "AbstractReporter": try: return self.named_reporters[report_type] except KeyError: pass reporter_cls, needs_lxml = reporter_classes[report_type] if needs_lxml and not LXML_INSTALLED: - print(('You must install the lxml package before you can run mypy' - ' with `--{}-report`.\n' - 'You can do this with `python3 -m pip install lxml`.').format(report_type), - file=sys.stderr) + print( + ( + "You must install the lxml package before you can run mypy" + " with `--{}-report`.\n" + "You can do this with `python3 -m pip install lxml`." + ).format(report_type), + file=sys.stderr, + ) raise ImportError reporter = reporter_cls(self, report_dir) self.reporters.append(reporter) self.named_reporters[report_type] = reporter return reporter - def file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: for reporter in self.reporters: reporter.on_file(tree, modules, type_map, options) @@ -92,15 +98,17 @@ def finish(self) -> None: class AbstractReporter(metaclass=ABCMeta): def __init__(self, reports: Reports, output_dir: str) -> None: self.output_dir = output_dir - if output_dir != '': + if output_dir != "": stats.ensure_dir_exists(output_dir) @abstractmethod - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: pass @abstractmethod @@ -108,9 +116,11 @@ def on_finish(self) -> None: pass -def register_reporter(report_name: str, - reporter: Callable[[Reports, str], AbstractReporter], - needs_lxml: bool = False) -> None: +def register_reporter( + report_name: str, + reporter: Callable[[Reports, str], AbstractReporter], + needs_lxml: bool = False, +) -> None: reporter_classes[report_name] = (reporter, needs_lxml) @@ -121,9 +131,9 @@ def alias_reporter(source_reporter: str, target_reporter: str) -> None: def should_skip_path(path: str) -> bool: if stats.is_special_module(path): return True - if path.startswith('..'): + if path.startswith(".."): return True - if 'stubs' in path.split('/') or 'stubs' in path.split(os.sep): + if "stubs" in path.split("/") or "stubs" in path.split(os.sep): return True return False @@ -148,14 +158,16 @@ def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) self.counts: Dict[str, Tuple[int, int, int, int]] = {} - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: # Count physical lines. This assumes the file's encoding is a # superset of ASCII (or at least uses \n in its line endings). - with open(tree.path, 'rb') as f: + with open(tree.path, "rb") as f: physical_lines = len(f.readlines()) func_counter = FuncCounterVisitor() @@ -167,11 +179,16 @@ def on_file(self, if options.ignore_errors: annotated_funcs = 0 - imputed_annotated_lines = (physical_lines * annotated_funcs // total_funcs - if total_funcs else physical_lines) + imputed_annotated_lines = ( + physical_lines * annotated_funcs // total_funcs if total_funcs else physical_lines + ) - self.counts[tree._fullname] = (imputed_annotated_lines, physical_lines, - annotated_funcs, total_funcs) + self.counts[tree._fullname] = ( + imputed_annotated_lines, + physical_lines, + annotated_funcs, + total_funcs, + ) def on_finish(self) -> None: counts: List[Tuple[Tuple[int, int, int, int], str]] = sorted( @@ -181,10 +198,10 @@ def on_finish(self) -> None: with open(os.path.join(self.output_dir, "linecount.txt"), "w") as f: f.write("{:7} {:7} {:6} {:6} total\n".format(*total_counts)) for c, p in counts: - f.write(f'{c[0]:7} {c[1]:7} {c[2]:6} {c[3]:6} {p}\n') + f.write(f"{c[0]:7} {c[1]:7} {c[2]:6} {c[3]:6} {p}\n") -register_reporter('linecount', LineCountReporter) +register_reporter("linecount", LineCountReporter) class AnyExpressionsReporter(AbstractReporter): @@ -195,17 +212,21 @@ def __init__(self, reports: Reports, output_dir: str) -> None: self.counts: Dict[str, Tuple[int, int]] = {} self.any_types_counter: Dict[str, typing.Counter[int]] = {} - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: - visitor = stats.StatisticsVisitor(inferred=True, - filename=tree.fullname, - modules=modules, - typemap=type_map, - all_nodes=True, - visit_untyped_defs=False) + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: + visitor = stats.StatisticsVisitor( + inferred=True, + filename=tree.fullname, + modules=modules, + typemap=type_map, + all_nodes=True, + visit_untyped_defs=False, + ) tree.accept(visitor) self.any_types_counter[tree.fullname] = visitor.type_of_any_counter num_unanalyzed_lines = list(visitor.line_map.values()).count(stats.TYPE_UNANALYZED) @@ -219,12 +240,9 @@ def on_finish(self) -> None: self._report_any_exprs() self._report_types_of_anys() - def _write_out_report(self, - filename: str, - header: List[str], - rows: List[List[str]], - footer: List[str], - ) -> None: + def _write_out_report( + self, filename: str, header: List[str], rows: List[List[str]], footer: List[str] + ) -> None: row_len = len(header) assert all(len(row) == row_len for row in rows + [header, footer]) min_column_distance = 3 # minimum distance between numbers in two columns @@ -236,17 +254,17 @@ def _write_out_report(self, # Do not add min_column_distance to the first column. if i > 0: widths[i] = w + min_column_distance - with open(os.path.join(self.output_dir, filename), 'w') as f: + with open(os.path.join(self.output_dir, filename), "w") as f: header_str = ("{:>{}}" * len(widths)).format(*itertools.chain(*zip(header, widths))) - separator = '-' * len(header_str) - f.write(header_str + '\n') - f.write(separator + '\n') + separator = "-" * len(header_str) + f.write(header_str + "\n") + f.write(separator + "\n") for row_values in rows: r = ("{:>{}}" * len(widths)).format(*itertools.chain(*zip(row_values, widths))) - f.write(r + '\n') - f.write(separator + '\n') + f.write(r + "\n") + f.write(separator + "\n") footer_str = ("{:>{}}" * len(widths)).format(*itertools.chain(*zip(footer, widths))) - f.write(footer_str + '\n') + f.write(footer_str + "\n") def _report_any_exprs(self) -> None: total_any = sum(num_any for num_any, _ in self.counts.values()) @@ -260,11 +278,11 @@ def _report_any_exprs(self) -> None: for filename in sorted(self.counts): (num_any, num_total) = self.counts[filename] coverage = (float(num_total - num_any) / float(num_total)) * 100 - coverage_str = f'{coverage:.2f}%' + coverage_str = f"{coverage:.2f}%" rows.append([filename, str(num_any), str(num_total), coverage_str]) rows.sort(key=lambda x: x[0]) - total_row = ["Total", str(total_any), str(total_expr), f'{total_coverage:.2f}%'] - self._write_out_report('any-exprs.txt', column_names, rows, total_row) + total_row = ["Total", str(total_any), str(total_expr), f"{total_coverage:.2f}%"] + self._write_out_report("any-exprs.txt", column_names, rows, total_row) def _report_types_of_anys(self) -> None: total_counter: typing.Counter[int] = collections.Counter() @@ -278,12 +296,11 @@ def _report_types_of_anys(self) -> None: for filename, counter in self.any_types_counter.items(): rows.append([filename] + [str(counter[typ]) for typ in type_of_any_name_map]) rows.sort(key=lambda x: x[0]) - total_row = [total_row_name] + [str(total_counter[typ]) - for typ in type_of_any_name_map] - self._write_out_report('types-of-anys.txt', column_names, rows, total_row) + total_row = [total_row_name] + [str(total_counter[typ]) for typ in type_of_any_name_map] + self._write_out_report("types-of-anys.txt", column_names, rows, total_row) -register_reporter('any-exprs', AnyExpressionsReporter) +register_reporter("any-exprs", AnyExpressionsReporter) class LineCoverageVisitor(TraverserVisitor): @@ -313,14 +330,14 @@ def indentation_level(self, line_number: int) -> Optional[int]: line = self.source[line_number] indent = 0 for char in list(line): - if char == ' ': + if char == " ": indent += 1 - elif char == '\t': + elif char == "\t": indent = 8 * ((indent + 8) // 8) - elif char == '#': + elif char == "#": # Line is a comment; ignore it return None - elif char == '\n': + elif char == "\n": # Line is entirely whitespace; ignore it return None # TODO line continuation (\) @@ -391,11 +408,13 @@ def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) self.lines_covered: Dict[str, List[int]] = {} - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: with open(tree.path) as f: tree_source = f.readlines() @@ -410,11 +429,11 @@ def on_file(self, self.lines_covered[os.path.abspath(tree.path)] = covered_lines def on_finish(self) -> None: - with open(os.path.join(self.output_dir, 'coverage.json'), 'w') as f: - json.dump({'lines': self.lines_covered}, f) + with open(os.path.join(self.output_dir, "coverage.json"), "w") as f: + json.dump({"lines": self.lines_covered}, f) -register_reporter('linecoverage', LineCoverageReporter) +register_reporter("linecoverage", LineCoverageReporter) class FileInfo: @@ -439,10 +458,10 @@ class MemoryXmlReporter(AbstractReporter): def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) - self.xslt_html_path = os.path.join(reports.data_dir, 'xml', 'mypy-html.xslt') - self.xslt_txt_path = os.path.join(reports.data_dir, 'xml', 'mypy-txt.xslt') - self.css_html_path = os.path.join(reports.data_dir, 'xml', 'mypy-html.css') - xsd_path = os.path.join(reports.data_dir, 'xml', 'mypy.xsd') + self.xslt_html_path = os.path.join(reports.data_dir, "xml", "mypy-html.xslt") + self.xslt_txt_path = os.path.join(reports.data_dir, "xml", "mypy-txt.xslt") + self.css_html_path = os.path.join(reports.data_dir, "xml", "mypy-html.css") + xsd_path = os.path.join(reports.data_dir, "xml", "mypy.xsd") self.schema = etree.XMLSchema(etree.parse(xsd_path)) self.last_xml: Optional[Any] = None self.files: List[FileInfo] = [] @@ -452,11 +471,13 @@ def __init__(self, reports: Reports, output_dir: str) -> None: # Tabs (#x09) are allowed in XML content. control_fixer: Final = str.maketrans("".join(chr(i) for i in range(32) if i != 9), "?" * 31) - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: self.last_xml = None try: @@ -467,29 +488,35 @@ def on_file(self, if should_skip_path(path) or os.path.isdir(path): return # `path` can sometimes be a directory, see #11334 - visitor = stats.StatisticsVisitor(inferred=True, - filename=tree.fullname, - modules=modules, - typemap=type_map, - all_nodes=True) + visitor = stats.StatisticsVisitor( + inferred=True, + filename=tree.fullname, + modules=modules, + typemap=type_map, + all_nodes=True, + ) tree.accept(visitor) - root = etree.Element('mypy-report-file', name=path, module=tree._fullname) + root = etree.Element("mypy-report-file", name=path, module=tree._fullname) doc = etree.ElementTree(root) file_info = FileInfo(path, tree._fullname) for lineno, line_text in iterate_python_lines(path): status = visitor.line_map.get(lineno, stats.TYPE_EMPTY) file_info.counts[status] += 1 - etree.SubElement(root, 'line', - any_info=self._get_any_info_for_line(visitor, lineno), - content=line_text.rstrip('\n').translate(self.control_fixer), - number=str(lineno), - precision=stats.precision_names[status]) + etree.SubElement( + root, + "line", + any_info=self._get_any_info_for_line(visitor, lineno), + content=line_text.rstrip("\n").translate(self.control_fixer), + number=str(lineno), + precision=stats.precision_names[status], + ) # Assumes a layout similar to what XmlReporter uses. - xslt_path = os.path.relpath('mypy-html.xslt', path) - transform_pi = etree.ProcessingInstruction('xml-stylesheet', - f'type="text/xsl" href="{pathname2url(xslt_path)}"') + xslt_path = os.path.relpath("mypy-html.xslt", path) + transform_pi = etree.ProcessingInstruction( + "xml-stylesheet", f'type="text/xsl" href="{pathname2url(xslt_path)}"' + ) root.addprevious(transform_pi) self.schema.assertValid(doc) @@ -514,32 +541,36 @@ def on_finish(self) -> None: # index_path = os.path.join(self.output_dir, 'index.xml') output_files = sorted(self.files, key=lambda x: x.module) - root = etree.Element('mypy-report-index', name='index') + root = etree.Element("mypy-report-index", name="index") doc = etree.ElementTree(root) for file_info in output_files: - etree.SubElement(root, 'file', - file_info.attrib(), - module=file_info.module, - name=pathname2url(file_info.name), - total=str(file_info.total())) - xslt_path = os.path.relpath('mypy-html.xslt', '.') - transform_pi = etree.ProcessingInstruction('xml-stylesheet', - f'type="text/xsl" href="{pathname2url(xslt_path)}"') + etree.SubElement( + root, + "file", + file_info.attrib(), + module=file_info.module, + name=pathname2url(file_info.name), + total=str(file_info.total()), + ) + xslt_path = os.path.relpath("mypy-html.xslt", ".") + transform_pi = etree.ProcessingInstruction( + "xml-stylesheet", f'type="text/xsl" href="{pathname2url(xslt_path)}"' + ) root.addprevious(transform_pi) self.schema.assertValid(doc) self.last_xml = doc -register_reporter('memory-xml', MemoryXmlReporter, needs_lxml=True) +register_reporter("memory-xml", MemoryXmlReporter, needs_lxml=True) def get_line_rate(covered_lines: int, total_lines: int) -> str: if total_lines == 0: return str(1.0) else: - return f'{covered_lines / total_lines:.4f}' + return f"{covered_lines / total_lines:.4f}" class CoberturaPackage: @@ -553,12 +584,10 @@ def __init__(self, name: str) -> None: self.covered_lines = 0 def as_xml(self) -> Any: - package_element = etree.Element('package', - complexity='1.0', - name=self.name) - package_element.attrib['branch-rate'] = '0' - package_element.attrib['line-rate'] = get_line_rate(self.covered_lines, self.total_lines) - classes_element = etree.SubElement(package_element, 'classes') + package_element = etree.Element("package", complexity="1.0", name=self.name) + package_element.attrib["branch-rate"] = "0" + package_element.attrib["line-rate"] = get_line_rate(self.covered_lines, self.total_lines) + classes_element = etree.SubElement(package_element, "classes") for class_name in sorted(self.classes): classes_element.append(self.classes[class_name]) self.add_packages(package_element) @@ -566,8 +595,8 @@ def as_xml(self) -> Any: def add_packages(self, parent_element: Any) -> None: if self.packages: - packages_element = etree.SubElement(parent_element, 'packages') - for package in sorted(self.packages.values(), key=attrgetter('name')): + packages_element = etree.SubElement(parent_element, "packages") + for package in sorted(self.packages.values(), key=attrgetter("name")): packages_element.append(package.as_xml()) @@ -577,33 +606,32 @@ class CoberturaXmlReporter(AbstractReporter): def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) - self.root = etree.Element('coverage', - timestamp=str(int(time.time())), - version=__version__) + self.root = etree.Element("coverage", timestamp=str(int(time.time())), version=__version__) self.doc = etree.ElementTree(self.root) - self.root_package = CoberturaPackage('.') - - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + self.root_package = CoberturaPackage(".") + + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: path = os.path.relpath(tree.path) - visitor = stats.StatisticsVisitor(inferred=True, - filename=tree.fullname, - modules=modules, - typemap=type_map, - all_nodes=True) + visitor = stats.StatisticsVisitor( + inferred=True, + filename=tree.fullname, + modules=modules, + typemap=type_map, + all_nodes=True, + ) tree.accept(visitor) class_name = os.path.basename(path) file_info = FileInfo(path, tree._fullname) - class_element = etree.Element('class', - complexity='1.0', - filename=path, - name=class_name) - etree.SubElement(class_element, 'methods') - lines_element = etree.SubElement(class_element, 'lines') + class_element = etree.Element("class", complexity="1.0", filename=path, name=class_name) + etree.SubElement(class_element, "methods") + lines_element = etree.SubElement(class_element, "lines") with tokenize.open(path) as input_file: class_lines_covered = 0 @@ -621,21 +649,25 @@ def on_file(self, if status == stats.TYPE_IMPRECISE: branch = True file_info.counts[status] += 1 - line_element = etree.SubElement(lines_element, 'line', - branch=str(branch).lower(), - hits=str(hits), - number=str(lineno), - precision=stats.precision_names[status]) + line_element = etree.SubElement( + lines_element, + "line", + branch=str(branch).lower(), + hits=str(hits), + number=str(lineno), + precision=stats.precision_names[status], + ) if branch: - line_element.attrib['condition-coverage'] = '50% (1/2)' - class_element.attrib['branch-rate'] = '0' - class_element.attrib['line-rate'] = get_line_rate(class_lines_covered, - class_total_lines) + line_element.attrib["condition-coverage"] = "50% (1/2)" + class_element.attrib["branch-rate"] = "0" + class_element.attrib["line-rate"] = get_line_rate( + class_lines_covered, class_total_lines + ) # parent_module is set to whichever module contains this file. For most files, we want # to simply strip the last element off of the module. But for __init__.py files, # the module == the parent module. - parent_module = file_info.module.rsplit('.', 1)[0] - if file_info.name.endswith('__init__.py'): + parent_module = file_info.module.rsplit(".", 1)[0] + if file_info.name.endswith("__init__.py"): parent_module = file_info.module if parent_module not in self.root_package.packages: @@ -648,19 +680,20 @@ def on_file(self, current_package.classes[class_name] = class_element def on_finish(self) -> None: - self.root.attrib['line-rate'] = get_line_rate(self.root_package.covered_lines, - self.root_package.total_lines) - self.root.attrib['branch-rate'] = '0' - sources = etree.SubElement(self.root, 'sources') - source_element = etree.SubElement(sources, 'source') + self.root.attrib["line-rate"] = get_line_rate( + self.root_package.covered_lines, self.root_package.total_lines + ) + self.root.attrib["branch-rate"] = "0" + sources = etree.SubElement(self.root, "sources") + source_element = etree.SubElement(sources, "source") source_element.text = os.getcwd() self.root_package.add_packages(self.root) - out_path = os.path.join(self.output_dir, 'cobertura.xml') - self.doc.write(out_path, encoding='utf-8', pretty_print=True) - print('Generated Cobertura report:', os.path.abspath(out_path)) + out_path = os.path.join(self.output_dir, "cobertura.xml") + self.doc.write(out_path, encoding="utf-8", pretty_print=True) + print("Generated Cobertura report:", os.path.abspath(out_path)) -register_reporter('cobertura-xml', CoberturaXmlReporter, needs_lxml=True) +register_reporter("cobertura-xml", CoberturaXmlReporter, needs_lxml=True) class AbstractXmlReporter(AbstractReporter): @@ -669,7 +702,7 @@ class AbstractXmlReporter(AbstractReporter): def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) - memory_reporter = reports.add_report('memory-xml', '') + memory_reporter = reports.add_report("memory-xml", "") # The dependency will be called first. self.memory_xml = cast(MemoryXmlReporter, memory_reporter) @@ -684,34 +717,36 @@ class XmlReporter(AbstractXmlReporter): that makes it fail from file:// URLs but work on http:// URLs. """ - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: last_xml = self.memory_xml.last_xml if last_xml is None: return path = os.path.relpath(tree.path) - if path.startswith('..'): + if path.startswith(".."): return - out_path = os.path.join(self.output_dir, 'xml', path + '.xml') + out_path = os.path.join(self.output_dir, "xml", path + ".xml") stats.ensure_dir_exists(os.path.dirname(out_path)) - last_xml.write(out_path, encoding='utf-8') + last_xml.write(out_path, encoding="utf-8") def on_finish(self) -> None: last_xml = self.memory_xml.last_xml assert last_xml is not None - out_path = os.path.join(self.output_dir, 'index.xml') - out_xslt = os.path.join(self.output_dir, 'mypy-html.xslt') - out_css = os.path.join(self.output_dir, 'mypy-html.css') - last_xml.write(out_path, encoding='utf-8') + out_path = os.path.join(self.output_dir, "index.xml") + out_xslt = os.path.join(self.output_dir, "mypy-html.xslt") + out_css = os.path.join(self.output_dir, "mypy-html.css") + last_xml.write(out_path, encoding="utf-8") shutil.copyfile(self.memory_xml.xslt_html_path, out_xslt) shutil.copyfile(self.memory_xml.css_html_path, out_css) - print('Generated XML report:', os.path.abspath(out_path)) + print("Generated XML report:", os.path.abspath(out_path)) -register_reporter('xml', XmlReporter, needs_lxml=True) +register_reporter("xml", XmlReporter, needs_lxml=True) class XsltHtmlReporter(AbstractXmlReporter): @@ -725,38 +760,40 @@ def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) self.xslt_html = etree.XSLT(etree.parse(self.memory_xml.xslt_html_path)) - self.param_html = etree.XSLT.strparam('html') - - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + self.param_html = etree.XSLT.strparam("html") + + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: last_xml = self.memory_xml.last_xml if last_xml is None: return path = os.path.relpath(tree.path) - if path.startswith('..'): + if path.startswith(".."): return - out_path = os.path.join(self.output_dir, 'html', path + '.html') + out_path = os.path.join(self.output_dir, "html", path + ".html") stats.ensure_dir_exists(os.path.dirname(out_path)) transformed_html = bytes(self.xslt_html(last_xml, ext=self.param_html)) - with open(out_path, 'wb') as out_file: + with open(out_path, "wb") as out_file: out_file.write(transformed_html) def on_finish(self) -> None: last_xml = self.memory_xml.last_xml assert last_xml is not None - out_path = os.path.join(self.output_dir, 'index.html') - out_css = os.path.join(self.output_dir, 'mypy-html.css') + out_path = os.path.join(self.output_dir, "index.html") + out_css = os.path.join(self.output_dir, "mypy-html.css") transformed_html = bytes(self.xslt_html(last_xml, ext=self.param_html)) - with open(out_path, 'wb') as out_file: + with open(out_path, "wb") as out_file: out_file.write(transformed_html) shutil.copyfile(self.memory_xml.css_html_path, out_css) - print('Generated HTML report (via XSLT):', os.path.abspath(out_path)) + print("Generated HTML report (via XSLT):", os.path.abspath(out_path)) -register_reporter('xslt-html', XsltHtmlReporter, needs_lxml=True) +register_reporter("xslt-html", XsltHtmlReporter, needs_lxml=True) class XsltTxtReporter(AbstractXmlReporter): @@ -770,27 +807,29 @@ def __init__(self, reports: Reports, output_dir: str) -> None: self.xslt_txt = etree.XSLT(etree.parse(self.memory_xml.xslt_txt_path)) - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: pass def on_finish(self) -> None: last_xml = self.memory_xml.last_xml assert last_xml is not None - out_path = os.path.join(self.output_dir, 'index.txt') + out_path = os.path.join(self.output_dir, "index.txt") transformed_txt = bytes(self.xslt_txt(last_xml)) - with open(out_path, 'wb') as out_file: + with open(out_path, "wb") as out_file: out_file.write(transformed_txt) - print('Generated TXT report (via XSLT):', os.path.abspath(out_path)) + print("Generated TXT report (via XSLT):", os.path.abspath(out_path)) -register_reporter('xslt-txt', XsltTxtReporter, needs_lxml=True) +register_reporter("xslt-txt", XsltTxtReporter, needs_lxml=True) -alias_reporter('xslt-html', 'html') -alias_reporter('xslt-txt', 'txt') +alias_reporter("xslt-html", "html") +alias_reporter("xslt-txt", "txt") class LinePrecisionReporter(AbstractReporter): @@ -812,11 +851,13 @@ def __init__(self, reports: Reports, output_dir: str) -> None: super().__init__(reports, output_dir) self.files: List[FileInfo] = [] - def on_file(self, - tree: MypyFile, - modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - options: Options) -> None: + def on_file( + self, + tree: MypyFile, + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + options: Options, + ) -> None: try: path = os.path.relpath(tree.path) @@ -826,11 +867,13 @@ def on_file(self, if should_skip_path(path): return - visitor = stats.StatisticsVisitor(inferred=True, - filename=tree.fullname, - modules=modules, - typemap=type_map, - all_nodes=True) + visitor = stats.StatisticsVisitor( + inferred=True, + filename=tree.fullname, + modules=modules, + typemap=type_map, + all_nodes=True, + ) tree.accept(visitor) file_info = FileInfo(path, tree._fullname) @@ -845,27 +888,30 @@ def on_finish(self) -> None: # Nothing to do. return output_files = sorted(self.files, key=lambda x: x.module) - report_file = os.path.join(self.output_dir, 'lineprecision.txt') + report_file = os.path.join(self.output_dir, "lineprecision.txt") width = max(4, max(len(info.module) for info in output_files)) - titles = ('Lines', 'Precise', 'Imprecise', 'Any', 'Empty', 'Unanalyzed') + titles = ("Lines", "Precise", "Imprecise", "Any", "Empty", "Unanalyzed") widths = (width,) + tuple(len(t) for t in titles) - fmt = '{:%d} {:%d} {:%d} {:%d} {:%d} {:%d} {:%d}\n' % widths - with open(report_file, 'w') as f: - f.write( - fmt.format('Name', *titles)) - f.write('-' * (width + 51) + '\n') + fmt = "{:%d} {:%d} {:%d} {:%d} {:%d} {:%d} {:%d}\n" % widths + with open(report_file, "w") as f: + f.write(fmt.format("Name", *titles)) + f.write("-" * (width + 51) + "\n") for file_info in output_files: counts = file_info.counts - f.write(fmt.format(file_info.module.ljust(width), - file_info.total(), - counts[stats.TYPE_PRECISE], - counts[stats.TYPE_IMPRECISE], - counts[stats.TYPE_ANY], - counts[stats.TYPE_EMPTY], - counts[stats.TYPE_UNANALYZED])) - - -register_reporter('lineprecision', LinePrecisionReporter) + f.write( + fmt.format( + file_info.module.ljust(width), + file_info.total(), + counts[stats.TYPE_PRECISE], + counts[stats.TYPE_IMPRECISE], + counts[stats.TYPE_ANY], + counts[stats.TYPE_EMPTY], + counts[stats.TYPE_UNANALYZED], + ) + ) + + +register_reporter("lineprecision", LinePrecisionReporter) # Reporter class names are defined twice to speed up mypy startup, as this diff --git a/mypy/sametypes.py b/mypy/sametypes.py index 4fbc9bfc48018..691af147d98fe 100644 --- a/mypy/sametypes.py +++ b/mypy/sametypes.py @@ -1,13 +1,33 @@ -from typing import Sequence, Tuple, Set, List +from typing import List, Sequence, Set, Tuple +from mypy.typeops import is_simple_literal, make_simplified_union, tuple_fallback from mypy.types import ( - Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType, - UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType, - Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType, - ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, ) -from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal def is_same_type(left: Type, right: Type) -> bool: @@ -98,58 +118,67 @@ def visit_deleted_type(self, left: DeletedType) -> bool: return isinstance(self.right, DeletedType) def visit_instance(self, left: Instance) -> bool: - return (isinstance(self.right, Instance) and - left.type == self.right.type and - is_same_types(left.args, self.right.args) and - left.last_known_value == self.right.last_known_value) + return ( + isinstance(self.right, Instance) + and left.type == self.right.type + and is_same_types(left.args, self.right.args) + and left.last_known_value == self.right.last_known_value + ) def visit_type_alias_type(self, left: TypeAliasType) -> bool: # Similar to protocols, two aliases with the same targets return False here, # but both is_subtype(t, s) and is_subtype(s, t) return True. - return (isinstance(self.right, TypeAliasType) and - left.alias == self.right.alias and - is_same_types(left.args, self.right.args)) + return ( + isinstance(self.right, TypeAliasType) + and left.alias == self.right.alias + and is_same_types(left.args, self.right.args) + ) def visit_type_var(self, left: TypeVarType) -> bool: - return (isinstance(self.right, TypeVarType) and - left.id == self.right.id) + return isinstance(self.right, TypeVarType) and left.id == self.right.id def visit_param_spec(self, left: ParamSpecType) -> bool: # Ignore upper bound since it's derived from flavor. - return (isinstance(self.right, ParamSpecType) and - left.id == self.right.id and left.flavor == self.right.flavor) + return ( + isinstance(self.right, ParamSpecType) + and left.id == self.right.id + and left.flavor == self.right.flavor + ) def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool: - return (isinstance(self.right, TypeVarTupleType) and - left.id == self.right.id) + return isinstance(self.right, TypeVarTupleType) and left.id == self.right.id def visit_unpack_type(self, left: UnpackType) -> bool: - return (isinstance(self.right, UnpackType) and - is_same_type(left.type, self.right.type)) + return isinstance(self.right, UnpackType) and is_same_type(left.type, self.right.type) def visit_parameters(self, left: Parameters) -> bool: - return (isinstance(self.right, Parameters) and - left.arg_names == self.right.arg_names and - is_same_types(left.arg_types, self.right.arg_types) and - left.arg_kinds == self.right.arg_kinds) + return ( + isinstance(self.right, Parameters) + and left.arg_names == self.right.arg_names + and is_same_types(left.arg_types, self.right.arg_types) + and left.arg_kinds == self.right.arg_kinds + ) def visit_callable_type(self, left: CallableType) -> bool: # FIX generics if isinstance(self.right, CallableType): cright = self.right - return (is_same_type(left.ret_type, cright.ret_type) and - is_same_types(left.arg_types, cright.arg_types) and - left.arg_names == cright.arg_names and - left.arg_kinds == cright.arg_kinds and - left.is_type_obj() == cright.is_type_obj() and - left.is_ellipsis_args == cright.is_ellipsis_args) + return ( + is_same_type(left.ret_type, cright.ret_type) + and is_same_types(left.arg_types, cright.arg_types) + and left.arg_names == cright.arg_names + and left.arg_kinds == cright.arg_kinds + and left.is_type_obj() == cright.is_type_obj() + and left.is_ellipsis_args == cright.is_ellipsis_args + ) else: return False def visit_tuple_type(self, left: TupleType) -> bool: if isinstance(self.right, TupleType): - return (is_same_type(tuple_fallback(left), tuple_fallback(self.right)) - and is_same_types(left.items, self.right.items)) + return is_same_type( + tuple_fallback(left), tuple_fallback(self.right) + ) and is_same_types(left.items, self.right.items) else: return False diff --git a/mypy/scope.py b/mypy/scope.py index fdc1c1a314fcf..cf7e0514ebe81 100644 --- a/mypy/scope.py +++ b/mypy/scope.py @@ -4,12 +4,12 @@ """ from contextlib import contextmanager -from typing import List, Optional, Iterator, Tuple +from typing import Iterator, List, Optional, Tuple + from typing_extensions import TypeAlias as _TypeAlias from mypy.backports import nullcontext -from mypy.nodes import TypeInfo, FuncBase - +from mypy.nodes import FuncBase, TypeInfo SavedScope: _TypeAlias = Tuple[str, Optional[TypeInfo], Optional[FuncBase]] @@ -33,7 +33,7 @@ def current_target(self) -> str: assert self.module if self.function: fullname = self.function.fullname - return fullname or '' + return fullname or "" return self.module def current_full_target(self) -> str: diff --git a/mypy/semanal.py b/mypy/semanal.py index b803de743c2f9..a5fd094a40572 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -49,103 +49,261 @@ """ from contextlib import contextmanager - from typing import ( - Any, List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator, Iterable + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + TypeVar, + Union, + cast, ) + from typing_extensions import Final, TypeAlias as _TypeAlias +from mypy import errorcodes as codes, message_registry +from mypy.errorcodes import ErrorCode +from mypy.errors import Errors, report_internal_error +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.messages import ( + SUGGESTED_TEST_FIXTURES, + TYPES_FOR_UNIMPORTED_HINTS, + MessageBuilder, + best_matches, + pretty_seq, +) +from mypy.mro import MroError, calculate_mro from mypy.nodes import ( - AssertTypeExpr, MypyFile, TypeInfo, Node, AssignmentStmt, FuncDef, OverloadedFuncDef, - ClassDef, Var, GDEF, FuncItem, Import, Expression, Lvalue, - ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr, - IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, - RaiseStmt, AssertStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, - GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, - SliceExpr, CastExpr, RevealExpr, TypeApplication, Context, SymbolTable, - SymbolTableNode, ListComprehension, GeneratorExpr, - LambdaExpr, MDEF, Decorator, SetExpr, TypeVarExpr, - StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, - ComparisonExpr, StarExpr, ArgKind, ARG_POS, ARG_NAMED, type_aliases, - YieldFromExpr, NamedTupleExpr, NonlocalDecl, SymbolNode, - SetComprehension, DictionaryComprehension, TypeAlias, TypeAliasExpr, - YieldExpr, ExecStmt, BackquoteExpr, ImportBase, AwaitExpr, - IntExpr, FloatExpr, UnicodeExpr, TempNode, OverloadPart, - PlaceholderNode, COVARIANT, CONTRAVARIANT, INVARIANT, - get_nongen_builtins, get_member_expr_fullname, REVEAL_TYPE, - REVEAL_LOCALS, is_final_node, TypedDictExpr, type_aliases_source_versions, + ARG_NAMED, + ARG_POS, + CONTRAVARIANT, + COVARIANT, + GDEF, + INVARIANT, + LDEF, + MDEF, + REVEAL_LOCALS, + REVEAL_TYPE, + RUNTIME_PROTOCOL_DECOS, + ArgKind, + AssertStmt, + AssertTypeExpr, + AssignmentExpr, + AssignmentStmt, + AwaitExpr, + BackquoteExpr, + Block, + BreakStmt, + BytesExpr, + CallExpr, + CastExpr, + ClassDef, + ComparisonExpr, + ConditionalExpr, + Context, + ContinueStmt, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + EnumCallExpr, + ExecStmt, + Expression, + ExpressionStmt, + FakeExpression, + FloatExpr, + ForStmt, + FuncBase, + FuncDef, + FuncItem, + GeneratorExpr, + GlobalDecl, + IfStmt, + Import, + ImportAll, + ImportBase, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + Lvalue, + MatchStmt, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + Node, + NonlocalDecl, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + OverloadPart, + ParamSpecExpr, + PlaceholderNode, + PrintStmt, + PromoteExpr, + RaiseStmt, + RefExpr, + ReturnStmt, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + Statement, + StrExpr, + SuperExpr, + SymbolNode, + SymbolTable, + SymbolTableNode, + TempNode, + TryStmt, + TupleExpr, + TypeAlias, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeInfo, + TypeVarExpr, + TypeVarLikeExpr, + TypeVarTupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + YieldExpr, + YieldFromExpr, + get_member_expr_fullname, + get_nongen_builtins, + implicit_module_attrs, + is_final_node, + type_aliases, + type_aliases_source_versions, typing_extensions_aliases, - EnumCallExpr, RUNTIME_PROTOCOL_DECOS, FakeExpression, Statement, AssignmentExpr, - ParamSpecExpr, EllipsisExpr, TypeVarLikeExpr, implicit_module_attrs, - MatchStmt, FuncBase, TypeVarTupleExpr ) +from mypy.options import Options from mypy.patterns import ( - AsPattern, OrPattern, ValuePattern, SequencePattern, - StarredPattern, MappingPattern, ClassPattern, + AsPattern, + ClassPattern, + MappingPattern, + OrPattern, + SequencePattern, + StarredPattern, + ValuePattern, ) -from mypy.tvar_scope import TypeVarLikeScope -from mypy.typevars import fill_typevars -from mypy.visitor import NodeVisitor -from mypy.errors import Errors, report_internal_error -from mypy.messages import ( - best_matches, MessageBuilder, pretty_seq, SUGGESTED_TEST_FIXTURES, TYPES_FOR_UNIMPORTED_HINTS +from mypy.plugin import ( + ClassDefContext, + DynamicClassDefContext, + Plugin, + SemanticAnalyzerPluginInterface, ) -from mypy.errorcodes import ErrorCode -from mypy import message_registry, errorcodes as codes -from mypy.types import ( - NEVER_NAMES, FunctionLike, UnboundType, TypeVarType, TupleType, UnionType, StarType, - CallableType, Overloaded, Instance, Type, AnyType, LiteralType, LiteralValue, - TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType, - get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType, - PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES, - ASSERT_TYPE_NAMES, OVERLOAD_NAMES, TYPED_NAMEDTUPLE_NAMES, is_named_instance, +from mypy.reachability import ( + ALWAYS_FALSE, + ALWAYS_TRUE, + MYPY_FALSE, + MYPY_TRUE, + infer_condition_value, + infer_reachability_of_if_statement, + infer_reachability_of_match_statement, ) -from mypy.typeops import function_type, get_type_vars +from mypy.scope import Scope +from mypy.semanal_enum import EnumCallAnalyzer +from mypy.semanal_namedtuple import NamedTupleAnalyzer +from mypy.semanal_newtype import NewTypeAnalyzer +from mypy.semanal_shared import ( + PRIORITY_FALLBACKS, + SemanticAnalyzerInterface, + calculate_tuple_fallback, + set_callable_name, +) +from mypy.semanal_typeddict import TypedDictAnalyzer +from mypy.tvar_scope import TypeVarLikeScope from mypy.type_visitor import TypeQuery from mypy.typeanal import ( - TypeAnalyser, analyze_type_alias, no_subscript_builtin_alias, - TypeVarLikeQuery, TypeVarLikeList, remove_dups, has_any_from_unimported_type, - check_for_explicit_any, type_constructors, fix_instance_types + TypeAnalyser, + TypeVarLikeList, + TypeVarLikeQuery, + analyze_type_alias, + check_for_explicit_any, + fix_instance_types, + has_any_from_unimported_type, + no_subscript_builtin_alias, + remove_dups, + type_constructors, ) -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError -from mypy.options import Options -from mypy.plugin import ( - Plugin, ClassDefContext, SemanticAnalyzerPluginInterface, - DynamicClassDefContext +from mypy.typeops import function_type, get_type_vars +from mypy.types import ( + ASSERT_TYPE_NAMES, + FINAL_DECORATOR_NAMES, + FINAL_TYPE_NAMES, + NEVER_NAMES, + OVERLOAD_NAMES, + PROTOCOL_NAMES, + REVEAL_TYPE_NAMES, + TPDICT_NAMES, + TYPE_ALIAS_NAMES, + TYPED_NAMEDTUPLE_NAMES, + AnyType, + CallableType, + FunctionLike, + Instance, + LiteralType, + LiteralValue, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PlaceholderType, + ProperType, + StarType, + TupleType, + Type, + TypeAliasType, + TypeOfAny, + TypeTranslator, + TypeType, + TypeVarLikeType, + TypeVarType, + UnboundType, + UnionType, + get_proper_type, + get_proper_types, + is_named_instance, ) +from mypy.typevars import fill_typevars from mypy.util import ( - correct_relative_import, unmangle, module_prefix, is_typeshed_file, unnamed_function, + correct_relative_import, is_dunder, + is_typeshed_file, + module_prefix, + unmangle, + unnamed_function, ) -from mypy.scope import Scope -from mypy.semanal_shared import ( - SemanticAnalyzerInterface, set_callable_name, calculate_tuple_fallback, PRIORITY_FALLBACKS -) -from mypy.semanal_namedtuple import NamedTupleAnalyzer -from mypy.semanal_typeddict import TypedDictAnalyzer -from mypy.semanal_enum import EnumCallAnalyzer -from mypy.semanal_newtype import NewTypeAnalyzer -from mypy.reachability import ( - infer_reachability_of_if_statement, infer_reachability_of_match_statement, - infer_condition_value, ALWAYS_FALSE, ALWAYS_TRUE, MYPY_TRUE, MYPY_FALSE -) -from mypy.mro import calculate_mro, MroError +from mypy.visitor import NodeVisitor -T = TypeVar('T') +T = TypeVar("T") FUTURE_IMPORTS: Final = { - '__future__.nested_scopes': 'nested_scopes', - '__future__.generators': 'generators', - '__future__.division': 'division', - '__future__.absolute_import': 'absolute_import', - '__future__.with_statement': 'with_statement', - '__future__.print_function': 'print_function', - '__future__.unicode_literals': 'unicode_literals', - '__future__.barry_as_FLUFL': 'barry_as_FLUFL', - '__future__.generator_stop': 'generator_stop', - '__future__.annotations': 'annotations', + "__future__.nested_scopes": "nested_scopes", + "__future__.generators": "generators", + "__future__.division": "division", + "__future__.absolute_import": "absolute_import", + "__future__.with_statement": "with_statement", + "__future__.print_function": "print_function", + "__future__.unicode_literals": "unicode_literals", + "__future__.barry_as_FLUFL": "barry_as_FLUFL", + "__future__.generator_stop": "generator_stop", + "__future__.annotations": "annotations", } @@ -155,23 +313,23 @@ # Subclasses can override these Var attributes with incompatible types. This can also be # set for individual attributes using 'allow_incompatible_override' of Var. -ALLOW_INCOMPATIBLE_OVERRIDE: Final = ('__slots__', '__deletable__', '__match_args__') +ALLOW_INCOMPATIBLE_OVERRIDE: Final = ("__slots__", "__deletable__", "__match_args__") # Used for tracking incomplete references Tag: _TypeAlias = int -class SemanticAnalyzer(NodeVisitor[None], - SemanticAnalyzerInterface, - SemanticAnalyzerPluginInterface): +class SemanticAnalyzer( + NodeVisitor[None], SemanticAnalyzerInterface, SemanticAnalyzerPluginInterface +): """Semantically analyze parsed mypy files. The analyzer binds names and does various consistency checks for an AST. Note that type checking is performed as a separate pass. """ - __deletable__ = ['patches', 'options', 'cur_mod_node'] + __deletable__ = ["patches", "options", "cur_mod_node"] # Module name space modules: Dict[str, MypyFile] @@ -221,9 +379,9 @@ class SemanticAnalyzer(NodeVisitor[None], missing_names: List[Set[str]] # Callbacks that will be called after semantic analysis to tweak things. patches: List[Tuple[int, Callable[[], None]]] - loop_depth = 0 # Depth of breakable loops - cur_mod_id = '' # Current module id (or None) (phase 2) - _is_stub_file = False # Are we analyzing a stub file? + loop_depth = 0 # Depth of breakable loops + cur_mod_id = "" # Current module id (or None) (phase 2) + _is_stub_file = False # Are we analyzing a stub file? _is_typeshed_stub_file = False # Are we analyzing a typeshed stub file? imports: Set[str] # Imported modules (during phase 2 analysis) # Note: some imports (and therefore dependencies) might @@ -238,12 +396,14 @@ class SemanticAnalyzer(NodeVisitor[None], # type is stored in this mapping and that it still matches. wrapped_coro_return_types: Dict[FuncDef, Type] = {} - def __init__(self, - modules: Dict[str, MypyFile], - missing_modules: Set[str], - incomplete_namespaces: Set[str], - errors: Errors, - plugin: Plugin) -> None: + def __init__( + self, + modules: Dict[str, MypyFile], + missing_modules: Set[str], + incomplete_namespaces: Set[str], + errors: Errors, + plugin: Plugin, + ) -> None: """Construct semantic analyzer. We reuse the same semantic analyzer instance across multiple modules. @@ -315,18 +475,16 @@ def final_iteration(self) -> bool: def prepare_file(self, file_node: MypyFile) -> None: """Prepare a freshly parsed file for semantic analysis.""" - if 'builtins' in self.modules: - file_node.names['__builtins__'] = SymbolTableNode(GDEF, - self.modules['builtins']) - if file_node.fullname == 'builtins': + if "builtins" in self.modules: + file_node.names["__builtins__"] = SymbolTableNode(GDEF, self.modules["builtins"]) + if file_node.fullname == "builtins": self.prepare_builtins_namespace(file_node) - if file_node.fullname == 'typing': + if file_node.fullname == "typing": self.prepare_typing_namespace(file_node, type_aliases) - if file_node.fullname == 'typing_extensions': + if file_node.fullname == "typing_extensions": self.prepare_typing_namespace(file_node, typing_extensions_aliases) - def prepare_typing_namespace(self, file_node: MypyFile, - aliases: Dict[str, str]) -> None: + def prepare_typing_namespace(self, file_node: MypyFile, aliases: Dict[str, str]) -> None: """Remove dummy alias definitions such as List = TypeAlias(object) from typing. They will be replaced with real aliases when corresponding targets are ready. @@ -345,10 +503,13 @@ def helper(defs: List[Statement]) -> None: helper(body.body) if stmt.else_body: helper(stmt.else_body.body) - if (isinstance(stmt, AssignmentStmt) and len(stmt.lvalues) == 1 and - isinstance(stmt.lvalues[0], NameExpr)): + if ( + isinstance(stmt, AssignmentStmt) + and len(stmt.lvalues) == 1 + and isinstance(stmt.lvalues[0], NameExpr) + ): # Assignment to a simple name, remove it if it is a dummy alias. - if f'{file_node.fullname}.{stmt.lvalues[0].name}' in aliases: + if f"{file_node.fullname}.{stmt.lvalues[0].name}" in aliases: defs.remove(stmt) helper(file_node.defs) @@ -365,43 +526,45 @@ def prepare_builtins_namespace(self, file_node: MypyFile) -> None: # operation. These will be completed later on. for name in CORE_BUILTIN_CLASSES: cdef = ClassDef(name, Block([])) # Dummy ClassDef, will be replaced later - info = TypeInfo(SymbolTable(), cdef, 'builtins') - info._fullname = f'builtins.{name}' + info = TypeInfo(SymbolTable(), cdef, "builtins") + info._fullname = f"builtins.{name}" names[name] = SymbolTableNode(GDEF, info) - bool_info = names['bool'].node + bool_info = names["bool"].node assert isinstance(bool_info, TypeInfo) bool_type = Instance(bool_info, []) special_var_types: List[Tuple[str, Type]] = [ - ('None', NoneType()), + ("None", NoneType()), # reveal_type is a mypy-only function that gives an error with # the type of its arg. - ('reveal_type', AnyType(TypeOfAny.special_form)), + ("reveal_type", AnyType(TypeOfAny.special_form)), # reveal_locals is a mypy-only function that gives an error with the types of # locals - ('reveal_locals', AnyType(TypeOfAny.special_form)), - ('True', bool_type), - ('False', bool_type), - ('__debug__', bool_type), + ("reveal_locals", AnyType(TypeOfAny.special_form)), + ("True", bool_type), + ("False", bool_type), + ("__debug__", bool_type), ] for name, typ in special_var_types: v = Var(name, typ) - v._fullname = f'builtins.{name}' + v._fullname = f"builtins.{name}" file_node.names[name] = SymbolTableNode(GDEF, v) # # Analyzing a target # - def refresh_partial(self, - node: Union[MypyFile, FuncDef, OverloadedFuncDef], - patches: List[Tuple[int, Callable[[], None]]], - final_iteration: bool, - file_node: MypyFile, - options: Options, - active_type: Optional[TypeInfo] = None) -> None: + def refresh_partial( + self, + node: Union[MypyFile, FuncDef, OverloadedFuncDef], + patches: List[Tuple[int, Callable[[], None]]], + final_iteration: bool, + file_node: MypyFile, + options: Options, + active_type: Optional[TypeInfo] = None, + ) -> None: """Refresh a stale target in fine-grained incremental mode.""" self.patches = patches self.deferred = False @@ -423,9 +586,9 @@ def refresh_top_level(self, file_node: MypyFile) -> None: self.add_implicit_module_attrs(file_node) for d in file_node.defs: self.accept(d) - if file_node.fullname == 'typing': + if file_node.fullname == "typing": self.add_builtin_aliases(file_node) - if file_node.fullname == 'typing_extensions': + if file_node.fullname == "typing_extensions": self.add_typing_extension_aliases(file_node) self.adjust_public_exports() self.export_map[self.cur_mod_id] = self.all_exports @@ -435,13 +598,14 @@ def add_implicit_module_attrs(self, file_node: MypyFile) -> None: """Manually add implicit definitions of module '__name__' etc.""" for name, t in implicit_module_attrs.items(): # unicode docstrings should be accepted in Python 2 - if name == '__doc__': + if name == "__doc__": if self.options.python_version >= (3, 0): typ: Type = UnboundType("__builtins__.str") else: - typ = UnionType([UnboundType('__builtins__.str'), - UnboundType('__builtins__.unicode')]) - elif name == '__path__': + typ = UnionType( + [UnboundType("__builtins__.str"), UnboundType("__builtins__.unicode")] + ) + elif name == "__path__": if not file_node.is_package_init_file(): continue # Need to construct the type ourselves, to avoid issues with __builtins__.list @@ -452,7 +616,7 @@ def add_implicit_module_attrs(self, file_node: MypyFile) -> None: node = sym.node assert isinstance(node, TypeInfo) typ = Instance(node, [self.str_type()]) - elif name == '__annotations__': + elif name == "__annotations__": sym = self.lookup_qualified("__builtins__.dict", Context(), suppress_errors=True) if not sym: continue @@ -460,7 +624,7 @@ def add_implicit_module_attrs(self, file_node: MypyFile) -> None: assert isinstance(node, TypeInfo) typ = Instance(node, [self.str_type(), AnyType(TypeOfAny.special_form)]) else: - assert t is not None, f'type should be specified for {name}' + assert t is not None, f"type should be specified for {name}" typ = UnboundType(t) existing = file_node.names.get(name) @@ -475,9 +639,11 @@ def add_implicit_module_attrs(self, file_node: MypyFile) -> None: var.is_ready = True self.add_symbol(name, var, dummy_context()) else: - self.add_symbol(name, - PlaceholderNode(self.qualified_name(name), file_node, -1), - dummy_context()) + self.add_symbol( + name, + PlaceholderNode(self.qualified_name(name), file_node, -1), + dummy_context(), + ) def add_builtin_aliases(self, tree: MypyFile) -> None: """Add builtin type aliases to typing module. @@ -487,12 +653,12 @@ def add_builtin_aliases(self, tree: MypyFile) -> None: corresponding nodes on the fly. We explicitly mark these aliases as normalized, so that a user can write `typing.List[int]`. """ - assert tree.fullname == 'typing' + assert tree.fullname == "typing" for alias, target_name in type_aliases.items(): if type_aliases_source_versions[alias] > self.options.python_version: # This alias is not available on this Python version. continue - name = alias.split('.')[-1] + name = alias.split(".")[-1] if name in tree.names and not isinstance(tree.names[name].node, PlaceholderNode): continue self.create_alias(tree, target_name, alias, name) @@ -504,10 +670,10 @@ def add_typing_extension_aliases(self, tree: MypyFile) -> None: they are just defined as `_Alias()` call. Which is not supported natively. """ - assert tree.fullname == 'typing_extensions' + assert tree.fullname == "typing_extensions" for alias, target_name in typing_extensions_aliases.items(): - name = alias.split('.')[-1] + name = alias.split(".")[-1] if name in tree.names and isinstance(tree.names[name].node, TypeAlias): continue # Do not reset TypeAliases on the second pass. @@ -529,9 +695,14 @@ def create_alias(self, tree: MypyFile, target_name: str, alias: str, name: str) assert target is not None # Transform List to List[Any], etc. fix_instance_types(target, self.fail, self.note, self.options.python_version) - alias_node = TypeAlias(target, alias, - line=-1, column=-1, # there is no context - no_args=True, normalized=True) + alias_node = TypeAlias( + target, + alias, + line=-1, + column=-1, # there is no context + no_args=True, + normalized=True, + ) self.add_symbol(name, alias_node, tree) elif self.found_incomplete_ref(tag): # Built-in class target may not ready yet -- defer. @@ -545,7 +716,7 @@ def create_alias(self, tree: MypyFile, target_name: str, alias: str, name: str) def adjust_public_exports(self) -> None: """Adjust the module visibility of globals due to __all__.""" - if '__all__' in self.globals: + if "__all__" in self.globals: for name, g in self.globals.items(): # Being included in __all__ explicitly exports and makes public. if name in self.all_exports: @@ -557,10 +728,9 @@ def adjust_public_exports(self) -> None: g.module_public = False @contextmanager - def file_context(self, - file_node: MypyFile, - options: Options, - active_type: Optional[TypeInfo] = None) -> Iterator[None]: + def file_context( + self, file_node: MypyFile, options: Options, active_type: Optional[TypeInfo] = None + ) -> Iterator[None]: """Configure analyzer for analyzing targets within a file/class. Args: @@ -574,7 +744,7 @@ def file_context(self, self.cur_mod_node = file_node self.cur_mod_id = file_node.fullname with scope.module_scope(self.cur_mod_id): - self._is_stub_file = file_node.path.lower().endswith('.pyi') + self._is_stub_file = file_node.path.lower().endswith(".pyi") self._is_typeshed_stub_file = is_typeshed_file(file_node.path) self.globals = file_node.names self.tvar_scope = TypeVarLikeScope() @@ -648,7 +818,7 @@ def analyze_func_def(self, defn: FuncDef) -> None: # Method definition assert self.type is not None defn.info = self.type - if defn.type is not None and defn.name in ('__init__', '__init_subclass__'): + if defn.type is not None and defn.name in ("__init__", "__init_subclass__"): assert isinstance(defn.type, CallableType) if isinstance(get_proper_type(defn.type.ret_type), AnyType): defn.type = defn.type.copy_modified(ret_type=NoneType()) @@ -678,9 +848,11 @@ def analyze_func_def(self, defn: FuncDef) -> None: self.analyze_arg_initializers(defn) self.analyze_function_body(defn) - if (defn.is_coroutine and - isinstance(defn.type, CallableType) and - self.wrapped_coro_return_types.get(defn) != defn.type): + if ( + defn.is_coroutine + and isinstance(defn.type, CallableType) + and self.wrapped_coro_return_types.get(defn) != defn.type + ): if defn.is_async_generator: # Async generator types are handled elsewhere pass @@ -688,8 +860,9 @@ def analyze_func_def(self, defn: FuncDef) -> None: # A coroutine defined as `async def foo(...) -> T: ...` # has external return type `Coroutine[Any, Any, T]`. any_type = AnyType(TypeOfAny.special_form) - ret_type = self.named_type_or_none('typing.Coroutine', - [any_type, any_type, defn.type.ret_type]) + ret_type = self.named_type_or_none( + "typing.Coroutine", [any_type, any_type, defn.type.ret_type] + ) assert ret_type is not None, "Internal error: typing.Coroutine not found" defn.type = defn.type.copy_modified(ret_type=ret_type) self.wrapped_coro_return_types[defn] = defn.type @@ -699,15 +872,15 @@ def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None: # Only non-static methods are special. functype = func.type if not func.is_static: - if func.name in ['__init_subclass__', '__class_getitem__']: + if func.name in ["__init_subclass__", "__class_getitem__"]: func.is_class = True if not func.arguments: - self.fail('Method must have at least one argument', func) + self.fail("Method must have at least one argument", func) elif isinstance(functype, CallableType): self_type = get_proper_type(functype.arg_types[0]) if isinstance(self_type, AnyType): leading_type: Type = fill_typevars(info) - if func.is_class or func.name == '__new__': + if func.is_class or func.name == "__new__": leading_type = self.class_type(leading_type) func.type = replace_implicit_first_type(functype, leading_type) @@ -776,7 +949,7 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: # This is a property. first_item.func.is_overload = True self.analyze_property_with_multi_part_definition(defn) - typ = function_type(first_item.func, self.named_type('builtins.function')) + typ = function_type(first_item.func, self.named_type("builtins.function")) assert isinstance(typ, CallableType) types = [typ] else: @@ -786,8 +959,9 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: types, impl, non_overload_indexes = self.analyze_overload_sigs_and_impl(defn) defn.impl = impl if non_overload_indexes: - self.handle_missing_overload_decorators(defn, non_overload_indexes, - some_overload_decorators=len(types) > 0) + self.handle_missing_overload_decorators( + defn, non_overload_indexes, some_overload_decorators=len(types) > 0 + ) # If we found an implementation, remove it from the overload item list, # as it's special. if impl is not None: @@ -814,10 +988,8 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: self.process_static_or_class_method_in_overload(defn) def analyze_overload_sigs_and_impl( - self, - defn: OverloadedFuncDef) -> Tuple[List[CallableType], - Optional[OverloadPart], - List[int]]: + self, defn: OverloadedFuncDef + ) -> Tuple[List[CallableType], Optional[OverloadPart], List[int]]: """Find overload signatures, the implementation, and items with missing @overload. Assume that the first was already analyzed. As a side effect: @@ -833,10 +1005,9 @@ def analyze_overload_sigs_and_impl( item.accept(self) # TODO: support decorated overloaded functions properly if isinstance(item, Decorator): - callable = function_type(item.func, self.named_type('builtins.function')) + callable = function_type(item.func, self.named_type("builtins.function")) assert isinstance(callable, CallableType) - if not any(refers_to_fullname(dec, OVERLOAD_NAMES) - for dec in item.decorators): + if not any(refers_to_fullname(dec, OVERLOAD_NAMES) for dec in item.decorators): if i == len(defn.items) - 1 and not self.is_stub_file: # Last item outside a stub is impl impl = item @@ -855,10 +1026,12 @@ def analyze_overload_sigs_and_impl( non_overload_indexes.append(i) return types, impl, non_overload_indexes - def handle_missing_overload_decorators(self, - defn: OverloadedFuncDef, - non_overload_indexes: List[int], - some_overload_decorators: bool) -> None: + def handle_missing_overload_decorators( + self, + defn: OverloadedFuncDef, + non_overload_indexes: List[int], + some_overload_decorators: bool, + ) -> None: """Generate errors for overload items without @overload. Side effect: remote non-overload items. @@ -867,11 +1040,16 @@ def handle_missing_overload_decorators(self, # Some of them were overloads, but not all. for idx in non_overload_indexes: if self.is_stub_file: - self.fail("An implementation for an overloaded function " - "is not allowed in a stub file", defn.items[idx]) + self.fail( + "An implementation for an overloaded function " + "is not allowed in a stub file", + defn.items[idx], + ) else: - self.fail("The implementation for an overloaded function " - "must come last", defn.items[idx]) + self.fail( + "The implementation for an overloaded function " "must come last", + defn.items[idx], + ) else: for idx in non_overload_indexes[1:]: self.name_already_defined(defn.name, defn.items[idx], defn.items[0]) @@ -894,7 +1072,9 @@ def handle_missing_overload_implementation(self, defn: OverloadedFuncDef) -> Non else: self.fail( "An overloaded function outside a stub file must have an implementation", - defn, code=codes.NO_OVERLOAD_IMPL) + defn, + code=codes.NO_OVERLOAD_IMPL, + ) def process_final_in_overload(self, defn: OverloadedFuncDef) -> None: """Detect the @final status of an overloaded function (and perform checks).""" @@ -906,12 +1086,12 @@ def process_final_in_overload(self, defn: OverloadedFuncDef) -> None: # Only show the error once per overload bad_final = next(ov for ov in defn.items if ov.is_final) if not self.is_stub_file: - self.fail("@final should be applied only to overload implementation", - bad_final) + self.fail("@final should be applied only to overload implementation", bad_final) elif any(item.is_final for item in defn.items[1:]): bad_final = next(ov for ov in defn.items[1:] if ov.is_final) - self.fail("In a stub file @final must be applied only to the first overload", - bad_final) + self.fail( + "In a stub file @final must be applied only to the first overload", bad_final + ) if defn.impl is not None and defn.impl.is_final: defn.is_final = True @@ -939,9 +1119,9 @@ def process_static_or_class_method_in_overload(self, defn: OverloadedFuncDef) -> static_status.append(inner.is_static) if len(set(class_status)) != 1: - self.msg.overload_inconsistently_applies_decorator('classmethod', defn) + self.msg.overload_inconsistently_applies_decorator("classmethod", defn) elif len(set(static_status)) != 1: - self.msg.overload_inconsistently_applies_decorator('staticmethod', defn) + self.msg.overload_inconsistently_applies_decorator("staticmethod", defn) else: defn.is_class = class_status[0] defn.is_static = static_status[0] @@ -960,7 +1140,7 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) - if len(item.decorators) == 1: node = item.decorators[0] if isinstance(node, MemberExpr): - if node.name == 'setter': + if node.name == "setter": # The first item represents the entire property. first_item.var.is_settable_property = True # Get abstractness from the original definition. @@ -969,8 +1149,7 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) - self.fail("Decorated property not supported", item) item.func.accept(self) else: - self.fail(f'Unexpected definition for property "{first_item.func.name}"', - item) + self.fail(f'Unexpected definition for property "{first_item.func.name}"', item) deleted_items.append(i + 1) for i in reversed(deleted_items): del items[i] @@ -1027,13 +1206,13 @@ def check_function_signature(self, fdef: FuncItem) -> None: sig = fdef.type assert isinstance(sig, CallableType) if len(sig.arg_types) < len(fdef.arguments): - self.fail('Type signature has too few arguments', fdef) + self.fail("Type signature has too few arguments", fdef) # Add dummy Any arguments to prevent crashes later. num_extra_anys = len(fdef.arguments) - len(sig.arg_types) extra_anys = [AnyType(TypeOfAny.from_error)] * num_extra_anys sig.arg_types.extend(extra_anys) elif len(sig.arg_types) > len(fdef.arguments): - self.fail('Type signature has too many arguments', fdef, blocker=True) + self.fail("Type signature has too many arguments", fdef, blocker=True) def visit_decorator(self, dec: Decorator) -> None: self.statement = dec @@ -1049,38 +1228,37 @@ def visit_decorator(self, dec: Decorator) -> None: no_type_check = False for i, d in enumerate(dec.decorators): # A bunch of decorators are special cased here. - if refers_to_fullname(d, 'abc.abstractmethod'): + if refers_to_fullname(d, "abc.abstractmethod"): removed.append(i) dec.func.is_abstract = True - self.check_decorated_function_is_method('abstractmethod', dec) - elif refers_to_fullname(d, ('asyncio.coroutines.coroutine', 'types.coroutine')): + self.check_decorated_function_is_method("abstractmethod", dec) + elif refers_to_fullname(d, ("asyncio.coroutines.coroutine", "types.coroutine")): removed.append(i) dec.func.is_awaitable_coroutine = True - elif refers_to_fullname(d, 'builtins.staticmethod'): + elif refers_to_fullname(d, "builtins.staticmethod"): removed.append(i) dec.func.is_static = True dec.var.is_staticmethod = True - self.check_decorated_function_is_method('staticmethod', dec) - elif refers_to_fullname(d, 'builtins.classmethod'): + self.check_decorated_function_is_method("staticmethod", dec) + elif refers_to_fullname(d, "builtins.classmethod"): removed.append(i) dec.func.is_class = True dec.var.is_classmethod = True - self.check_decorated_function_is_method('classmethod', dec) - elif refers_to_fullname(d, ( - 'builtins.property', - 'abc.abstractproperty', - 'functools.cached_property')): + self.check_decorated_function_is_method("classmethod", dec) + elif refers_to_fullname( + d, ("builtins.property", "abc.abstractproperty", "functools.cached_property") + ): removed.append(i) dec.func.is_property = True dec.var.is_property = True - if refers_to_fullname(d, 'abc.abstractproperty'): + if refers_to_fullname(d, "abc.abstractproperty"): dec.func.is_abstract = True - elif refers_to_fullname(d, 'functools.cached_property'): + elif refers_to_fullname(d, "functools.cached_property"): dec.var.is_settable_property = True - self.check_decorated_function_is_method('property', dec) + self.check_decorated_function_is_method("property", dec) if len(dec.func.arguments) > 1: - self.fail('Too many arguments', dec.func) - elif refers_to_fullname(d, 'typing.no_type_check'): + self.fail("Too many arguments", dec.func) + elif refers_to_fullname(d, "typing.no_type_check"): dec.var.type = AnyType(TypeOfAny.special_form) no_type_check = True elif refers_to_fullname(d, FINAL_DECORATOR_NAMES): @@ -1102,12 +1280,11 @@ def visit_decorator(self, dec: Decorator) -> None: if not no_type_check and self.recurse_into_functions: dec.func.accept(self) if dec.decorators and dec.var.is_property: - self.fail('Decorated property not supported', dec) + self.fail("Decorated property not supported", dec) if dec.func.is_abstract and dec.func.is_final: self.fail(f"Method {dec.func.name} is both abstract and final", dec) - def check_decorated_function_is_method(self, decorator: str, - context: Context) -> None: + def check_decorated_function_is_method(self, decorator: str, context: Context) -> None: if not self.type or self.is_func_scope(): self.fail(f'"{decorator}" used with a non-method', context) @@ -1143,11 +1320,13 @@ def analyze_class(self, defn: ClassDef) -> None: bases = defn.base_type_exprs bases, tvar_defs, is_protocol = self.clean_up_bases_and_infer_type_variables( - defn, bases, context=defn) + defn, bases, context=defn + ) for tvd in tvar_defs: - if (isinstance(tvd, TypeVarType) - and any(has_placeholder(t) for t in [tvd.upper_bound] + tvd.values)): + if isinstance(tvd, TypeVarType) and any( + has_placeholder(t) for t in [tvd.upper_bound] + tvd.values + ): # Some type variable bounds or values are not ready, we need # to re-analyze this class. self.defer() @@ -1203,7 +1382,7 @@ def analyze_class(self, defn: ClassDef) -> None: self.analyze_class_body_common(defn) def is_core_builtin_class(self, defn: ClassDef) -> bool: - return self.cur_mod_id == 'builtins' and defn.name in CORE_BUILTIN_CLASSES + return self.cur_mod_id == "builtins" and defn.name in CORE_BUILTIN_CLASSES def analyze_class_body_common(self, defn: ClassDef) -> None: """Parts of class body analysis that are common to all kinds of class defs.""" @@ -1220,7 +1399,8 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: is_named_tuple, info = True, defn.info # type: bool, Optional[TypeInfo] else: is_named_tuple, info = self.named_tuple_analyzer.analyze_namedtuple_classdef( - defn, self.is_stub_file, self.is_func_scope()) + defn, self.is_stub_file, self.is_func_scope() + ) if is_named_tuple: if info is None: self.mark_incomplete(defn.name, defn) @@ -1286,7 +1466,7 @@ def enter_class(self, info: TypeInfo) -> None: self.missing_names.append(set()) def leave_class(self) -> None: - """ Restore analyzer state. """ + """Restore analyzer state.""" self.block_depth.pop() self.locals.pop() self.is_comprehension_stack.pop() @@ -1300,18 +1480,13 @@ def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None if defn.info.is_protocol: defn.info.runtime_protocol = True else: - self.fail('@runtime_checkable can only be used with protocol classes', - defn) + self.fail("@runtime_checkable can only be used with protocol classes", defn) elif decorator.fullname in FINAL_DECORATOR_NAMES: defn.info.is_final = True def clean_up_bases_and_infer_type_variables( - self, - defn: ClassDef, - base_type_exprs: List[Expression], - context: Context) -> Tuple[List[Expression], - List[TypeVarLikeType], - bool]: + self, defn: ClassDef, base_type_exprs: List[Expression], context: Context + ) -> Tuple[List[Expression], List[TypeVarLikeType], bool]: """Remove extra base classes such as Generic and infer type vars. For example, consider this class: @@ -1339,7 +1514,7 @@ class Foo(Bar, Generic[T]): ... result = self.analyze_class_typevar_declaration(base) if result is not None: if declared_tvars: - self.fail('Only single Generic[...] or Protocol[...] can be in bases', context) + self.fail("Only single Generic[...] or Protocol[...] can be in bases", context) removed.append(i) tvars = result[0] is_protocol |= result[1] @@ -1358,8 +1533,11 @@ class Foo(Bar, Generic[T]): ... self.fail("Duplicate type variables in Generic[...] or Protocol[...]", context) declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): - self.fail("If Generic[...] or Protocol[...] is present" - " it should list all type variables", context) + self.fail( + "If Generic[...] or Protocol[...] is present" + " it should list all type variables", + context, + ) # In case of error, Generic tvars will go first declared_tvars = remove_dups(declared_tvars + all_tvars) else: @@ -1377,8 +1555,7 @@ class Foo(Bar, Generic[T]): ... return base_type_exprs, tvar_defs, is_protocol def analyze_class_typevar_declaration( - self, - base: Type + self, base: Type ) -> Optional[Tuple[TypeVarLikeList, bool]]: """Analyze type variables declared using Generic[...] or Protocol[...]. @@ -1394,9 +1571,12 @@ def analyze_class_typevar_declaration( sym = self.lookup_qualified(unbound.name, unbound) if sym is None or sym.node is None: return None - if (sym.node.fullname == 'typing.Generic' or - sym.node.fullname in PROTOCOL_NAMES and base.args): - is_proto = sym.node.fullname != 'typing.Generic' + if ( + sym.node.fullname == "typing.Generic" + or sym.node.fullname in PROTOCOL_NAMES + and base.args + ): + is_proto = sym.node.fullname != "typing.Generic" tvars: TypeVarLikeList = [] for arg in unbound.args: tag = self.track_incomplete_refs() @@ -1404,8 +1584,7 @@ def analyze_class_typevar_declaration( if tvar: tvars.append(tvar) elif not self.found_incomplete_ref(tag): - self.fail('Free type variable expected in %s[...]' % - sym.node.name, base) + self.fail("Free type variable expected in %s[...]" % sym.node.name, base) return tvars, is_proto return None @@ -1435,9 +1614,9 @@ def analyze_unbound_tvar(self, t: Type) -> Optional[Tuple[str, TypeVarLikeExpr]] assert isinstance(sym.node, TypeVarExpr) return unbound.name, sym.node - def get_all_bases_tvars(self, - base_type_exprs: List[Expression], - removed: List[int]) -> TypeVarLikeList: + def get_all_bases_tvars( + self, base_type_exprs: List[Expression], removed: List[int] + ) -> TypeVarLikeList: """Return all type variable references in bases.""" tvars: TypeVarLikeList = [] for i, base_expr in enumerate(base_type_exprs): @@ -1468,8 +1647,8 @@ def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) -> else: info._fullname = info.name local_name = defn.name - if '@' in local_name: - local_name = local_name.split('@')[0] + if "@" in local_name: + local_name = local_name.split("@")[0] self.add_symbol(local_name, defn.info, defn) if self.is_nested_within_func_scope(): # We need to preserve local classes, let's store them @@ -1478,9 +1657,9 @@ def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) -> # TODO: Putting local classes into globals breaks assumptions in fine-grained # incremental mode and we should avoid it. In general, this logic is too # ad-hoc and needs to be removed/refactored. - if '@' not in defn.info._fullname: - global_name = defn.info.name + '@' + str(defn.line) - defn.info._fullname = self.cur_mod_id + '.' + global_name + if "@" not in defn.info._fullname: + global_name = defn.info.name + "@" + str(defn.line) + defn.info._fullname = self.cur_mod_id + "." + global_name else: # Preserve name from previous fine-grained incremental run. global_name = defn.info.name @@ -1492,9 +1671,11 @@ def prepare_class_def(self, defn: ClassDef, info: Optional[TypeInfo] = None) -> self.globals[global_name] = SymbolTableNode(GDEF, defn.info) def make_empty_type_info(self, defn: ClassDef) -> TypeInfo: - if (self.is_module_scope() - and self.cur_mod_id == 'builtins' - and defn.name in CORE_BUILTIN_CLASSES): + if ( + self.is_module_scope() + and self.cur_mod_id == "builtins" + and defn.name in CORE_BUILTIN_CLASSES + ): # Special case core built-in classes. A TypeInfo was already # created for it before semantic analysis, but with a dummy # ClassDef. Patch the real ClassDef object. @@ -1518,10 +1699,8 @@ def get_name_repr_of_expr(self, expr: Expression) -> Optional[str]: return None def analyze_base_classes( - self, - base_type_exprs: List[Expression]) -> Optional[Tuple[List[Tuple[ProperType, - Expression]], - bool]]: + self, base_type_exprs: List[Expression] + ) -> Optional[Tuple[List[Tuple[ProperType, Expression]], bool]]: """Analyze base class types. Return None if some definition was incomplete. Otherwise, return a tuple @@ -1533,8 +1712,10 @@ def analyze_base_classes( is_error = False bases = [] for base_expr in base_type_exprs: - if (isinstance(base_expr, RefExpr) and - base_expr.fullname in TYPED_NAMEDTUPLE_NAMES + TPDICT_NAMES): + if ( + isinstance(base_expr, RefExpr) + and base_expr.fullname in TYPED_NAMEDTUPLE_NAMES + TPDICT_NAMES + ): # Ignore magic bases for now. continue @@ -1543,9 +1724,9 @@ def analyze_base_classes( except TypeTranslationError: name = self.get_name_repr_of_expr(base_expr) if isinstance(base_expr, CallExpr): - msg = 'Unsupported dynamic base class' + msg = "Unsupported dynamic base class" else: - msg = 'Invalid base class' + msg = "Invalid base class" if name: msg += f' "{name}"' self.fail(msg, base_expr) @@ -1557,9 +1738,9 @@ def analyze_base_classes( bases.append((base, base_expr)) return bases, is_error - def configure_base_classes(self, - defn: ClassDef, - bases: List[Tuple[ProperType, Expression]]) -> None: + def configure_base_classes( + self, defn: ClassDef, bases: List[Tuple[ProperType, Expression]] + ) -> None: """Set up base classes. This computes several attributes on the corresponding TypeInfo defn.info @@ -1587,7 +1768,7 @@ def configure_base_classes(self, self.fail(msg, base_expr) info.fallback_to_any = True else: - msg = 'Invalid base class' + msg = "Invalid base class" name = self.get_name_repr_of_expr(base_expr) if name: msg += f' "{name}"' @@ -1599,11 +1780,12 @@ def configure_base_classes(self, else: prefix = "Base type" self.msg.unimported_type_becomes_any(prefix, base, base_expr) - check_for_explicit_any(base, self.options, self.is_typeshed_stub_file, self.msg, - context=base_expr) + check_for_explicit_any( + base, self.options, self.is_typeshed_stub_file, self.msg, context=base_expr + ) # Add 'object' as implicit base if there is no other base class. - if not base_types and defn.fullname != 'builtins.object': + if not base_types and defn.fullname != "builtins.object": base_types.append(self.object_type()) info.bases = base_types @@ -1614,10 +1796,9 @@ def configure_base_classes(self, return self.calculate_class_mro(defn, self.object_type) - def configure_tuple_base_class(self, - defn: ClassDef, - base: TupleType, - base_expr: Expression) -> Instance: + def configure_tuple_base_class( + self, defn: ClassDef, base: TupleType, base_expr: Expression + ) -> Instance: info = defn.info # There may be an existing valid tuple type from previous semanal iterations. @@ -1631,7 +1812,7 @@ def configure_tuple_base_class(self, defn.analyzed.line = defn.line defn.analyzed.column = defn.column - if base.partial_fallback.type.fullname == 'builtins.tuple': + if base.partial_fallback.type.fullname == "builtins.tuple": # Fallback can only be safely calculated after semantic analysis, since base # classes may be incomplete. Postpone the calculation. self.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(base)) @@ -1643,8 +1824,9 @@ def set_dummy_mro(self, info: TypeInfo) -> None: info.mro = [info, self.object_type().type] info.bad_mro = True - def calculate_class_mro(self, defn: ClassDef, - obj_type: Optional[Callable[[], Instance]] = None) -> None: + def calculate_class_mro( + self, defn: ClassDef, obj_type: Optional[Callable[[], Instance]] = None + ) -> None: """Calculate method resolution order for a class. `obj_type` may be omitted in the third pass when all classes are already analyzed. @@ -1654,8 +1836,11 @@ def calculate_class_mro(self, defn: ClassDef, try: calculate_mro(defn.info, obj_type) except MroError: - self.fail('Cannot determine consistent method resolution ' - 'order (MRO) for "%s"' % defn.name, defn) + self.fail( + "Cannot determine consistent method resolution " + 'order (MRO) for "%s"' % defn.name, + defn, + ) self.set_dummy_mro(defn.info) # Allow plugins to alter the MRO to handle the fact that `def mro()` # on metaclasses permits MRO rewriting. @@ -1692,11 +1877,16 @@ def update_metaclass(self, defn: ClassDef) -> None: base_expr = defn.base_type_exprs[0] if isinstance(base_expr, CallExpr) and isinstance(base_expr.callee, RefExpr): base_expr.accept(self) - if (base_expr.callee.fullname in {'six.with_metaclass', - 'future.utils.with_metaclass', - 'past.utils.with_metaclass'} - and len(base_expr.args) >= 1 - and all(kind == ARG_POS for kind in base_expr.arg_kinds)): + if ( + base_expr.callee.fullname + in { + "six.with_metaclass", + "future.utils.with_metaclass", + "past.utils.with_metaclass", + } + and len(base_expr.args) >= 1 + and all(kind == ARG_POS for kind in base_expr.arg_kinds) + ): with_meta_expr = base_expr.args[0] defn.base_type_exprs = base_expr.args[1:] @@ -1705,9 +1895,11 @@ def update_metaclass(self, defn: ClassDef) -> None: for dec_expr in defn.decorators: if isinstance(dec_expr, CallExpr) and isinstance(dec_expr.callee, RefExpr): dec_expr.callee.accept(self) - if (dec_expr.callee.fullname == 'six.add_metaclass' + if ( + dec_expr.callee.fullname == "six.add_metaclass" and len(dec_expr.args) == 1 - and dec_expr.arg_kinds[0] == ARG_POS): + and dec_expr.arg_kinds[0] == ARG_POS + ): add_meta_expr = dec_expr.args[0] break @@ -1725,11 +1917,10 @@ def verify_base_classes(self, defn: ClassDef) -> bool: for base in info.bases: baseinfo = base.type if self.is_base_class(info, baseinfo): - self.fail('Cycle in inheritance hierarchy', defn) + self.fail("Cycle in inheritance hierarchy", defn) cycle = True - if baseinfo.fullname == 'builtins.bool': - self.fail('"%s" is not a valid base class' % - baseinfo.name, defn, blocker=True) + if baseinfo.fullname == "builtins.bool": + self.fail('"%s" is not a valid base class' % baseinfo.name, defn, blocker=True) return False dup = find_duplicate(info.direct_base_classes()) if dup: @@ -1780,19 +1971,22 @@ def analyze_metaclass(self, defn: ClassDef) -> None: self.fail(f'Invalid metaclass "{metaclass_name}"', defn.metaclass) return if not sym.node.is_metaclass(): - self.fail('Metaclasses not inheriting from "type" are not supported', - defn.metaclass) + self.fail( + 'Metaclasses not inheriting from "type" are not supported', defn.metaclass + ) return inst = fill_typevars(sym.node) assert isinstance(inst, Instance) defn.info.declared_metaclass = inst defn.info.metaclass_type = defn.info.calculate_metaclass_type() if any(info.is_protocol for info in defn.info.mro): - if (not defn.info.metaclass_type or - defn.info.metaclass_type.type.fullname == 'builtins.type'): + if ( + not defn.info.metaclass_type + or defn.info.metaclass_type.type.fullname == "builtins.type" + ): # All protocols and their subclasses have ABCMeta metaclass by default. # TODO: add a metaclass conflict check if there is another metaclass. - abc_meta = self.named_type_or_none('abc.ABCMeta', []) + abc_meta = self.named_type_or_none("abc.ABCMeta", []) if abc_meta is not None: # May be None in tests with incomplete lib-stub. defn.info.metaclass_type = abc_meta if defn.info.metaclass_type is None: @@ -1801,7 +1995,7 @@ def analyze_metaclass(self, defn: ClassDef) -> None: if defn.metaclass is not None: self.fail(f'Inconsistent metaclass structure for "{defn.name}"', defn) else: - if defn.info.metaclass_type.type.has_base('enum.EnumMeta'): + if defn.info.metaclass_type.type.has_base("enum.EnumMeta"): defn.info.is_enum = True if defn.type_vars: self.fail("Enum class cannot be generic", defn) @@ -1821,18 +2015,23 @@ def visit_import(self, i: Import) -> None: imported_id = as_id module_public = use_implicit_reexport or id.split(".")[-1] == as_id else: - base_id = id.split('.')[0] + base_id = id.split(".")[0] imported_id = base_id module_public = use_implicit_reexport - self.add_module_symbol(base_id, imported_id, context=i, module_public=module_public, - module_hidden=not module_public) + self.add_module_symbol( + base_id, + imported_id, + context=i, + module_public=module_public, + module_hidden=not module_public, + ) def visit_import_from(self, imp: ImportFrom) -> None: self.statement = imp module_id = self.correct_relative_import(imp) module = self.modules.get(module_id) for id, as_id in imp.names: - fullname = module_id + '.' + id + fullname = module_id + "." + id self.set_future_import_flags(fullname) if module is None: node = None @@ -1844,7 +2043,7 @@ def visit_import_from(self, imp: ImportFrom) -> None: # precedence, but doesn't seem to be important in most use cases. node = SymbolTableNode(GDEF, self.modules[fullname]) else: - if id == as_id == '__all__' and module_id in self.export_map: + if id == as_id == "__all__" and module_id in self.export_map: self.all_exports[:] = self.export_map[module_id] node = module.names.get(id) @@ -1867,16 +2066,23 @@ def visit_import_from(self, imp: ImportFrom) -> None: elif fullname in self.missing_modules: missing_submodule = True # If it is still not resolved, check for a module level __getattr__ - if (module and not node and (module.is_stub or self.options.python_version >= (3, 7)) - and '__getattr__' in module.names): + if ( + module + and not node + and (module.is_stub or self.options.python_version >= (3, 7)) + and "__getattr__" in module.names + ): # We store the fullname of the original definition so that we can # detect whether two imported names refer to the same thing. - fullname = module_id + '.' + id - gvar = self.create_getattr_var(module.names['__getattr__'], imported_id, fullname) + fullname = module_id + "." + id + gvar = self.create_getattr_var(module.names["__getattr__"], imported_id, fullname) if gvar: self.add_symbol( - imported_id, gvar, imp, module_public=module_public, - module_hidden=not module_public + imported_id, + gvar, + imp, + module_public=module_public, + module_hidden=not module_public, ) continue @@ -1887,24 +2093,33 @@ def visit_import_from(self, imp: ImportFrom) -> None: elif module and not missing_submodule: # Target module exists but the imported name is missing or hidden. self.report_missing_module_attribute( - module_id, id, imported_id, module_public=module_public, - module_hidden=not module_public, context=imp + module_id, + id, + imported_id, + module_public=module_public, + module_hidden=not module_public, + context=imp, ) else: # Import of a missing (sub)module. self.add_unknown_imported_symbol( - imported_id, imp, target_name=fullname, module_public=module_public, - module_hidden=not module_public + imported_id, + imp, + target_name=fullname, + module_public=module_public, + module_hidden=not module_public, ) - def process_imported_symbol(self, - node: SymbolTableNode, - module_id: str, - id: str, - imported_id: str, - fullname: str, - module_public: bool, - context: ImportBase) -> None: + def process_imported_symbol( + self, + node: SymbolTableNode, + module_id: str, + id: str, + imported_id: str, + fullname: str, + module_public: bool, + context: ImportBase, + ) -> None: module_hidden = not module_public and ( # `from package import submodule` should work regardless of whether package # re-exports submodule, so we shouldn't hide it @@ -1918,22 +2133,31 @@ def process_imported_symbol(self, if isinstance(node.node, PlaceholderNode): if self.final_iteration: self.report_missing_module_attribute( - module_id, id, imported_id, module_public=module_public, - module_hidden=module_hidden, context=context + module_id, + id, + imported_id, + module_public=module_public, + module_hidden=module_hidden, + context=context, ) return else: # This might become a type. - self.mark_incomplete(imported_id, node.node, - module_public=module_public, - module_hidden=module_hidden, - becomes_typeinfo=True) + self.mark_incomplete( + imported_id, + node.node, + module_public=module_public, + module_hidden=module_hidden, + becomes_typeinfo=True, + ) existing_symbol = self.globals.get(imported_id) - if (existing_symbol and not isinstance(existing_symbol.node, PlaceholderNode) and - not isinstance(node.node, PlaceholderNode)): + if ( + existing_symbol + and not isinstance(existing_symbol.node, PlaceholderNode) + and not isinstance(node.node, PlaceholderNode) + ): # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name( - imported_id, existing_symbol, node, context): + if self.process_import_over_existing_name(imported_id, existing_symbol, node, context): return if existing_symbol and isinstance(node.node, PlaceholderNode): # Imports are special, some redefinitions are allowed, so wait until @@ -1941,13 +2165,18 @@ def process_imported_symbol(self, return # NOTE: we take the original node even for final `Var`s. This is to support # a common pattern when constants are re-exported (same applies to import *). - self.add_imported_symbol(imported_id, node, context, - module_public=module_public, - module_hidden=module_hidden) + self.add_imported_symbol( + imported_id, node, context, module_public=module_public, module_hidden=module_hidden + ) def report_missing_module_attribute( - self, import_id: str, source_id: str, imported_id: str, module_public: bool, - module_hidden: bool, context: Node + self, + import_id: str, + source_id: str, + imported_id: str, + module_public: bool, + module_hidden: bool, + context: Node, ) -> None: # Missing attribute. if self.is_incomplete_namespace(import_id): @@ -1962,8 +2191,10 @@ def report_missing_module_attribute( module = self.modules.get(import_id) if module: if not self.options.implicit_reexport and source_id in module.names.keys(): - message = ('Module "{}" does not explicitly export attribute "{}"' - '; implicit reexport disabled'.format(import_id, source_id)) + message = ( + 'Module "{}" does not explicitly export attribute "{}"' + "; implicit reexport disabled".format(import_id, source_id) + ) else: alternatives = set(module.names.keys()).difference({source_id}) matches = best_matches(source_id, alternatives)[:3] @@ -1972,27 +2203,36 @@ def report_missing_module_attribute( message += f"{suggestion}" self.fail(message, context, code=codes.ATTR_DEFINED) self.add_unknown_imported_symbol( - imported_id, context, target_name=None, module_public=module_public, - module_hidden=not module_public + imported_id, + context, + target_name=None, + module_public=module_public, + module_hidden=not module_public, ) - if import_id == 'typing': + if import_id == "typing": # The user probably has a missing definition in a test fixture. Let's verify. - fullname = f'builtins.{source_id.lower()}' - if (self.lookup_fully_qualified_or_none(fullname) is None and - fullname in SUGGESTED_TEST_FIXTURES): + fullname = f"builtins.{source_id.lower()}" + if ( + self.lookup_fully_qualified_or_none(fullname) is None + and fullname in SUGGESTED_TEST_FIXTURES + ): # Yes. Generate a helpful note. self.msg.add_fixture_note(fullname, context) - def process_import_over_existing_name(self, - imported_id: str, existing_symbol: SymbolTableNode, - module_symbol: SymbolTableNode, - import_node: ImportBase) -> bool: + def process_import_over_existing_name( + self, + imported_id: str, + existing_symbol: SymbolTableNode, + module_symbol: SymbolTableNode, + import_node: ImportBase, + ) -> bool: if existing_symbol.node is module_symbol.node: # We added this symbol on previous iteration. return False - if (existing_symbol.kind in (LDEF, GDEF, MDEF) and - isinstance(existing_symbol.node, (Var, FuncDef, TypeInfo, Decorator, TypeAlias))): + if existing_symbol.kind in (LDEF, GDEF, MDEF) and isinstance( + existing_symbol.node, (Var, FuncDef, TypeInfo, Decorator, TypeAlias) + ): # This is a valid import over an existing definition in the file. Construct a dummy # assignment that we'll use to type check the import. lvalue = NameExpr(imported_id) @@ -2013,8 +2253,9 @@ def process_import_over_existing_name(self, return False def correct_relative_import(self, node: Union[ImportFrom, ImportAll]) -> str: - import_id, ok = correct_relative_import(self.cur_mod_id, node.relative, node.id, - self.cur_mod_node.is_package_init_file()) + import_id, ok = correct_relative_import( + self.cur_mod_id, node.relative, node.id, self.cur_mod_node.is_package_init_file() + ) if not ok: self.fail("Relative import climbs too many namespaces", node) return import_id @@ -2026,28 +2267,27 @@ def visit_import_all(self, i: ImportAll) -> None: if self.is_incomplete_namespace(i_id): # Any names could be missing from the current namespace if the target module # namespace is incomplete. - self.mark_incomplete('*', i) + self.mark_incomplete("*", i) for name, node in m.names.items(): - fullname = i_id + '.' + name + fullname = i_id + "." + name self.set_future_import_flags(fullname) if node is None: continue # if '__all__' exists, all nodes not included have had module_public set to # False, and we can skip checking '_' because it's been explicitly included. - if node.module_public and (not name.startswith('_') or '__all__' in m.names): + if node.module_public and (not name.startswith("_") or "__all__" in m.names): if isinstance(node.node, MypyFile): # Star import of submodule from a package, add it as a dependency. self.imports.add(node.node.fullname) existing_symbol = self.lookup_current_scope(name) if existing_symbol and not isinstance(node.node, PlaceholderNode): # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name( - name, existing_symbol, node, i): + if self.process_import_over_existing_name(name, existing_symbol, node, i): continue # `from x import *` always reexports symbols - self.add_imported_symbol(name, node, i, - module_public=True, - module_hidden=False) + self.add_imported_symbol( + name, node, i, module_public=True, module_hidden=False + ) else: # Don't add any dummy symbols for 'from x import *' if 'x' is unknown. @@ -2204,14 +2444,13 @@ def can_be_type_alias(self, rv: Expression, allow_none: bool = False) -> bool: return True if self.is_none_alias(rv): return True - if allow_none and isinstance(rv, NameExpr) and rv.fullname == 'builtins.None': + if allow_none and isinstance(rv, NameExpr) and rv.fullname == "builtins.None": return True - if isinstance(rv, OpExpr) and rv.op == '|': + if isinstance(rv, OpExpr) and rv.op == "|": if self.is_stub_file: return True - if ( - self.can_be_type_alias(rv.left, allow_none=True) - and self.can_be_type_alias(rv.right, allow_none=True) + if self.can_be_type_alias(rv.left, allow_none=True) and self.can_be_type_alias( + rv.right, allow_none=True ): return True return False @@ -2237,14 +2476,15 @@ def is_type_ref(self, rv: Expression, bare: bool = False) -> bool: if not isinstance(rv, RefExpr): return False if isinstance(rv.node, TypeVarExpr): - self.fail('Type variable "{}" is invalid as target for type alias'.format( - rv.fullname), rv) + self.fail( + 'Type variable "{}" is invalid as target for type alias'.format(rv.fullname), rv + ) return False if bare: # These three are valid even if bare, for example # A = Tuple is just equivalent to A = Tuple[Any, ...]. - valid_refs = {'typing.Any', 'typing.Tuple', 'typing.Callable'} + valid_refs = {"typing.Any", "typing.Tuple", "typing.Callable"} else: valid_refs = type_constructors @@ -2279,12 +2519,21 @@ def is_none_alias(self, node: Expression) -> bool: Void in type annotations. """ if isinstance(node, CallExpr): - if (isinstance(node.callee, NameExpr) and len(node.args) == 1 and - isinstance(node.args[0], NameExpr)): + if ( + isinstance(node.callee, NameExpr) + and len(node.args) == 1 + and isinstance(node.args[0], NameExpr) + ): call = self.lookup_qualified(node.callee.name, node.callee) arg = self.lookup_qualified(node.args[0].name, node.args[0]) - if (call is not None and call.node and call.node.fullname == 'builtins.type' and - arg is not None and arg.node and arg.node.fullname == 'builtins.None'): + if ( + call is not None + and call.node + and call.node.fullname == "builtins.type" + and arg is not None + and arg.node + and arg.node.fullname == "builtins.None" + ): return True return False @@ -2315,16 +2564,22 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool: return False lvalue = s.lvalues[0] name = lvalue.name - internal_name, info = self.named_tuple_analyzer.check_namedtuple(s.rvalue, name, - self.is_func_scope()) + internal_name, info = self.named_tuple_analyzer.check_namedtuple( + s.rvalue, name, self.is_func_scope() + ) if internal_name is None: return False if isinstance(lvalue, MemberExpr): self.fail("NamedTuple type as an attribute is not supported", lvalue) return False if internal_name != name: - self.fail('First argument to namedtuple() should be "{}", not "{}"'.format( - name, internal_name), s.rvalue, code=codes.NAME_MATCH) + self.fail( + 'First argument to namedtuple() should be "{}", not "{}"'.format( + name, internal_name + ), + s.rvalue, + code=codes.NAME_MATCH, + ) return True # Yes, it's a valid namedtuple, but defer if it is not ready. if not info: @@ -2339,8 +2594,9 @@ def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool: return False lvalue = s.lvalues[0] name = lvalue.name - is_typed_dict, info = self.typed_dict_analyzer.check_typeddict(s.rvalue, name, - self.is_func_scope()) + is_typed_dict, info = self.typed_dict_analyzer.check_typeddict( + s.rvalue, name, self.is_func_scope() + ) if not is_typed_dict: return False if isinstance(lvalue, MemberExpr): @@ -2369,10 +2625,12 @@ def analyze_lvalues(self, s: AssignmentStmt) -> None: has_explicit_value = False for lval in s.lvalues: - self.analyze_lvalue(lval, - explicit_type=explicit, - is_final=s.is_final_def, - has_explicit_value=has_explicit_value) + self.analyze_lvalue( + lval, + explicit_type=explicit, + is_final=s.is_final_def, + has_explicit_value=has_explicit_value, + ) def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: if not isinstance(s.rvalue, CallExpr): @@ -2387,7 +2645,7 @@ def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: callee_expr = call.callee.expr if isinstance(callee_expr, RefExpr) and callee_expr.fullname: method_name = call.callee.name - fname = callee_expr.fullname + '.' + method_name + fname = callee_expr.fullname + "." + method_name elif isinstance(callee_expr, CallExpr): # check if chain call call = callee_expr @@ -2448,8 +2706,12 @@ def unwrap_final(self, s: AssignmentStmt) -> bool: self.fail("Cannot use Final inside a loop", s) if self.type and self.type.is_protocol: self.msg.protocol_members_cant_be_final(s) - if (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs and - not self.is_stub_file and not self.is_class_scope()): + if ( + isinstance(s.rvalue, TempNode) + and s.rvalue.no_rhs + and not self.is_stub_file + and not self.is_class_scope() + ): if not invalid_bare_final: # Skip extra error messages. self.msg.final_without_value(s) return True @@ -2470,7 +2732,7 @@ def check_final_implicit_def(self, s: AssignmentStmt) -> None: return else: assert self.function_stack - if self.function_stack[-1].name != '__init__': + if self.function_stack[-1].name != "__init__": self.fail("Can only declare a final attribute in class body or __init__", s) s.is_final_def = False return @@ -2483,8 +2745,9 @@ def store_final_status(self, s: AssignmentStmt) -> None: if isinstance(node, Var): node.is_final = True node.final_value = self.unbox_literal(s.rvalue) - if (self.is_class_scope() and - (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs)): + if self.is_class_scope() and ( + isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs + ): node.final_unset_in_class = True else: for lval in self.flatten_lvalues(s.lvalues): @@ -2499,12 +2762,17 @@ def store_final_status(self, s: AssignmentStmt) -> None: # # will fail with `AttributeError: Cannot reassign members.` # That's why we need to replicate this. - if (isinstance(lval, NameExpr) and - isinstance(self.type, TypeInfo) and - self.type.is_enum): + if ( + isinstance(lval, NameExpr) + and isinstance(self.type, TypeInfo) + and self.type.is_enum + ): cur_node = self.type.names.get(lval.name, None) - if (cur_node and isinstance(cur_node.node, Var) and - not (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs)): + if ( + cur_node + and isinstance(cur_node.node, Var) + and not (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs) + ): # Double underscored members are writable on an `Enum`. # (Except read-only `__members__` but that is handled in type checker) cur_node.node.is_final = s.is_final_def = not is_dunder(cur_node.node.name) @@ -2518,10 +2786,12 @@ def store_final_status(self, s: AssignmentStmt) -> None: if cur_node and isinstance(cur_node.node, Var) and cur_node.node.is_final: assert self.function_stack top_function = self.function_stack[-1] - if (top_function.name == '__init__' and - cur_node.node.final_unset_in_class and - not cur_node.node.final_set_in_init and - not (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs)): + if ( + top_function.name == "__init__" + and cur_node.node.final_unset_in_class + and not cur_node.node.final_set_in_init + and not (isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs) + ): cur_node.node.final_set_in_init = True s.is_final_def = True @@ -2537,8 +2807,8 @@ def flatten_lvalues(self, lvalues: List[Expression]) -> List[Expression]: def unbox_literal(self, e: Expression) -> Optional[Union[int, float, bool, str]]: if isinstance(e, (IntExpr, FloatExpr, StrExpr)): return e.value - elif isinstance(e, NameExpr) and e.name in ('True', 'False'): - return True if e.name == 'True' else False + elif isinstance(e, NameExpr) and e.name in ("True", "False"): + return True if e.name == "True" else False return None def process_type_annotation(self, s: AssignmentStmt) -> None: @@ -2551,14 +2821,23 @@ def process_type_annotation(self, s: AssignmentStmt) -> None: if analyzed is None or has_placeholder(analyzed): return s.type = analyzed - if (self.type and self.type.is_protocol and isinstance(lvalue, NameExpr) and - isinstance(s.rvalue, TempNode) and s.rvalue.no_rhs): + if ( + self.type + and self.type.is_protocol + and isinstance(lvalue, NameExpr) + and isinstance(s.rvalue, TempNode) + and s.rvalue.no_rhs + ): if isinstance(lvalue.node, Var): lvalue.node.is_abstract_var = True else: - if (self.type and self.type.is_protocol and - self.is_annotated_protocol_member(s) and not self.is_func_scope()): - self.fail('All protocol members must have explicitly declared types', s) + if ( + self.type + and self.type.is_protocol + and self.is_annotated_protocol_member(s) + and not self.is_func_scope() + ): + self.fail("All protocol members must have explicitly declared types", s) # Set the type if the rvalue is a simple literal (even if the above error occurred). if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr): if s.lvalues[0].is_inferred_def: @@ -2573,11 +2852,7 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool: There are some exceptions that can be left unannotated, like ``__slots__``.""" return any( - ( - isinstance(lv, NameExpr) - and lv.name != '__slots__' - and lv.is_inferred_def - ) + (isinstance(lv, NameExpr) and lv.name != "__slots__" and lv.is_inferred_def) for lv in s.lvalues ) @@ -2594,36 +2869,35 @@ def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Opt # AnyStr). return None if isinstance(rvalue, FloatExpr): - return self.named_type_or_none('builtins.float') + return self.named_type_or_none("builtins.float") value: Optional[LiteralValue] = None type_name: Optional[str] = None if isinstance(rvalue, IntExpr): - value, type_name = rvalue.value, 'builtins.int' + value, type_name = rvalue.value, "builtins.int" if isinstance(rvalue, StrExpr): - value, type_name = rvalue.value, 'builtins.str' + value, type_name = rvalue.value, "builtins.str" if isinstance(rvalue, BytesExpr): - value, type_name = rvalue.value, 'builtins.bytes' + value, type_name = rvalue.value, "builtins.bytes" if isinstance(rvalue, UnicodeExpr): - value, type_name = rvalue.value, 'builtins.unicode' + value, type_name = rvalue.value, "builtins.unicode" if type_name is not None: assert value is not None typ = self.named_type_or_none(type_name) if typ and is_final: - return typ.copy_modified(last_known_value=LiteralType( - value=value, - fallback=typ, - line=typ.line, - column=typ.column, - )) + return typ.copy_modified( + last_known_value=LiteralType( + value=value, fallback=typ, line=typ.line, column=typ.column + ) + ) return typ return None - def analyze_alias(self, rvalue: Expression, - allow_placeholder: bool = False) -> Tuple[Optional[Type], List[str], - Set[str], List[str]]: + def analyze_alias( + self, rvalue: Expression, allow_placeholder: bool = False + ) -> Tuple[Optional[Type], List[str], Set[str], List[str]]: """Check if 'rvalue' is a valid type allowed for aliasing (e.g. not a type variable). If yes, return the corresponding type, a list of @@ -2636,15 +2910,17 @@ def analyze_alias(self, rvalue: Expression, """ dynamic = bool(self.function_stack and self.function_stack[-1].is_dynamic()) global_scope = not self.type and not self.function_stack - res = analyze_type_alias(rvalue, - self, - self.tvar_scope, - self.plugin, - self.options, - self.is_typeshed_stub_file, - allow_placeholder=allow_placeholder, - in_dynamic_func=dynamic, - global_scope=global_scope) + res = analyze_type_alias( + rvalue, + self, + self.tvar_scope, + self.plugin, + self.options, + self.is_typeshed_stub_file, + allow_placeholder=allow_placeholder, + in_dynamic_func=dynamic, + global_scope=global_scope, + ) typ: Optional[Type] = None if res: typ, depends_on = res @@ -2688,18 +2964,19 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # B = int # B = float # Error! # Don't create an alias in these cases: - if (existing - and (isinstance(existing.node, Var) # existing variable - or (isinstance(existing.node, TypeAlias) - and not s.is_alias_def) # existing alias - or (isinstance(existing.node, PlaceholderNode) - and existing.node.node.line < s.line))): # previous incomplete definition + if existing and ( + isinstance(existing.node, Var) # existing variable + or (isinstance(existing.node, TypeAlias) and not s.is_alias_def) # existing alias + or (isinstance(existing.node, PlaceholderNode) and existing.node.node.line < s.line) + ): # previous incomplete definition # TODO: find a more robust way to track the order of definitions. # Note: if is_alias_def=True, this is just a node from previous iteration. if isinstance(existing.node, TypeAlias) and not s.is_alias_def: - self.fail('Cannot assign multiple types to name "{}"' - ' without an explicit "Type[...]" annotation' - .format(lvalue.name), lvalue) + self.fail( + 'Cannot assign multiple types to name "{}"' + ' without an explicit "Type[...]" annotation'.format(lvalue.name), + lvalue, + ) return False non_global_scope = self.type or self.is_func_scope() @@ -2726,12 +3003,16 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: res: Optional[Type] = None if self.is_none_alias(rvalue): res = NoneType() - alias_tvars, depends_on, qualified_tvars = \ - [], set(), [] # type: List[str], Set[str], List[str] + alias_tvars, depends_on, qualified_tvars = ( + [], + set(), + [], + ) # type: List[str], Set[str], List[str] else: tag = self.track_incomplete_refs() - res, alias_tvars, depends_on, qualified_tvars = \ - self.analyze_alias(rvalue, allow_placeholder=True) + res, alias_tvars, depends_on, qualified_tvars = self.analyze_alias( + rvalue, allow_placeholder=True + ) if not res: return False # TODO: Maybe we only need to reject top-level placeholders, similar @@ -2748,8 +3029,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # The above are only direct deps on other aliases. # For subscripted aliases, type deps from expansion are added in deps.py # (because the type is stored). - check_for_explicit_any(res, self.options, self.is_typeshed_stub_file, self.msg, - context=s) + check_for_explicit_any(res, self.options, self.is_typeshed_stub_file, self.msg, context=s) # When this type alias gets "inlined", the Any is not explicit anymore, # so we need to replace it with non-explicit Anys. if not has_placeholder(res): @@ -2763,13 +3043,15 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: # the function, since the symbol table will no longer # exist. Work around by expanding them eagerly when used. eager = self.is_func_scope() - alias_node = TypeAlias(res, - self.qualified_name(lvalue.name), - s.line, - s.column, - alias_tvars=alias_tvars, - no_args=no_args, - eager=eager) + alias_node = TypeAlias( + res, + self.qualified_name(lvalue.name), + s.line, + s.column, + alias_tvars=alias_tvars, + no_args=no_args, + eager=eager, + ) if isinstance(s.rvalue, (IndexExpr, CallExpr)): # CallExpr is for `void = type(None)` s.rvalue.analyzed = TypeAliasExpr(alias_node) s.rvalue.analyzed.line = s.line @@ -2795,7 +3077,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: updated = True if updated: if self.final_iteration: - self.cannot_resolve_name(lvalue.name, 'name', s) + self.cannot_resolve_name(lvalue.name, "name", s) return True else: self.progress = True @@ -2807,13 +3089,15 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool: alias_node.normalized = rvalue.node.normalized return True - def analyze_lvalue(self, - lval: Lvalue, - nested: bool = False, - explicit_type: bool = False, - is_final: bool = False, - escape_comprehensions: bool = False, - has_explicit_value: bool = False) -> None: + def analyze_lvalue( + self, + lval: Lvalue, + nested: bool = False, + explicit_type: bool = False, + is_final: bool = False, + escape_comprehensions: bool = False, + has_explicit_value: bool = False, + ) -> None: """Analyze an lvalue or assignment target. Args: @@ -2828,18 +3112,19 @@ def analyze_lvalue(self, assert isinstance(lval, NameExpr), "assignment expression target must be NameExpr" if isinstance(lval, NameExpr): self.analyze_name_lvalue( - lval, explicit_type, is_final, + lval, + explicit_type, + is_final, escape_comprehensions, has_explicit_value=has_explicit_value, ) elif isinstance(lval, MemberExpr): self.analyze_member_lvalue(lval, explicit_type, is_final) if explicit_type and not self.is_self_member_ref(lval): - self.fail('Type cannot be declared in assignment to non-self ' - 'attribute', lval) + self.fail("Type cannot be declared in assignment to non-self " "attribute", lval) elif isinstance(lval, IndexExpr): if explicit_type: - self.fail('Unexpected type declaration', lval) + self.fail("Unexpected type declaration", lval) lval.accept(self) elif isinstance(lval, TupleExpr): self.analyze_tuple_or_list_lvalue(lval, explicit_type) @@ -2847,16 +3132,18 @@ def analyze_lvalue(self, if nested: self.analyze_lvalue(lval.expr, nested, explicit_type) else: - self.fail('Starred assignment target must be in a list or tuple', lval) + self.fail("Starred assignment target must be in a list or tuple", lval) else: - self.fail('Invalid assignment target', lval) - - def analyze_name_lvalue(self, - lvalue: NameExpr, - explicit_type: bool, - is_final: bool, - escape_comprehensions: bool, - has_explicit_value: bool) -> None: + self.fail("Invalid assignment target", lval) + + def analyze_name_lvalue( + self, + lvalue: NameExpr, + explicit_type: bool, + is_final: bool, + escape_comprehensions: bool, + has_explicit_value: bool, + ) -> None: """Analyze an lvalue that targets a name expression. Arguments are similar to "analyze_lvalue". @@ -2880,9 +3167,12 @@ def analyze_name_lvalue(self, if kind == MDEF and isinstance(self.type, TypeInfo) and self.type.is_enum: # Special case: we need to be sure that `Enum` keys are unique. if existing is not None and not isinstance(existing.node, PlaceholderNode): - self.fail('Attempted to reuse member name "{}" in Enum definition "{}"'.format( - name, self.type.name, - ), lvalue) + self.fail( + 'Attempted to reuse member name "{}" in Enum definition "{}"'.format( + name, self.type.name + ), + lvalue, + ) if (not existing or isinstance(existing.node, PlaceholderNode)) and not outer: # Define new variable. @@ -2899,7 +3189,7 @@ def analyze_name_lvalue(self, else: lvalue.fullname = lvalue.name if self.is_func_scope(): - if unmangle(name) == '_': + if unmangle(name) == "_": # Special case for assignment to local named '_': always infer 'Any'. typ = AnyType(TypeOfAny.special_form) self.store_declared_types(lvalue, typ) @@ -2938,7 +3228,7 @@ def is_alias_for_final_name(self, name: str) -> bool: return existing is not None and is_final_node(existing.node) def make_name_lvalue_var( - self, lvalue: NameExpr, kind: int, inferred: bool, has_explicit_value: bool, + self, lvalue: NameExpr, kind: int, inferred: bool, has_explicit_value: bool ) -> Var: """Return a Var node for an lvalue that is a name expression.""" name = lvalue.name @@ -2960,10 +3250,8 @@ def make_name_lvalue_var( return v def make_name_lvalue_point_to_existing_def( - self, - lval: NameExpr, - explicit_type: bool, - is_final: bool) -> None: + self, lval: NameExpr, explicit_type: bool, is_final: bool + ) -> None: """Update an lvalue to point to existing definition in the same scope. Arguments are similar to "analyze_lvalue". @@ -2988,14 +3276,13 @@ def make_name_lvalue_point_to_existing_def( self.name_not_defined(lval.name, lval) self.check_lvalue_validity(lval.node, lval) - def analyze_tuple_or_list_lvalue(self, lval: TupleExpr, - explicit_type: bool = False) -> None: + def analyze_tuple_or_list_lvalue(self, lval: TupleExpr, explicit_type: bool = False) -> None: """Analyze an lvalue or assignment target that is a list or tuple.""" items = lval.items star_exprs = [item for item in items if isinstance(item, StarExpr)] if len(star_exprs) > 1: - self.fail('Two starred expressions in assignment', lval) + self.fail("Two starred expressions in assignment", lval) else: if len(star_exprs) == 1: star_exprs[0].valid = True @@ -3030,16 +3317,23 @@ def analyze_member_lvalue(self, lval: MemberExpr, explicit_type: bool, is_final: self.fail("Cannot redefine an existing name as final", lval) # On first encounter with this definition, if this attribute was defined before # with an inferred type and it's marked with an explicit type now, give an error. - if (not lval.node and cur_node and isinstance(cur_node.node, Var) and - cur_node.node.is_inferred and explicit_type): + if ( + not lval.node + and cur_node + and isinstance(cur_node.node, Var) + and cur_node.node.is_inferred + and explicit_type + ): self.attribute_already_defined(lval.name, lval, cur_node) # If the attribute of self is not defined in superclasses, create a new Var, ... - if (node is None - or (isinstance(node.node, Var) and node.node.is_abstract_var) - # ... also an explicit declaration on self also creates a new Var. - # Note that `explicit_type` might has been erased for bare `Final`, - # so we also check if `is_final` is passed. - or (cur_node is None and (explicit_type or is_final))): + if ( + node is None + or (isinstance(node.node, Var) and node.node.is_abstract_var) + # ... also an explicit declaration on self also creates a new Var. + # Note that `explicit_type` might has been erased for bare `Final`, + # so we also check if `is_final` is passed. + or (cur_node is None and (explicit_type or is_final)) + ): if self.type.is_protocol and node is None: self.fail("Protocol members cannot be defined via assignment to self", lval) else: @@ -3065,16 +3359,17 @@ def is_self_member_ref(self, memberexpr: MemberExpr) -> bool: node = memberexpr.expr.node return isinstance(node, Var) and node.is_self - def check_lvalue_validity(self, node: Union[Expression, SymbolNode, None], - ctx: Context) -> None: + def check_lvalue_validity( + self, node: Union[Expression, SymbolNode, None], ctx: Context + ) -> None: if isinstance(node, TypeVarExpr): - self.fail('Invalid assignment target', ctx) + self.fail("Invalid assignment target", ctx) elif isinstance(node, TypeInfo): self.fail(message_registry.CANNOT_ASSIGN_TO_TYPE, ctx) def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None: if isinstance(typ, StarType) and not isinstance(lvalue, StarExpr): - self.fail('Star type only allowed for starred expressions', lvalue) + self.fail("Star type only allowed for starred expressions", lvalue) if isinstance(lvalue, RefExpr): lvalue.is_inferred_def = False if isinstance(lvalue.node, Var): @@ -3086,13 +3381,12 @@ def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None: typ = get_proper_type(typ) if isinstance(typ, TupleType): if len(lvalue.items) != len(typ.items): - self.fail('Incompatible number of tuple items', lvalue) + self.fail("Incompatible number of tuple items", lvalue) return for item, itemtype in zip(lvalue.items, typ.items): self.store_declared_types(item, itemtype) else: - self.fail('Tuple type expected for multiple variables', - lvalue) + self.fail("Tuple type expected for multiple variables", lvalue) elif isinstance(lvalue, StarExpr): # Historical behavior for the old parser if isinstance(typ, StarType): @@ -3119,22 +3413,26 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: # Constraining types n_values = call.arg_kinds[1:].count(ARG_POS) - values = self.analyze_value_types(call.args[1:1 + n_values]) - - res = self.process_typevar_parameters(call.args[1 + n_values:], - call.arg_names[1 + n_values:], - call.arg_kinds[1 + n_values:], - n_values, - s) + values = self.analyze_value_types(call.args[1 : 1 + n_values]) + + res = self.process_typevar_parameters( + call.args[1 + n_values :], + call.arg_names[1 + n_values :], + call.arg_kinds[1 + n_values :], + n_values, + s, + ) if res is None: return False variance, upper_bound = res existing = self.current_symbol_table().get(name) - if existing and not (isinstance(existing.node, PlaceholderNode) or - # Also give error for another type variable with the same name. - (isinstance(existing.node, TypeVarExpr) and - existing.node is call.analyzed)): + if existing and not ( + isinstance(existing.node, PlaceholderNode) + or + # Also give error for another type variable with the same name. + (isinstance(existing.node, TypeVarExpr) and existing.node is call.analyzed) + ): self.fail(f'Cannot redefine "{name}" as a type variable', s) return False @@ -3149,8 +3447,9 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: self.msg.unimported_type_becomes_any(prefix, upper_bound, s) for t in values + [upper_bound]: - check_for_explicit_any(t, self.options, self.is_typeshed_stub_file, self.msg, - context=s) + check_for_explicit_any( + t, self.options, self.is_typeshed_stub_file, self.msg, context=s + ) # mypyc suppresses making copies of a function to check each # possible type, so set the upper bound to Any to prevent that @@ -3160,8 +3459,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> bool: # Yes, it's a valid type variable definition! Add it to the symbol table. if not call.analyzed: - type_var = TypeVarExpr(name, self.qualified_name(name), - values, upper_bound, variance) + type_var = TypeVarExpr(name, self.qualified_name(name), values, upper_bound, variance) type_var.line = call.line call.analyzed = type_var else: @@ -3184,10 +3482,11 @@ def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> if len(call.args) < 1: self.fail(f"Too few arguments for {typevarlike_type}()", context) return False - if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) - or not call.arg_kinds[0] == ARG_POS): - self.fail(f"{typevarlike_type}() expects a string literal as first argument", - context) + if ( + not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) + or not call.arg_kinds[0] == ARG_POS + ): + self.fail(f"{typevarlike_type}() expects a string literal as first argument", context) return False elif call.args[0].value != name: msg = 'String argument 1 "{}" to {}(...) does not match variable name "{}"' @@ -3195,8 +3494,9 @@ def check_typevarlike_name(self, call: CallExpr, name: str, context: Context) -> return False return True - def get_typevarlike_declaration(self, s: AssignmentStmt, - typevarlike_types: Tuple[str, ...]) -> Optional[CallExpr]: + def get_typevarlike_declaration( + self, s: AssignmentStmt, typevarlike_types: Tuple[str, ...] + ) -> Optional[CallExpr]: """Returns the call expression if `s` is a declaration of `typevarlike_type` (TypeVar or ParamSpec), or None otherwise. """ @@ -3212,12 +3512,15 @@ def get_typevarlike_declaration(self, s: AssignmentStmt, return None return call - def process_typevar_parameters(self, args: List[Expression], - names: List[Optional[str]], - kinds: List[ArgKind], - num_values: int, - context: Context) -> Optional[Tuple[int, Type]]: - has_values = (num_values > 0) + def process_typevar_parameters( + self, + args: List[Expression], + names: List[Optional[str]], + kinds: List[ArgKind], + num_values: int, + context: Context, + ) -> Optional[Tuple[int, Type]]: + has_values = num_values > 0 covariant = False contravariant = False upper_bound: Type = self.object_type() @@ -3225,32 +3528,30 @@ def process_typevar_parameters(self, args: List[Expression], if not param_kind.is_named(): self.fail(message_registry.TYPEVAR_UNEXPECTED_ARGUMENT, context) return None - if param_name == 'covariant': - if (isinstance(param_value, NameExpr) - and param_value.name in ('True', 'False')): - covariant = param_value.name == 'True' + if param_name == "covariant": + if isinstance(param_value, NameExpr) and param_value.name in ("True", "False"): + covariant = param_value.name == "True" else: - self.fail(message_registry.TYPEVAR_VARIANCE_DEF.format( - 'covariant'), context) + self.fail(message_registry.TYPEVAR_VARIANCE_DEF.format("covariant"), context) return None - elif param_name == 'contravariant': - if (isinstance(param_value, NameExpr) - and param_value.name in ('True', 'False')): - contravariant = param_value.name == 'True' + elif param_name == "contravariant": + if isinstance(param_value, NameExpr) and param_value.name in ("True", "False"): + contravariant = param_value.name == "True" else: - self.fail(message_registry.TYPEVAR_VARIANCE_DEF.format( - 'contravariant'), context) + self.fail( + message_registry.TYPEVAR_VARIANCE_DEF.format("contravariant"), context + ) return None - elif param_name == 'bound': + elif param_name == "bound": if has_values: self.fail("TypeVar cannot have both values and an upper bound", context) return None try: # We want to use our custom error message below, so we suppress # the default error message for invalid types here. - analyzed = self.expr_to_analyzed_type(param_value, - allow_placeholder=True, - report_invalid_types=False) + analyzed = self.expr_to_analyzed_type( + param_value, allow_placeholder=True, report_invalid_types=False + ) if analyzed is None: # Type variables are special: we need to place them in the symbol table # soon, even if upper bound is not ready yet. Otherwise avoiding @@ -3267,16 +3568,18 @@ def process_typevar_parameters(self, args: List[Expression], except TypeTranslationError: self.fail(message_registry.TYPEVAR_BOUND_MUST_BE_TYPE, param_value) return None - elif param_name == 'values': + elif param_name == "values": # Probably using obsolete syntax with values=(...). Explain the current syntax. self.fail('TypeVar "values" argument not supported', context) - self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", - context) + self.fail( + "Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", context + ) return None else: - self.fail('{}: "{}"'.format( - message_registry.TYPEVAR_UNEXPECTED_ARGUMENT, param_name, - ), context) + self.fail( + '{}: "{}"'.format(message_registry.TYPEVAR_UNEXPECTED_ARGUMENT, param_name), + context, + ) return None if covariant and contravariant: @@ -3329,10 +3632,7 @@ def process_paramspec_declaration(self, s: AssignmentStmt) -> bool: # arguments are not semantically valid. But, allowed in runtime. # So, we need to warn users about possible invalid usage. if len(call.args) > 1: - self.fail( - "Only the first argument to ParamSpec has defined semantics", - s, - ) + self.fail("Only the first argument to ParamSpec has defined semantics", s) # PEP 612 reserves the right to define bound, covariant and contravariant arguments to # ParamSpec in a later PEP. If and when that happens, we should do something @@ -3361,10 +3661,7 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: return False if len(call.args) > 1: - self.fail( - "Only the first argument to TypeVarTuple has defined semantics", - s, - ) + self.fail("Only the first argument to TypeVarTuple has defined semantics", s) if not self.options.enable_incomplete_features: self.fail('"TypeVarTuple" is not supported by mypy yet', s) @@ -3386,18 +3683,16 @@ def process_typevartuple_declaration(self, s: AssignmentStmt) -> bool: self.add_symbol(name, call.analyzed, s) return True - def basic_new_typeinfo(self, name: str, - basetype_or_fallback: Instance, - line: int) -> TypeInfo: - if self.is_func_scope() and not self.type and '@' not in name: - name += '@' + str(line) + def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance, line: int) -> TypeInfo: + if self.is_func_scope() and not self.type and "@" not in name: + name += "@" + str(line) class_def = ClassDef(name, Block([])) if self.is_func_scope() and not self.type: # Full names of generated classes should always be prefixed with the module names # even if they are nested in a function, since these classes will be (de-)serialized. # (Note that the caller should append @line to the name to avoid collisions.) # TODO: clean this up, see #6422. - class_def.fullname = self.cur_mod_id + '.' + self.qualified_name(name) + class_def.fullname = self.cur_mod_id + "." + self.qualified_name(name) else: class_def.fullname = self.qualified_name(name) @@ -3416,8 +3711,9 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]: result: List[Type] = [] for node in items: try: - analyzed = self.anal_type(self.expr_to_unanalyzed_type(node), - allow_placeholder=True) + analyzed = self.anal_type( + self.expr_to_unanalyzed_type(node), allow_placeholder=True + ) if analyzed is None: # Type variables are special: we need to place them in the symbol table # soon, even if some value is not ready yet, see process_typevar_parameters() @@ -3425,7 +3721,7 @@ def analyze_value_types(self, items: List[Expression]) -> List[Type]: analyzed = PlaceholderType(None, [], node.line) result.append(analyzed) except TypeTranslationError: - self.fail('Type expected', node) + self.fail("Type expected", node) result.append(AnyType(TypeOfAny.from_error)) return result @@ -3457,7 +3753,7 @@ def is_classvar(self, typ: Type) -> bool: sym = self.lookup_qualified(typ.name, typ) if not sym or not sym.node: return False - return sym.node.fullname == 'typing.ClassVar' + return sym.node.fullname == "typing.ClassVar" def is_final_type(self, typ: Optional[Type]) -> bool: if not isinstance(typ, UnboundType): @@ -3470,8 +3766,9 @@ def is_final_type(self, typ: Optional[Type]) -> bool: def fail_invalid_classvar(self, context: Context) -> None: self.fail(message_registry.CLASS_VAR_OUTSIDE_OF_CLASS, context) - def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, - ctx: AssignmentStmt) -> None: + def process_module_assignment( + self, lvals: List[Lvalue], rval: Expression, ctx: AssignmentStmt + ) -> None: """Propagate module references across assignments. Recursively handles the simple form of iterable unpacking; doesn't @@ -3481,8 +3778,9 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, y]. """ - if (isinstance(rval, (TupleExpr, ListExpr)) - and all(isinstance(v, TupleExpr) for v in lvals)): + if isinstance(rval, (TupleExpr, ListExpr)) and all( + isinstance(v, TupleExpr) for v in lvals + ): # rval and all lvals are either list or tuple, so we are dealing # with unpacking assignment like `x, y = a, b`. Mypy didn't # understand our all(isinstance(...)), so cast them as TupleExpr @@ -3515,7 +3813,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, if not isinstance(lval, RefExpr): continue # respect explicitly annotated type - if (isinstance(lval.node, Var) and lval.node.type is not None): + if isinstance(lval.node, Var) and lval.node.type is not None: continue # We can handle these assignments to locals and to self @@ -3533,7 +3831,8 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, self.fail( 'Cannot assign multiple modules to name "{}" ' 'without explicit "types.ModuleType" annotation'.format(lval.name), - ctx) + ctx, + ) # never create module alias except on initial var definition elif lval.is_inferred_def: assert rnode.node is not None @@ -3541,16 +3840,24 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression, def process__all__(self, s: AssignmentStmt) -> None: """Export names if argument is a __all__ assignment.""" - if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and - s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and - isinstance(s.rvalue, (ListExpr, TupleExpr))): + if ( + len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and s.lvalues[0].name == "__all__" + and s.lvalues[0].kind == GDEF + and isinstance(s.rvalue, (ListExpr, TupleExpr)) + ): self.add_exports(s.rvalue.items) def process__deletable__(self, s: AssignmentStmt) -> None: if not self.options.mypyc: return - if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and - s.lvalues[0].name == '__deletable__' and s.lvalues[0].kind == MDEF): + if ( + len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and s.lvalues[0].name == "__deletable__" + and s.lvalues[0].kind == MDEF + ): rvalue = s.rvalue if not isinstance(rvalue, (ListExpr, TupleExpr)): self.fail('"__deletable__" must be initialized with a list or tuple expression', s) @@ -3572,9 +3879,13 @@ def process__slots__(self, s: AssignmentStmt) -> None: See: https://docs.python.org/3/reference/datamodel.html#slots """ # Later we can support `__slots__` defined as `__slots__ = other = ('a', 'b')` - if (isinstance(self.type, TypeInfo) and - len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and - s.lvalues[0].name == '__slots__' and s.lvalues[0].kind == MDEF): + if ( + isinstance(self.type, TypeInfo) + and len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and s.lvalues[0].name == "__slots__" + and s.lvalues[0].kind == MDEF + ): # We understand `__slots__` defined as string, tuple, list, set, and dict: if not isinstance(s.rvalue, (StrExpr, ListExpr, TupleExpr, SetExpr, DictExpr)): @@ -3606,7 +3917,7 @@ def process__slots__(self, s: AssignmentStmt) -> None: for item in rvalue: # Special case for `'__dict__'` value: # when specified it will still allow any attribute assignment. - if isinstance(item, StrExpr) and item.value != '__dict__': + if isinstance(item, StrExpr) and item.value != "__dict__": slots.append(item.value) else: concrete_slots = False @@ -3662,13 +3973,16 @@ def visit_assert_stmt(self, s: AssertStmt) -> None: if s.msg: s.msg.accept(self) - def visit_operator_assignment_stmt(self, - s: OperatorAssignmentStmt) -> None: + def visit_operator_assignment_stmt(self, s: OperatorAssignmentStmt) -> None: self.statement = s s.lvalue.accept(self) s.rvalue.accept(self) - if (isinstance(s.lvalue, NameExpr) and s.lvalue.name == '__all__' and - s.lvalue.kind == GDEF and isinstance(s.rvalue, (ListExpr, TupleExpr))): + if ( + isinstance(s.lvalue, NameExpr) + and s.lvalue.name == "__all__" + and s.lvalue.kind == GDEF + and isinstance(s.rvalue, (ListExpr, TupleExpr)) + ): self.add_exports(s.rvalue.items) def visit_while_stmt(self, s: WhileStmt) -> None: @@ -3793,7 +4107,7 @@ def visit_del_stmt(self, s: DelStmt) -> None: self.statement = s s.expr.accept(self) if not self.is_valid_del_target(s.expr): - self.fail('Invalid delete target', s) + self.fail("Invalid delete target", s) def is_valid_del_target(self, s: Expression) -> bool: if isinstance(s, (IndexExpr, NameExpr, MemberExpr)): @@ -3823,8 +4137,11 @@ def visit_nonlocal_decl(self, d: NonlocalDecl) -> None: self.fail(f'No binding for nonlocal "{name}" found', d) if self.locals[-1] is not None and name in self.locals[-1]: - self.fail('Name "{}" is already defined in local ' - 'scope before nonlocal declaration'.format(name), d) + self.fail( + 'Name "{}" is already defined in local ' + "scope before nonlocal declaration".format(name), + d, + ) if name in self.global_decls[-1]: self.fail(f'Name "{name}" is nonlocal and global', d) @@ -3868,10 +4185,11 @@ def visit_name_expr(self, expr: NameExpr) -> None: def bind_name_expr(self, expr: NameExpr, sym: SymbolTableNode) -> None: """Bind name expression to a symbol table node.""" if isinstance(sym.node, TypeVarExpr) and self.tvar_scope.get_binding(sym): - self.fail('"{}" is a type variable and only valid in type ' - 'context'.format(expr.name), expr) + self.fail( + '"{}" is a type variable and only valid in type ' "context".format(expr.name), expr + ) elif isinstance(sym.node, PlaceholderNode): - self.process_placeholder(expr.name, 'name', expr) + self.process_placeholder(expr.name, "name", expr) else: expr.kind = sym.kind expr.node = sym.node @@ -3912,7 +4230,7 @@ def visit_dict_expr(self, expr: DictExpr) -> None: def visit_star_expr(self, expr: StarExpr) -> None: if not expr.valid: # XXX TODO Change this error message - self.fail('Can use starred expression only as assignment target', expr) + self.fail("Can use starred expression only as assignment target", expr) else: expr.expr.accept(self) @@ -3920,8 +4238,12 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None: if not self.is_func_scope(): self.fail('"yield from" outside function', e, serious=True, blocker=True) elif self.is_comprehension_stack[-1]: - self.fail('"yield from" inside comprehension or generator expression', - e, serious=True, blocker=True) + self.fail( + '"yield from" inside comprehension or generator expression', + e, + serious=True, + blocker=True, + ) elif self.function_stack[-1].is_coroutine: self.fail('"yield from" in async function', e, serious=True, blocker=True) else: @@ -3936,15 +4258,15 @@ def visit_call_expr(self, expr: CallExpr) -> None: cast(...). """ expr.callee.accept(self) - if refers_to_fullname(expr.callee, 'typing.cast'): + if refers_to_fullname(expr.callee, "typing.cast"): # Special form cast(...). - if not self.check_fixed_args(expr, 2, 'cast'): + if not self.check_fixed_args(expr, 2, "cast"): return # Translate first argument to an unanalyzed type. try: target = self.expr_to_unanalyzed_type(expr.args[0]) except TypeTranslationError: - self.fail('Cast target is not a type', expr) + self.fail("Cast target is not a type", expr) return # Piggyback CastExpr object to the CallExpr object; it takes # precedence over the CallExpr semantics. @@ -3953,26 +4275,26 @@ def visit_call_expr(self, expr: CallExpr) -> None: expr.analyzed.column = expr.column expr.analyzed.accept(self) elif refers_to_fullname(expr.callee, ASSERT_TYPE_NAMES): - if not self.check_fixed_args(expr, 2, 'assert_type'): + if not self.check_fixed_args(expr, 2, "assert_type"): return # Translate second argument to an unanalyzed type. try: target = self.expr_to_unanalyzed_type(expr.args[1]) except TypeTranslationError: - self.fail('assert_type() type is not a type', expr) + self.fail("assert_type() type is not a type", expr) return expr.analyzed = AssertTypeExpr(expr.args[0], target) expr.analyzed.line = expr.line expr.analyzed.column = expr.column expr.analyzed.accept(self) elif refers_to_fullname(expr.callee, REVEAL_TYPE_NAMES): - if not self.check_fixed_args(expr, 1, 'reveal_type'): + if not self.check_fixed_args(expr, 1, "reveal_type"): return expr.analyzed = RevealExpr(kind=REVEAL_TYPE, expr=expr.args[0]) expr.analyzed.line = expr.line expr.analyzed.column = expr.column expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'builtins.reveal_locals'): + elif refers_to_fullname(expr.callee, "builtins.reveal_locals"): # Store the local variable names into the RevealExpr for use in the # type checking pass local_nodes: List[Var] = [] @@ -3982,50 +4304,51 @@ def visit_call_expr(self, expr: CallExpr) -> None: # Each SymbolTableNode has an attribute node that is nodes.Var # look for variable nodes that marked as is_inferred # Each symboltable node has a Var node as .node - local_nodes = [n.node - for name, n in self.globals.items() - if getattr(n.node, 'is_inferred', False) - and isinstance(n.node, Var)] + local_nodes = [ + n.node + for name, n in self.globals.items() + if getattr(n.node, "is_inferred", False) and isinstance(n.node, Var) + ] elif self.is_class_scope(): # type = None # type: Optional[TypeInfo] if self.type is not None: - local_nodes = [st.node - for st in self.type.names.values() - if isinstance(st.node, Var)] + local_nodes = [ + st.node for st in self.type.names.values() if isinstance(st.node, Var) + ] elif self.is_func_scope(): # locals = None # type: List[Optional[SymbolTable]] if self.locals is not None: symbol_table = self.locals[-1] if symbol_table is not None: - local_nodes = [st.node - for st in symbol_table.values() - if isinstance(st.node, Var)] + local_nodes = [ + st.node for st in symbol_table.values() if isinstance(st.node, Var) + ] expr.analyzed = RevealExpr(kind=REVEAL_LOCALS, local_nodes=local_nodes) expr.analyzed.line = expr.line expr.analyzed.column = expr.column expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'typing.Any'): + elif refers_to_fullname(expr.callee, "typing.Any"): # Special form Any(...) no longer supported. - self.fail('Any(...) is no longer supported. Use cast(Any, ...) instead', expr) - elif refers_to_fullname(expr.callee, 'typing._promote'): + self.fail("Any(...) is no longer supported. Use cast(Any, ...) instead", expr) + elif refers_to_fullname(expr.callee, "typing._promote"): # Special form _promote(...). - if not self.check_fixed_args(expr, 1, '_promote'): + if not self.check_fixed_args(expr, 1, "_promote"): return # Translate first argument to an unanalyzed type. try: target = self.expr_to_unanalyzed_type(expr.args[0]) except TypeTranslationError: - self.fail('Argument 1 to _promote is not a type', expr) + self.fail("Argument 1 to _promote is not a type", expr) return expr.analyzed = PromoteExpr(target) expr.analyzed.line = expr.line expr.analyzed.accept(self) - elif refers_to_fullname(expr.callee, 'builtins.dict'): + elif refers_to_fullname(expr.callee, "builtins.dict"): expr.analyzed = self.translate_dict_call(expr) - elif refers_to_fullname(expr.callee, 'builtins.divmod'): - if not self.check_fixed_args(expr, 2, 'divmod'): + elif refers_to_fullname(expr.callee, "builtins.divmod"): + if not self.check_fixed_args(expr, 2, "divmod"): return - expr.analyzed = OpExpr('divmod', expr.args[0], expr.args[1]) + expr.analyzed = OpExpr("divmod", expr.args[0], expr.args[1]) expr.analyzed.line = expr.line expr.analyzed.accept(self) else: @@ -4033,15 +4356,20 @@ def visit_call_expr(self, expr: CallExpr) -> None: for a in expr.args: a.accept(self) - if (isinstance(expr.callee, MemberExpr) and - isinstance(expr.callee.expr, NameExpr) and - expr.callee.expr.name == '__all__' and - expr.callee.expr.kind == GDEF and - expr.callee.name in ('append', 'extend')): - if expr.callee.name == 'append' and expr.args: + if ( + isinstance(expr.callee, MemberExpr) + and isinstance(expr.callee.expr, NameExpr) + and expr.callee.expr.name == "__all__" + and expr.callee.expr.kind == GDEF + and expr.callee.name in ("append", "extend") + ): + if expr.callee.name == "append" and expr.args: self.add_exports(expr.args[0]) - elif (expr.callee.name == 'extend' and expr.args and - isinstance(expr.args[0], (ListExpr, TupleExpr))): + elif ( + expr.callee.name == "extend" + and expr.args + and isinstance(expr.args[0], (ListExpr, TupleExpr)) + ): self.add_exports(expr.args[0].items) def translate_dict_call(self, call: CallExpr) -> Optional[DictExpr]: @@ -4054,28 +4382,31 @@ def translate_dict_call(self, call: CallExpr) -> Optional[DictExpr]: for a in call.args: a.accept(self) return None - expr = DictExpr([(StrExpr(cast(str, key)), value) # since they are all ARG_NAMED - for key, value in zip(call.arg_names, call.args)]) + expr = DictExpr( + [ + (StrExpr(cast(str, key)), value) # since they are all ARG_NAMED + for key, value in zip(call.arg_names, call.args) + ] + ) expr.set_line(call) expr.accept(self) return expr - def check_fixed_args(self, expr: CallExpr, numargs: int, - name: str) -> bool: + def check_fixed_args(self, expr: CallExpr, numargs: int, name: str) -> bool: """Verify that expr has specified number of positional args. Return True if the arguments are valid. """ - s = 's' + s = "s" if numargs == 1: - s = '' + s = "" if len(expr.args) != numargs: - self.fail('"%s" expects %d argument%s' % (name, numargs, s), - expr) + self.fail('"%s" expects %d argument%s' % (name, numargs, s), expr) return False if expr.arg_kinds != [ARG_POS] * numargs: - self.fail('"%s" must be called with %s positional argument%s' % - (name, numargs, s), expr) + self.fail( + '"%s" must be called with %s positional argument%s' % (name, numargs, s), expr + ) return False return True @@ -4087,7 +4418,7 @@ def visit_member_expr(self, expr: MemberExpr) -> None: sym = self.get_module_symbol(base.node, expr.name) if sym: if isinstance(sym.node, PlaceholderNode): - self.process_placeholder(expr.name, 'attribute', expr) + self.process_placeholder(expr.name, "attribute", expr) return expr.kind = sym.kind expr.fullname = sym.fullname @@ -4128,14 +4459,16 @@ def visit_member_expr(self, expr: MemberExpr) -> None: def visit_op_expr(self, expr: OpExpr) -> None: expr.left.accept(self) - if expr.op in ('and', 'or'): + if expr.op in ("and", "or"): inferred = infer_condition_value(expr.left, self.options) - if ((inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'and') or - (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'or')): + if (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "and") or ( + inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "or" + ): expr.right_unreachable = True return - elif ((inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == 'and') or - (inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == 'or')): + elif (inferred in (ALWAYS_TRUE, MYPY_TRUE) and expr.op == "and") or ( + inferred in (ALWAYS_FALSE, MYPY_FALSE) and expr.op == "or" + ): expr.right_always = True expr.right.accept(self) @@ -4150,12 +4483,15 @@ def visit_unary_expr(self, expr: UnaryExpr) -> None: def visit_index_expr(self, expr: IndexExpr) -> None: base = expr.base base.accept(self) - if (isinstance(base, RefExpr) - and isinstance(base.node, TypeInfo) - and not base.node.is_generic()): + if ( + isinstance(base, RefExpr) + and isinstance(base.node, TypeInfo) + and not base.node.is_generic() + ): expr.index.accept(self) - elif ((isinstance(base, RefExpr) and isinstance(base.node, TypeAlias)) - or refers_to_class_or_function(base)): + elif ( + isinstance(base, RefExpr) and isinstance(base.node, TypeAlias) + ) or refers_to_class_or_function(base): # We need to do full processing on every iteration, since some type # arguments may contain placeholder types. self.analyze_type_application(expr) @@ -4178,16 +4514,22 @@ def analyze_type_application(self, expr: IndexExpr) -> None: target = get_proper_type(alias.target) if isinstance(target, Instance): name = target.type.fullname - if (alias.no_args and # this avoids bogus errors for already reported aliases - name in get_nongen_builtins(self.options.python_version) and - not self.is_stub_file and - not alias.normalized): + if ( + alias.no_args + and name # this avoids bogus errors for already reported aliases + in get_nongen_builtins(self.options.python_version) + and not self.is_stub_file + and not alias.normalized + ): self.fail(no_subscript_builtin_alias(name, propose_alt=False), expr) # ...or directly. else: n = self.lookup_type_node(base) - if (n and n.fullname in get_nongen_builtins(self.options.python_version) and - not self.is_stub_file): + if ( + n + and n.fullname in get_nongen_builtins(self.options.python_version) + and not self.is_stub_file + ): self.fail(no_subscript_builtin_alias(n.fullname, propose_alt=False), expr) def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]]: @@ -4203,7 +4545,7 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] types: List[Type] = [] if isinstance(index, TupleExpr): items = index.items - is_tuple = isinstance(expr.base, RefExpr) and expr.base.fullname == 'builtins.tuple' + is_tuple = isinstance(expr.base, RefExpr) and expr.base.fullname == "builtins.tuple" if is_tuple and len(items) == 2 and isinstance(items[-1], EllipsisExpr): items = items[:-1] else: @@ -4233,23 +4575,31 @@ def analyze_type_application_args(self, expr: IndexExpr) -> Optional[List[Type]] try: typearg = self.expr_to_unanalyzed_type(item) except TypeTranslationError: - self.fail('Type expected within [...]', expr) + self.fail("Type expected within [...]", expr) return None # We always allow unbound type variables in IndexExpr, since we # may be analysing a type alias definition rvalue. The error will be # reported elsewhere if it is not the case. - analyzed = self.anal_type(typearg, allow_unbound_tvars=True, - allow_placeholder=True, - allow_param_spec_literals=has_param_spec) + analyzed = self.anal_type( + typearg, + allow_unbound_tvars=True, + allow_placeholder=True, + allow_param_spec_literals=has_param_spec, + ) if analyzed is None: return None types.append(analyzed) if has_param_spec and num_args == 1 and len(types) > 0: first_arg = get_proper_type(types[0]) - if not (len(types) == 1 and (isinstance(first_arg, Parameters) or - isinstance(first_arg, ParamSpecType) or - isinstance(first_arg, AnyType))): + if not ( + len(types) == 1 + and ( + isinstance(first_arg, Parameters) + or isinstance(first_arg, ParamSpecType) + or isinstance(first_arg, AnyType) + ) + ): types = [Parameters(types, [ARG_POS] * len(types), [None] * len(types))] return types @@ -4321,16 +4671,15 @@ def visit_generator_expr(self, expr: GeneratorExpr) -> None: expr.left_expr.accept(self) self.analyze_comp_for_2(expr) - def analyze_comp_for(self, expr: Union[GeneratorExpr, - DictionaryComprehension]) -> None: + def analyze_comp_for(self, expr: Union[GeneratorExpr, DictionaryComprehension]) -> None: """Analyses the 'comp_for' part of comprehensions (part 1). That is the part after 'for' in (x for x in l if p). This analyzes variables and conditions which are analyzed in a local scope. """ - for i, (index, sequence, conditions) in enumerate(zip(expr.indices, - expr.sequences, - expr.condlists)): + for i, (index, sequence, conditions) in enumerate( + zip(expr.indices, expr.sequences, expr.condlists) + ): if i > 0: sequence.accept(self) # Bind index variables. @@ -4338,8 +4687,7 @@ def analyze_comp_for(self, expr: Union[GeneratorExpr, for cond in conditions: cond.accept(self) - def analyze_comp_for_2(self, expr: Union[GeneratorExpr, - DictionaryComprehension]) -> None: + def analyze_comp_for_2(self, expr: Union[GeneratorExpr, DictionaryComprehension]) -> None: """Analyses the 'comp_for' part of comprehensions (part 2). That is the part after 'for' in (x for x in l if p). This analyzes @@ -4368,8 +4716,12 @@ def visit_yield_expr(self, e: YieldExpr) -> None: if not self.is_func_scope(): self.fail('"yield" outside function', e, serious=True, blocker=True) elif self.is_comprehension_stack[-1]: - self.fail('"yield" inside comprehension or generator expression', - e, serious=True, blocker=True) + self.fail( + '"yield" inside comprehension or generator expression', + e, + serious=True, + blocker=True, + ) elif self.function_stack[-1].is_coroutine: if self.options.python_version < (3, 6): self.fail('"yield" in async function', e, serious=True, blocker=True) @@ -4432,8 +4784,9 @@ def visit_class_pattern(self, p: ClassPattern) -> None: # Lookup functions # - def lookup(self, name: str, ctx: Context, - suppress_errors: bool = False) -> Optional[SymbolTableNode]: + def lookup( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> Optional[SymbolTableNode]: """Look up an unqualified (no dots) name in all active namespaces. Note that the result may contain a PlaceholderNode. The caller may @@ -4478,7 +4831,7 @@ def lookup(self, name: str, ctx: Context, if name in self.globals: return self.globals[name] # 5. Builtins - b = self.globals.get('__builtins__', None) + b = self.globals.get("__builtins__", None) if b: assert isinstance(b.node, MypyFile) table = b.node.names @@ -4517,11 +4870,13 @@ class C: if self.statement is None: # Assume it's fine -- don't have enough context to check return True - return (node is None - or self.is_textually_before_statement(node) - or not self.is_defined_in_current_module(node.fullname) - or isinstance(node, TypeInfo) - or (isinstance(node, PlaceholderNode) and node.becomes_typeinfo)) + return ( + node is None + or self.is_textually_before_statement(node) + or not self.is_defined_in_current_module(node.fullname) + or isinstance(node, TypeInfo) + or (isinstance(node, PlaceholderNode) and node.becomes_typeinfo) + ) def is_textually_before_statement(self, node: SymbolNode) -> bool: """Check if a node is defined textually before the current statement @@ -4545,11 +4900,13 @@ def is_textually_before_statement(self, node: SymbolNode) -> bool: def is_overloaded_item(self, node: SymbolNode, statement: Statement) -> bool: """Check whether the function belongs to the overloaded variants""" if isinstance(node, OverloadedFuncDef) and isinstance(statement, FuncDef): - in_items = statement in {item.func if isinstance(item, Decorator) - else item for item in node.items} - in_impl = (node.impl is not None and - ((isinstance(node.impl, Decorator) and statement is node.impl.func) - or statement is node.impl)) + in_items = statement in { + item.func if isinstance(item, Decorator) else item for item in node.items + } + in_impl = node.impl is not None and ( + (isinstance(node.impl, Decorator) and statement is node.impl.func) + or statement is node.impl + ) return in_items or in_impl return False @@ -4558,8 +4915,9 @@ def is_defined_in_current_module(self, fullname: Optional[str]) -> bool: return False return module_prefix(self.modules, fullname) == self.cur_mod_id - def lookup_qualified(self, name: str, ctx: Context, - suppress_errors: bool = False) -> Optional[SymbolTableNode]: + def lookup_qualified( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> Optional[SymbolTableNode]: """Lookup a qualified name in all activate namespaces. Note that the result may contain a PlaceholderNode. The caller may @@ -4569,10 +4927,10 @@ def lookup_qualified(self, name: str, ctx: Context, is true or the current namespace is incomplete. In the latter case defer. """ - if '.' not in name: + if "." not in name: # Simple case: look up a short name. return self.lookup(name, ctx, suppress_errors=suppress_errors) - parts = name.split('.') + parts = name.split(".") namespace = self.cur_mod_id sym = self.lookup(parts[0], ctx, suppress_errors=suppress_errors) if sym: @@ -4626,15 +4984,15 @@ def get_module_symbol(self, node: MypyFile, name: str) -> Optional[SymbolTableNo names = node.names sym = names.get(name) if not sym: - fullname = module + '.' + name + fullname = module + "." + name if fullname in self.modules: sym = SymbolTableNode(GDEF, self.modules[fullname]) elif self.is_incomplete_namespace(module): self.record_incomplete_ref() - elif ('__getattr__' in names - and (node.is_stub - or self.options.python_version >= (3, 7))): - gvar = self.create_getattr_var(names['__getattr__'], name, fullname) + elif "__getattr__" in names and ( + node.is_stub or self.options.python_version >= (3, 7) + ): + gvar = self.create_getattr_var(names["__getattr__"], name, fullname) if gvar: sym = SymbolTableNode(GDEF, gvar) elif self.is_missing_module(fullname): @@ -4651,8 +5009,9 @@ def get_module_symbol(self, node: MypyFile, name: str) -> Optional[SymbolTableNo def is_missing_module(self, module: str) -> bool: return module in self.missing_modules - def implicit_symbol(self, sym: SymbolTableNode, name: str, parts: List[str], - source_type: AnyType) -> SymbolTableNode: + def implicit_symbol( + self, sym: SymbolTableNode, name: str, parts: List[str], source_type: AnyType + ) -> SymbolTableNode: """Create symbol for a qualified name reference through Any type.""" if sym.node is None: basename = None @@ -4661,14 +5020,15 @@ def implicit_symbol(self, sym: SymbolTableNode, name: str, parts: List[str], if basename is None: fullname = name else: - fullname = basename + '.' + '.'.join(parts) + fullname = basename + "." + ".".join(parts) var_type = AnyType(TypeOfAny.from_another_any, source_type) var = Var(parts[-1], var_type) var._fullname = fullname return SymbolTableNode(GDEF, var) - def create_getattr_var(self, getattr_defn: SymbolTableNode, - name: str, fullname: str) -> Optional[Var]: + def create_getattr_var( + self, getattr_defn: SymbolTableNode, name: str, fullname: str + ) -> Optional[Var]: """Create a dummy variable using module-level __getattr__ return type. If not possible, return None. @@ -4708,8 +5068,8 @@ def lookup_fully_qualified_or_none(self, fullname: str) -> Optional[SymbolTableN # TODO: unify/clean-up/simplify lookup methods, see #4157. # TODO: support nested classes (but consider performance impact, # we might keep the module level only lookup for thing like 'builtins.int'). - assert '.' in fullname - module, name = fullname.rsplit('.', maxsplit=1) + assert "." in fullname + module, name = fullname.rsplit(".", maxsplit=1) if module not in self.modules: return None filenode = self.modules[module] @@ -4720,10 +5080,10 @@ def lookup_fully_qualified_or_none(self, fullname: str) -> Optional[SymbolTableN return result def object_type(self) -> Instance: - return self.named_type('builtins.object') + return self.named_type("builtins.object") def str_type(self) -> Instance: - return self.named_type('builtins.str') + return self.named_type("builtins.str") def named_type(self, fullname: str, args: Optional[List[Type]] = None) -> Instance: sym = self.lookup_fully_qualified(fullname) @@ -4735,8 +5095,9 @@ def named_type(self, fullname: str, args: Optional[List[Type]] = None) -> Instan return Instance(node, args) return Instance(node, [AnyType(TypeOfAny.special_form)] * len(node.defn.type_vars)) - def named_type_or_none(self, fullname: str, - args: Optional[List[Type]] = None) -> Optional[Instance]: + def named_type_or_none( + self, fullname: str, args: Optional[List[Type]] = None + ) -> Optional[Instance]: sym = self.lookup_fully_qualified_or_none(fullname) if not sym or isinstance(sym.node, PlaceholderNode): return None @@ -4766,14 +5127,16 @@ def lookup_current_scope(self, name: str) -> Optional[SymbolTableNode]: # Adding symbols # - def add_symbol(self, - name: str, - node: SymbolNode, - context: Context, - module_public: bool = True, - module_hidden: bool = False, - can_defer: bool = True, - escape_comprehensions: bool = False) -> bool: + def add_symbol( + self, + name: str, + node: SymbolNode, + context: Context, + module_public: bool = True, + module_hidden: bool = False, + can_defer: bool = True, + escape_comprehensions: bool = False, + ) -> bool: """Add symbol to the currently active symbol table. Generally additions to symbol table should go through this method or @@ -4791,10 +5154,9 @@ def add_symbol(self, kind = MDEF else: kind = GDEF - symbol = SymbolTableNode(kind, - node, - module_public=module_public, - module_hidden=module_hidden) + symbol = SymbolTableNode( + kind, node, module_public=module_public, module_hidden=module_hidden + ) return self.add_symbol_table_node(name, symbol, context, can_defer, escape_comprehensions) def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None: @@ -4819,12 +5181,14 @@ def add_symbol_skip_local(self, name: str, node: SymbolNode) -> None: symbol = SymbolTableNode(kind, node) names[name] = symbol - def add_symbol_table_node(self, - name: str, - symbol: SymbolTableNode, - context: Optional[Context] = None, - can_defer: bool = True, - escape_comprehensions: bool = False) -> bool: + def add_symbol_table_node( + self, + name: str, + symbol: SymbolTableNode, + context: Optional[Context] = None, + can_defer: bool = True, + escape_comprehensions: bool = False, + ) -> bool: """Add symbol table node to the currently active symbol table. Return True if we actually added the symbol, or False if we refused @@ -4847,13 +5211,15 @@ def add_symbol_table_node(self, existing = names.get(name) if isinstance(symbol.node, PlaceholderNode) and can_defer: if context is not None: - self.process_placeholder(name, 'name', context) + self.process_placeholder(name, "name", context) else: # see note in docstring describing None contexts self.defer() - if (existing is not None - and context is not None - and not is_valid_replacement(existing, symbol)): + if ( + existing is not None + and context is not None + and not is_valid_replacement(existing, symbol) + ): # There is an existing node, so this may be a redefinition. # If the new node points to the same node as the old one, # or if both old and new nodes are placeholders, we don't @@ -4866,19 +5232,15 @@ def add_symbol_table_node(self, if not is_same_symbol(old, new): if isinstance(new, (FuncDef, Decorator, OverloadedFuncDef, TypeInfo)): self.add_redefinition(names, name, symbol) - if not (isinstance(new, (FuncDef, Decorator)) - and self.set_original_def(old, new)): + if not (isinstance(new, (FuncDef, Decorator)) and self.set_original_def(old, new)): self.name_already_defined(name, context, existing) - elif (name not in self.missing_names[-1] and '*' not in self.missing_names[-1]): + elif name not in self.missing_names[-1] and "*" not in self.missing_names[-1]: names[name] = symbol self.progress = True return True return False - def add_redefinition(self, - names: SymbolTable, - name: str, - symbol: SymbolTableNode) -> None: + def add_redefinition(self, names: SymbolTable, name: str, symbol: SymbolTableNode) -> None: """Add a symbol table node that reflects a redefinition as a function or a class. Redefinitions need to be added to the symbol table so that they can be found @@ -4897,9 +5259,9 @@ def add_redefinition(self, symbol.no_serialize = True while True: if i == 1: - new_name = f'{name}-redefinition' + new_name = f"{name}-redefinition" else: - new_name = f'{name}-redefinition{i}' + new_name = f"{name}-redefinition{i}" existing = names.get(new_name) if existing is None: names[new_name] = symbol @@ -4916,22 +5278,22 @@ def add_local(self, node: Union[Var, FuncDef, OverloadedFuncDef], context: Conte node._fullname = name self.add_symbol(name, node, context) - def add_module_symbol(self, - id: str, - as_id: str, - context: Context, - module_public: bool, - module_hidden: bool) -> None: + def add_module_symbol( + self, id: str, as_id: str, context: Context, module_public: bool, module_hidden: bool + ) -> None: """Add symbol that is a reference to a module object.""" if id in self.modules: node = self.modules[id] - self.add_symbol(as_id, node, context, - module_public=module_public, - module_hidden=module_hidden) + self.add_symbol( + as_id, node, context, module_public=module_public, module_hidden=module_hidden + ) else: self.add_unknown_imported_symbol( - as_id, context, target_name=id, module_public=module_public, - module_hidden=module_hidden + as_id, + context, + target_name=id, + module_public=module_public, + module_hidden=module_hidden, ) def _get_node_for_class_scoped_import( @@ -4965,7 +5327,7 @@ def _get_node_for_class_scoped_import( # In theory we could construct a new node here as well, but in practice # it doesn't work well, see #12197 typ: Optional[Type] = AnyType(TypeOfAny.from_error) - self.fail('Unsupported class scoped import', context) + self.fail("Unsupported class scoped import", context) else: typ = f(symbol_node).type symbol_node = Var(name, typ) @@ -4976,12 +5338,14 @@ def _get_node_for_class_scoped_import( symbol_node.column = context.column return symbol_node - def add_imported_symbol(self, - name: str, - node: SymbolTableNode, - context: Context, - module_public: bool, - module_hidden: bool) -> None: + def add_imported_symbol( + self, + name: str, + node: SymbolTableNode, + context: Context, + module_public: bool, + module_hidden: bool, + ) -> None: """Add an alias to an existing symbol through import.""" assert not module_hidden or not module_public @@ -4990,17 +5354,19 @@ def add_imported_symbol(self, if self.is_class_scope(): symbol_node = self._get_node_for_class_scoped_import(name, symbol_node, context) - symbol = SymbolTableNode(node.kind, symbol_node, - module_public=module_public, - module_hidden=module_hidden) + symbol = SymbolTableNode( + node.kind, symbol_node, module_public=module_public, module_hidden=module_hidden + ) self.add_symbol_table_node(name, symbol, context) - def add_unknown_imported_symbol(self, - name: str, - context: Context, - target_name: Optional[str], - module_public: bool, - module_hidden: bool) -> None: + def add_unknown_imported_symbol( + self, + name: str, + context: Context, + target_name: Optional[str], + module_public: bool, + module_hidden: bool, + ) -> None: """Add symbol that we don't know what it points to because resolving an import failed. This can happen if a module is missing, or it is present, but doesn't have @@ -5056,11 +5422,12 @@ def defer(self, debug_context: Optional[Context] = None) -> None: 'record_incomplete_ref', call this implicitly, or when needed. They are usually preferable to a direct defer() call. """ - assert not self.final_iteration, 'Must not defer during final iteration' + assert not self.final_iteration, "Must not defer during final iteration" self.deferred = True # Store debug info for this deferral. - line = (debug_context.line if debug_context else - self.statement.line if self.statement else -1) + line = ( + debug_context.line if debug_context else self.statement.line if self.statement else -1 + ) self.deferral_debug_context.append((self.cur_mod_id, line)) def track_incomplete_refs(self) -> Tag: @@ -5076,10 +5443,14 @@ def record_incomplete_ref(self) -> None: self.defer() self.num_incomplete_refs += 1 - def mark_incomplete(self, name: str, node: Node, - becomes_typeinfo: bool = False, - module_public: bool = True, - module_hidden: bool = False) -> None: + def mark_incomplete( + self, + name: str, + node: Node, + becomes_typeinfo: bool = False, + module_public: bool = True, + module_hidden: bool = False, + ) -> None: """Mark a definition as incomplete (and defer current analysis target). Also potentially mark the current namespace as incomplete. @@ -5091,16 +5462,21 @@ def mark_incomplete(self, name: str, node: Node, named tuples that will create TypeInfos). """ self.defer(node) - if name == '*': + if name == "*": self.incomplete = True elif not self.is_global_or_nonlocal(name): fullname = self.qualified_name(name) assert self.statement - placeholder = PlaceholderNode(fullname, node, self.statement.line, - becomes_typeinfo=becomes_typeinfo) - self.add_symbol(name, placeholder, - module_public=module_public, module_hidden=module_hidden, - context=dummy_context()) + placeholder = PlaceholderNode( + fullname, node, self.statement.line, becomes_typeinfo=becomes_typeinfo + ) + self.add_symbol( + name, + placeholder, + module_public=module_public, + module_hidden=module_hidden, + context=dummy_context(), + ) self.missing_names[-1].add(name) def is_incomplete_namespace(self, fullname: str) -> bool: @@ -5130,15 +5506,16 @@ def cannot_resolve_name(self, name: str, kind: str, ctx: Context) -> None: def qualified_name(self, name: str) -> str: if self.type is not None: - return self.type._fullname + '.' + name + return self.type._fullname + "." + name elif self.is_func_scope(): return name else: - return self.cur_mod_id + '.' + name + return self.cur_mod_id + "." + name @contextmanager - def enter(self, - function: Union[FuncItem, GeneratorExpr, DictionaryComprehension]) -> Iterator[None]: + def enter( + self, function: Union[FuncItem, GeneratorExpr, DictionaryComprehension] + ) -> Iterator[None]: """Enter a function, generator or comprehension scope.""" names = self.saved_locals.setdefault(function, SymbolTable()) self.locals.append(names) @@ -5194,8 +5571,9 @@ def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTab names = self.globals else: names_candidate = self.locals[-1 - i] - assert names_candidate is not None, \ - "Escaping comprehension from invalid scope" + assert ( + names_candidate is not None + ), "Escaping comprehension from invalid scope" names = names_candidate break else: @@ -5210,9 +5588,9 @@ def current_symbol_table(self, escape_comprehensions: bool = False) -> SymbolTab return names def is_global_or_nonlocal(self, name: str) -> bool: - return (self.is_func_scope() - and (name in self.global_decls[-1] - or name in self.nonlocal_decls[-1])) + return self.is_func_scope() and ( + name in self.global_decls[-1] or name in self.nonlocal_decls[-1] + ) def add_exports(self, exp_or_exps: Union[Iterable[Expression], Expression]) -> None: exps = [exp_or_exps] if isinstance(exp_or_exps, Expression) else exp_or_exps @@ -5222,11 +5600,13 @@ def add_exports(self, exp_or_exps: Union[Iterable[Expression], Expression]) -> N def name_not_defined(self, name: str, ctx: Context, namespace: Optional[str] = None) -> None: incomplete = self.is_incomplete_namespace(namespace or self.cur_mod_id) - if (namespace is None - and self.type - and not self.is_func_scope() - and self.incomplete_type_stack[-1] - and not self.final_iteration): + if ( + namespace is None + and self.type + and not self.is_func_scope() + and self.incomplete_type_stack[-1] + and not self.final_iteration + ): # We are processing a class body for the first time, so it is incomplete. incomplete = True if incomplete: @@ -5237,37 +5617,35 @@ def name_not_defined(self, name: str, ctx: Context, namespace: Optional[str] = N message = f'Name "{name}" is not defined' self.fail(message, ctx, code=codes.NAME_DEFINED) - if f'builtins.{name}' in SUGGESTED_TEST_FIXTURES: + if f"builtins.{name}" in SUGGESTED_TEST_FIXTURES: # The user probably has a missing definition in a test fixture. Let's verify. - fullname = f'builtins.{name}' + fullname = f"builtins.{name}" if self.lookup_fully_qualified_or_none(fullname) is None: # Yes. Generate a helpful note. self.msg.add_fixture_note(fullname, ctx) modules_with_unimported_hints = { - name.split('.', 1)[0] - for name in TYPES_FOR_UNIMPORTED_HINTS - } - lowercased = { - name.lower(): name - for name in TYPES_FOR_UNIMPORTED_HINTS + name.split(".", 1)[0] for name in TYPES_FOR_UNIMPORTED_HINTS } + lowercased = {name.lower(): name for name in TYPES_FOR_UNIMPORTED_HINTS} for module in modules_with_unimported_hints: - fullname = f'{module}.{name}'.lower() + fullname = f"{module}.{name}".lower() if fullname not in lowercased: continue # User probably forgot to import these types. hint = ( 'Did you forget to import it from "{module}"?' ' (Suggestion: "from {module} import {name}")' - ).format(module=module, name=lowercased[fullname].rsplit('.', 1)[-1]) + ).format(module=module, name=lowercased[fullname].rsplit(".", 1)[-1]) self.note(hint, ctx, code=codes.NAME_DEFINED) - def already_defined(self, - name: str, - ctx: Context, - original_ctx: Optional[Union[SymbolTableNode, SymbolNode]], - noun: str) -> None: + def already_defined( + self, + name: str, + ctx: Context, + original_ctx: Optional[Union[SymbolTableNode, SymbolNode]], + noun: str, + ) -> None: if isinstance(original_ctx, SymbolTableNode): node: Optional[SymbolNode] = original_ctx.node elif isinstance(original_ctx, SymbolNode): @@ -5279,33 +5657,36 @@ def already_defined(self, # Since this is an import, original_ctx.node points to the module definition. # Therefore its line number is always 1, which is not useful for this # error message. - extra_msg = ' (by an import)' + extra_msg = " (by an import)" elif node and node.line != -1 and self.is_local_name(node.fullname): # TODO: Using previous symbol node may give wrong line. We should use # the line number where the binding was established instead. - extra_msg = f' on line {node.line}' + extra_msg = f" on line {node.line}" else: - extra_msg = ' (possibly by an import)' - self.fail(f'{noun} "{unmangle(name)}" already defined{extra_msg}', ctx, - code=codes.NO_REDEF) - - def name_already_defined(self, - name: str, - ctx: Context, - original_ctx: Optional[Union[SymbolTableNode, SymbolNode]] = None - ) -> None: - self.already_defined(name, ctx, original_ctx, noun='Name') - - def attribute_already_defined(self, - name: str, - ctx: Context, - original_ctx: Optional[Union[SymbolTableNode, SymbolNode]] = None - ) -> None: - self.already_defined(name, ctx, original_ctx, noun='Attribute') + extra_msg = " (possibly by an import)" + self.fail( + f'{noun} "{unmangle(name)}" already defined{extra_msg}', ctx, code=codes.NO_REDEF + ) + + def name_already_defined( + self, + name: str, + ctx: Context, + original_ctx: Optional[Union[SymbolTableNode, SymbolNode]] = None, + ) -> None: + self.already_defined(name, ctx, original_ctx, noun="Name") + + def attribute_already_defined( + self, + name: str, + ctx: Context, + original_ctx: Optional[Union[SymbolTableNode, SymbolNode]] = None, + ) -> None: + self.already_defined(name, ctx, original_ctx, noun="Attribute") def is_local_name(self, name: str) -> bool: """Does name look like reference to a definition in the current module?""" - return self.is_defined_in_current_module(name) or '.' not in name + return self.is_defined_in_current_module(name) or "." not in name def in_checked_function(self) -> bool: """Should we type-check the current function? @@ -5321,10 +5702,7 @@ def in_checked_function(self) -> bool: current_index = len(self.function_stack) - 1 while current_index >= 0: current_func = self.function_stack[current_index] - if ( - isinstance(current_func, FuncItem) - and not isinstance(current_func, LambdaExpr) - ): + if isinstance(current_func, FuncItem) and not isinstance(current_func, LambdaExpr): return not current_func.is_dynamic() # Special case, `lambda` inherits the "checked" state from its parent. @@ -5336,13 +5714,15 @@ def in_checked_function(self) -> bool: # no regular functions. return True - def fail(self, - msg: str, - ctx: Context, - serious: bool = False, - *, - code: Optional[ErrorCode] = None, - blocker: bool = False) -> None: + def fail( + self, + msg: str, + ctx: Context, + serious: bool = False, + *, + code: Optional[ErrorCode] = None, + blocker: bool = False, + ) -> None: if not serious and not self.in_checked_function(): return # In case it's a bug and we don't really have context @@ -5352,7 +5732,7 @@ def fail(self, def note(self, msg: str, ctx: Context, code: Optional[ErrorCode] = None) -> None: if not self.in_checked_function(): return - self.errors.report(ctx.get_line(), ctx.get_column(), msg, severity='note', code=code) + self.errors.report(ctx.get_line(), ctx.get_column(), msg, severity="note", code=code) def accept(self, node: Node) -> None: try: @@ -5360,14 +5740,14 @@ def accept(self, node: Node) -> None: except Exception as err: report_internal_error(err, self.errors.file, node.line, self.errors, self.options) - def expr_to_analyzed_type(self, - expr: Expression, - report_invalid_types: bool = True, - allow_placeholder: bool = False) -> Optional[Type]: + def expr_to_analyzed_type( + self, expr: Expression, report_invalid_types: bool = True, allow_placeholder: bool = False + ) -> Optional[Type]: if isinstance(expr, CallExpr): expr.accept(self) - internal_name, info = self.named_tuple_analyzer.check_namedtuple(expr, None, - self.is_func_scope()) + internal_name, info = self.named_tuple_analyzer.check_namedtuple( + expr, None, self.is_func_scope() + ) if internal_name is None: # Some form of namedtuple is the only valid type that looks like a call # expression. This isn't a valid type. @@ -5379,8 +5759,9 @@ def expr_to_analyzed_type(self, fallback = Instance(info, []) return TupleType(info.tuple_type.items, fallback=fallback) typ = self.expr_to_unanalyzed_type(expr) - return self.anal_type(typ, report_invalid_types=report_invalid_types, - allow_placeholder=allow_placeholder) + return self.anal_type( + typ, report_invalid_types=report_invalid_types, allow_placeholder=allow_placeholder + ) def analyze_type_expr(self, expr: Expression) -> None: # There are certain expressions that mypy does not need to semantically analyze, @@ -5392,27 +5773,32 @@ def analyze_type_expr(self, expr: Expression) -> None: with self.tvar_scope_frame(TypeVarLikeScope()): expr.accept(self) - def type_analyzer(self, *, - tvar_scope: Optional[TypeVarLikeScope] = None, - allow_tuple_literal: bool = False, - allow_unbound_tvars: bool = False, - allow_placeholder: bool = False, - allow_required: bool = False, - allow_param_spec_literals: bool = False, - report_invalid_types: bool = True) -> TypeAnalyser: + def type_analyzer( + self, + *, + tvar_scope: Optional[TypeVarLikeScope] = None, + allow_tuple_literal: bool = False, + allow_unbound_tvars: bool = False, + allow_placeholder: bool = False, + allow_required: bool = False, + allow_param_spec_literals: bool = False, + report_invalid_types: bool = True, + ) -> TypeAnalyser: if tvar_scope is None: tvar_scope = self.tvar_scope - tpan = TypeAnalyser(self, - tvar_scope, - self.plugin, - self.options, - self.is_typeshed_stub_file, - allow_unbound_tvars=allow_unbound_tvars, - allow_tuple_literal=allow_tuple_literal, - report_invalid_types=report_invalid_types, - allow_placeholder=allow_placeholder, - allow_required=allow_required, - allow_param_spec_literals=allow_param_spec_literals) + tpan = TypeAnalyser( + self, + tvar_scope, + self.plugin, + self.options, + self.is_typeshed_stub_file, + allow_unbound_tvars=allow_unbound_tvars, + allow_tuple_literal=allow_tuple_literal, + report_invalid_types=report_invalid_types, + allow_placeholder=allow_placeholder, + allow_required=allow_required, + allow_param_spec_literals=allow_param_spec_literals, + ) tpan.in_dynamic_func = bool(self.function_stack and self.function_stack[-1].is_dynamic()) tpan.global_scope = not self.type and not self.function_stack return tpan @@ -5420,16 +5806,19 @@ def type_analyzer(self, *, def expr_to_unanalyzed_type(self, node: Expression) -> ProperType: return expr_to_unanalyzed_type(node, self.options, self.is_stub_file) - def anal_type(self, - typ: Type, *, - tvar_scope: Optional[TypeVarLikeScope] = None, - allow_tuple_literal: bool = False, - allow_unbound_tvars: bool = False, - allow_placeholder: bool = False, - allow_required: bool = False, - allow_param_spec_literals: bool = False, - report_invalid_types: bool = True, - third_pass: bool = False) -> Optional[Type]: + def anal_type( + self, + typ: Type, + *, + tvar_scope: Optional[TypeVarLikeScope] = None, + allow_tuple_literal: bool = False, + allow_unbound_tvars: bool = False, + allow_placeholder: bool = False, + allow_required: bool = False, + allow_param_spec_literals: bool = False, + report_invalid_types: bool = True, + third_pass: bool = False, + ) -> Optional[Type]: """Semantically analyze a type. Args: @@ -5450,13 +5839,15 @@ def anal_type(self, NOTE: The caller shouldn't defer even if this returns None or a placeholder type. """ - a = self.type_analyzer(tvar_scope=tvar_scope, - allow_unbound_tvars=allow_unbound_tvars, - allow_tuple_literal=allow_tuple_literal, - allow_placeholder=allow_placeholder, - allow_required=allow_required, - allow_param_spec_literals=allow_param_spec_literals, - report_invalid_types=report_invalid_types) + a = self.type_analyzer( + tvar_scope=tvar_scope, + allow_unbound_tvars=allow_unbound_tvars, + allow_tuple_literal=allow_tuple_literal, + allow_placeholder=allow_placeholder, + allow_required=allow_required, + allow_param_spec_literals=allow_param_spec_literals, + report_invalid_types=report_invalid_types, + ) tag = self.track_incomplete_refs() typ = typ.accept(a) if self.found_incomplete_ref(tag): @@ -5472,12 +5863,15 @@ def schedule_patch(self, priority: int, patch: Callable[[], None]) -> None: self.patches.append((priority, patch)) def report_hang(self) -> None: - print('Deferral trace:') + print("Deferral trace:") for mod, line in self.deferral_debug_context: - print(f' {mod}:{line}') - self.errors.report(-1, -1, - 'INTERNAL ERROR: maximum semantic analysis iteration count reached', - blocker=True) + print(f" {mod}:{line}") + self.errors.report( + -1, + -1, + "INTERNAL ERROR: maximum semantic analysis iteration count reached", + blocker=True, + ) def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None: """Add dependency from trigger to a target. @@ -5488,9 +5882,9 @@ def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> N target = self.scope.current_target() self.cur_mod_node.plugin_deps.setdefault(trigger, set()).add(target) - def add_type_alias_deps(self, - aliases_used: Iterable[str], - target: Optional[str] = None) -> None: + def add_type_alias_deps( + self, aliases_used: Iterable[str], target: Optional[str] = None + ) -> None: """Add full names of type aliases on which the current node depends. This is used by fine-grained incremental mode to re-check the corresponding nodes. @@ -5514,17 +5908,15 @@ def is_initial_mangled_global(self, name: str) -> bool: def parse_bool(self, expr: Expression) -> Optional[bool]: if isinstance(expr, NameExpr): - if expr.fullname == 'builtins.True': + if expr.fullname == "builtins.True": return True - if expr.fullname == 'builtins.False': + if expr.fullname == "builtins.False": return False return None def set_future_import_flags(self, module_name: str) -> None: if module_name in FUTURE_IMPORTS: - self.modules[self.cur_mod_id].future_import_flags.add( - FUTURE_IMPORTS[module_name], - ) + self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name]) def is_future_flag_set(self, flag: str) -> bool: return self.modules[self.cur_mod_id].is_future_flag_set(flag) @@ -5549,8 +5941,9 @@ def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: return sig return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) elif isinstance(sig, Overloaded): - return Overloaded([cast(CallableType, replace_implicit_first_type(i, new)) - for i in sig.items]) + return Overloaded( + [cast(CallableType, replace_implicit_first_type(i, new)) for i in sig.items] + ) else: assert False @@ -5571,8 +5964,9 @@ def refers_to_fullname(node: Expression, fullnames: Union[str, Tuple[str, ...]]) def refers_to_class_or_function(node: Expression) -> bool: """Does semantically analyzed node refer to a class?""" - return (isinstance(node, RefExpr) and - isinstance(node.node, (TypeInfo, FuncDef, OverloadedFuncDef))) + return isinstance(node, RefExpr) and isinstance( + node.node, (TypeInfo, FuncDef, OverloadedFuncDef) + ) def find_duplicate(list: List[T]) -> Optional[T]: @@ -5586,15 +5980,14 @@ def find_duplicate(list: List[T]) -> Optional[T]: return None -def remove_imported_names_from_symtable(names: SymbolTable, - module: str) -> None: +def remove_imported_names_from_symtable(names: SymbolTable, module: str) -> None: """Remove all imported names from the symbol table of a module.""" removed: List[str] = [] for name, node in names.items(): if node.node is None: continue fullname = node.node.fullname - prefix = fullname[:fullname.rfind('.')] + prefix = fullname[: fullname.rfind(".")] if prefix != module: removed.append(name) for name in removed: @@ -5650,11 +6043,13 @@ def names_modified_in_lvalue(lvalue: Lvalue) -> List[NameExpr]: def is_same_var_from_getattr(n1: Optional[SymbolNode], n2: Optional[SymbolNode]) -> bool: """Do n1 and n2 refer to the same Var derived from module-level __getattr__?""" - return (isinstance(n1, Var) - and n1.from_module_getattr - and isinstance(n2, Var) - and n2.from_module_getattr - and n1.fullname == n2.fullname) + return ( + isinstance(n1, Var) + and n1.from_module_getattr + and isinstance(n2, Var) + and n2.from_module_getattr + and n1.fullname == n2.fullname + ) def dummy_context() -> Context: @@ -5679,7 +6074,8 @@ def is_valid_replacement(old: SymbolTableNode, new: SymbolTableNode) -> bool: def is_same_symbol(a: Optional[SymbolNode], b: Optional[SymbolNode]) -> bool: - return (a == b - or (isinstance(a, PlaceholderNode) - and isinstance(b, PlaceholderNode)) - or is_same_var_from_getattr(a, b)) + return ( + a == b + or (isinstance(a, PlaceholderNode) and isinstance(b, PlaceholderNode)) + or is_same_var_from_getattr(a, b) + ) diff --git a/mypy/semanal_classprop.py b/mypy/semanal_classprop.py index 5344f321420f0..2fe22644929fe 100644 --- a/mypy/semanal_classprop.py +++ b/mypy/semanal_classprop.py @@ -3,24 +3,29 @@ These happen after semantic analysis and before type checking. """ -from typing import List, Set, Optional +from typing import List, Optional, Set + from typing_extensions import Final +from mypy.errors import Errors from mypy.nodes import ( - Node, TypeInfo, Var, Decorator, OverloadedFuncDef, SymbolTable, CallExpr, PromoteExpr, + CallExpr, + Decorator, + Node, + OverloadedFuncDef, + PromoteExpr, + SymbolTable, + TypeInfo, + Var, ) -from mypy.types import Instance, Type -from mypy.errors import Errors from mypy.options import Options +from mypy.types import Instance, Type # Hard coded type promotions (shared between all Python versions). # These add extra ad-hoc edges to the subtyping relation. For example, # int is considered a subtype of float, even though there is no # subclass relationship. -TYPE_PROMOTIONS: Final = { - 'builtins.int': 'float', - 'builtins.float': 'complex', -} +TYPE_PROMOTIONS: Final = {"builtins.int": "float", "builtins.float": "complex"} # Hard coded type promotions for Python 3. # @@ -28,10 +33,7 @@ # as some functions only accept bytes objects. Here convenience # trumps safety. TYPE_PROMOTIONS_PYTHON3: Final = TYPE_PROMOTIONS.copy() -TYPE_PROMOTIONS_PYTHON3.update({ - 'builtins.bytearray': 'bytes', - 'builtins.memoryview': 'bytes', -}) +TYPE_PROMOTIONS_PYTHON3.update({"builtins.bytearray": "bytes", "builtins.memoryview": "bytes"}) # Hard coded type promotions for Python 2. # @@ -39,11 +41,9 @@ # for convenience and also for Python 3 compatibility # (bytearray -> str). TYPE_PROMOTIONS_PYTHON2: Final = TYPE_PROMOTIONS.copy() -TYPE_PROMOTIONS_PYTHON2.update({ - 'builtins.str': 'unicode', - 'builtins.bytearray': 'str', - 'builtins.memoryview': 'str', -}) +TYPE_PROMOTIONS_PYTHON2.update( + {"builtins.str": "unicode", "builtins.bytearray": "str", "builtins.memoryview": "str"} +) def calculate_class_abstract_status(typ: TypeInfo, is_stub_file: bool, errors: Errors) -> None: @@ -97,32 +97,37 @@ def calculate_class_abstract_status(typ: TypeInfo, is_stub_file: bool, errors: E # implement some methods. typ.abstract_attributes = sorted(abstract) if is_stub_file: - if typ.declared_metaclass and typ.declared_metaclass.type.fullname == 'abc.ABCMeta': + if typ.declared_metaclass and typ.declared_metaclass.type.fullname == "abc.ABCMeta": return if typ.is_protocol: return if abstract and not abstract_in_this_class: + def report(message: str, severity: str) -> None: errors.report(typ.line, typ.column, message, severity=severity) attrs = ", ".join(f'"{attr}"' for attr in sorted(abstract)) - report(f"Class {typ.fullname} has abstract attributes {attrs}", 'error') - report("If it is meant to be abstract, add 'abc.ABCMeta' as an explicit metaclass", - 'note') + report(f"Class {typ.fullname} has abstract attributes {attrs}", "error") + report( + "If it is meant to be abstract, add 'abc.ABCMeta' as an explicit metaclass", "note" + ) if typ.is_final and abstract: attrs = ", ".join(f'"{attr}"' for attr in sorted(abstract)) - errors.report(typ.line, typ.column, - f"Final class {typ.fullname} has abstract attributes {attrs}") + errors.report( + typ.line, typ.column, f"Final class {typ.fullname} has abstract attributes {attrs}" + ) def check_protocol_status(info: TypeInfo, errors: Errors) -> None: """Check that all classes in MRO of a protocol are protocols""" if info.is_protocol: for type in info.bases: - if not type.type.is_protocol and type.type.fullname != 'builtins.object': + if not type.type.is_protocol and type.type.fullname != "builtins.object": + def report(message: str, severity: str) -> None: errors.report(info.line, info.column, message, severity=severity) - report('All bases of a protocol must be protocols', 'error') + + report("All bases of a protocol must be protocols", "error") def calculate_class_vars(info: TypeInfo) -> None: @@ -140,14 +145,13 @@ def calculate_class_vars(info: TypeInfo) -> None: if isinstance(node, Var) and node.info and node.is_inferred and not node.is_classvar: for base in info.mro[1:]: member = base.names.get(name) - if (member is not None - and isinstance(member.node, Var) - and member.node.is_classvar): + if member is not None and isinstance(member.node, Var) and member.node.is_classvar: node.is_classvar = True -def add_type_promotion(info: TypeInfo, module_names: SymbolTable, options: Options, - builtin_names: SymbolTable) -> None: +def add_type_promotion( + info: TypeInfo, module_names: SymbolTable, options: Options, builtin_names: SymbolTable +) -> None: """Setup extra, ad-hoc subtyping relationships between classes (promotion). This includes things like 'int' being compatible with 'float'. @@ -161,8 +165,9 @@ def add_type_promotion(info: TypeInfo, module_names: SymbolTable, options: Optio # _promote class decorator (undocumented feature). promote_targets.append(analyzed.type) if not promote_targets: - promotions = (TYPE_PROMOTIONS_PYTHON3 if options.python_version[0] >= 3 - else TYPE_PROMOTIONS_PYTHON2) + promotions = ( + TYPE_PROMOTIONS_PYTHON3 if options.python_version[0] >= 3 else TYPE_PROMOTIONS_PYTHON2 + ) if defn.fullname in promotions: target_sym = module_names.get(promotions[defn.fullname]) # With test stubs, the target may not exist. @@ -173,8 +178,8 @@ def add_type_promotion(info: TypeInfo, module_names: SymbolTable, options: Optio # Special case the promotions between 'int' and native integer types. # These have promotions going both ways, such as from 'int' to 'i64' # and 'i64' to 'int', for convenience. - if defn.fullname == 'mypy_extensions.i64' or defn.fullname == 'mypy_extensions.i32': - int_sym = builtin_names['int'] + if defn.fullname == "mypy_extensions.i64" or defn.fullname == "mypy_extensions.i32": + int_sym = builtin_names["int"] assert isinstance(int_sym.node, TypeInfo) int_sym.node._promote.append(Instance(defn.info, [])) defn.info.alt_promote = int_sym.node diff --git a/mypy/semanal_enum.py b/mypy/semanal_enum.py index 0f09a4bf94579..2b1481a90ba5e 100644 --- a/mypy/semanal_enum.py +++ b/mypy/semanal_enum.py @@ -3,28 +3,55 @@ This is conceptually part of mypy.semanal (semantic analyzer pass 2). """ -from typing import List, Tuple, Optional, Union, cast +from typing import List, Optional, Tuple, Union, cast + from typing_extensions import Final from mypy.nodes import ( - Expression, Context, TypeInfo, AssignmentStmt, NameExpr, CallExpr, RefExpr, StrExpr, - UnicodeExpr, TupleExpr, ListExpr, DictExpr, Var, SymbolTableNode, MDEF, ARG_POS, - ARG_NAMED, EnumCallExpr, MemberExpr + ARG_NAMED, + ARG_POS, + MDEF, + AssignmentStmt, + CallExpr, + Context, + DictExpr, + EnumCallExpr, + Expression, + ListExpr, + MemberExpr, + NameExpr, + RefExpr, + StrExpr, + SymbolTableNode, + TupleExpr, + TypeInfo, + UnicodeExpr, + Var, ) -from mypy.semanal_shared import SemanticAnalyzerInterface from mypy.options import Options -from mypy.types import get_proper_type, LiteralType, ENUM_REMOVED_PROPS +from mypy.semanal_shared import SemanticAnalyzerInterface +from mypy.types import ENUM_REMOVED_PROPS, LiteralType, get_proper_type # Note: 'enum.EnumMeta' is deliberately excluded from this list. Classes that directly use # enum.EnumMeta do not necessarily automatically have the 'name' and 'value' attributes. -ENUM_BASES: Final = frozenset(( - 'enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag', 'enum.StrEnum', -)) -ENUM_SPECIAL_PROPS: Final = frozenset(( - 'name', 'value', '_name_', '_value_', *ENUM_REMOVED_PROPS, - # Also attributes from `object`: - '__module__', '__annotations__', '__doc__', '__slots__', '__dict__', -)) +ENUM_BASES: Final = frozenset( + ("enum.Enum", "enum.IntEnum", "enum.Flag", "enum.IntFlag", "enum.StrEnum") +) +ENUM_SPECIAL_PROPS: Final = frozenset( + ( + "name", + "value", + "_name_", + "_value_", + *ENUM_REMOVED_PROPS, + # Also attributes from `object`: + "__module__", + "__annotations__", + "__doc__", + "__slots__", + "__dict__", + ) +) class EnumCallAnalyzer: @@ -52,10 +79,9 @@ def process_enum_call(self, s: AssignmentStmt, is_func_scope: bool) -> bool: self.api.add_symbol(name, enum_call, s) return True - def check_enum_call(self, - node: Expression, - var_name: str, - is_func_scope: bool) -> Optional[TypeInfo]: + def check_enum_call( + self, node: Expression, var_name: str, is_func_scope: bool + ) -> Optional[TypeInfo]: """Check if a call defines an Enum. Example: @@ -77,7 +103,7 @@ class A(enum.Enum): fullname = callee.fullname if fullname not in ENUM_BASES: return None - items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) + items, values, ok = self.parse_enum_call_args(call, fullname.split(".")[-1]) if not ok: # Error. Construct dummy return value. info = self.build_enum_call_typeinfo(var_name, [], fullname, node.line) @@ -85,7 +111,7 @@ class A(enum.Enum): name = cast(Union[StrExpr, UnicodeExpr], call.args[0]).value if name != var_name or is_func_scope: # Give it a unique name derived from the line number. - name += '@' + str(call.line) + name += "@" + str(call.line) info = self.build_enum_call_typeinfo(name, items, fullname, call.line) # Store generated TypeInfo under both names, see semanal_namedtuple for more details. if name != var_name or is_func_scope: @@ -95,8 +121,9 @@ class A(enum.Enum): info.line = node.line return info - def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str, - line: int) -> TypeInfo: + def build_enum_call_typeinfo( + self, name: str, items: List[str], fullname: str, line: int + ) -> TypeInfo: base = self.api.named_type_or_none(fullname) assert base is not None info = self.api.basic_new_typeinfo(name, base, line) @@ -106,13 +133,13 @@ def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str, var = Var(item) var.info = info var.is_property = True - var._fullname = f'{info.fullname}.{item}' + var._fullname = f"{info.fullname}.{item}" info.names[item] = SymbolTableNode(MDEF, var) return info - def parse_enum_call_args(self, call: CallExpr, - class_name: str) -> Tuple[List[str], - List[Optional[Expression]], bool]: + def parse_enum_call_args( + self, call: CallExpr, class_name: str + ) -> Tuple[List[str], List[Optional[Expression]], bool]: """Parse arguments of an Enum call. Return a tuple of fields, values, was there an error. @@ -124,15 +151,15 @@ def parse_enum_call_args(self, call: CallExpr, return self.fail_enum_call_arg(f"Too few arguments for {class_name}()", call) if len(args) > 6: return self.fail_enum_call_arg(f"Too many arguments for {class_name}()", call) - valid_name = [None, 'value', 'names', 'module', 'qualname', 'type', 'start'] + valid_name = [None, "value", "names", "module", "qualname", "type", "start"] for arg_name in call.arg_names: if arg_name not in valid_name: self.fail_enum_call_arg(f'Unexpected keyword argument "{arg_name}"', call) value, names = None, None for arg_name, arg in zip(call.arg_names, args): - if arg_name == 'value': + if arg_name == "value": value = arg - if arg_name == 'names': + if arg_name == "names": names = arg if value is None: value = args[0] @@ -140,22 +167,26 @@ def parse_enum_call_args(self, call: CallExpr, names = args[1] if not isinstance(value, (StrExpr, UnicodeExpr)): return self.fail_enum_call_arg( - f"{class_name}() expects a string literal as the first argument", call) + f"{class_name}() expects a string literal as the first argument", call + ) items = [] values: List[Optional[Expression]] = [] if isinstance(names, (StrExpr, UnicodeExpr)): fields = names.value - for field in fields.replace(',', ' ').split(): + for field in fields.replace(",", " ").split(): items.append(field) elif isinstance(names, (TupleExpr, ListExpr)): seq_items = names.items if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): - items = [cast(Union[StrExpr, UnicodeExpr], seq_item).value - for seq_item in seq_items] - elif all(isinstance(seq_item, (TupleExpr, ListExpr)) - and len(seq_item.items) == 2 - and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) - for seq_item in seq_items): + items = [ + cast(Union[StrExpr, UnicodeExpr], seq_item).value for seq_item in seq_items + ] + elif all( + isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items + ): for seq_item in seq_items: assert isinstance(seq_item, (TupleExpr, ListExpr)) name, value = seq_item.items @@ -164,39 +195,44 @@ def parse_enum_call_args(self, call: CallExpr, values.append(value) else: return self.fail_enum_call_arg( - "%s() with tuple or list expects strings or (name, value) pairs" % - class_name, - call) + "%s() with tuple or list expects strings or (name, value) pairs" % class_name, + call, + ) elif isinstance(names, DictExpr): for key, value in names.items: if not isinstance(key, (StrExpr, UnicodeExpr)): return self.fail_enum_call_arg( - f"{class_name}() with dict literal requires string literals", call) + f"{class_name}() with dict literal requires string literals", call + ) items.append(key.value) values.append(value) elif isinstance(args[1], RefExpr) and isinstance(args[1].node, Var): proper_type = get_proper_type(args[1].node.type) - if (proper_type is not None - and isinstance(proper_type, LiteralType) - and isinstance(proper_type.value, str)): + if ( + proper_type is not None + and isinstance(proper_type, LiteralType) + and isinstance(proper_type.value, str) + ): fields = proper_type.value - for field in fields.replace(',', ' ').split(): + for field in fields.replace(",", " ").split(): items.append(field) elif args[1].node.is_final and isinstance(args[1].node.final_value, str): fields = args[1].node.final_value - for field in fields.replace(',', ' ').split(): + for field in fields.replace(",", " ").split(): items.append(field) else: return self.fail_enum_call_arg( - "%s() expects a string, tuple, list or dict literal as the second argument" % - class_name, - call) + "%s() expects a string, tuple, list or dict literal as the second argument" + % class_name, + call, + ) else: # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? return self.fail_enum_call_arg( - "%s() expects a string, tuple, list or dict literal as the second argument" % - class_name, - call) + "%s() expects a string, tuple, list or dict literal as the second argument" + % class_name, + call, + ) if len(items) == 0: return self.fail_enum_call_arg(f"{class_name}() needs at least one item", call) if not values: @@ -204,9 +240,9 @@ def parse_enum_call_args(self, call: CallExpr, assert len(items) == len(values) return items, values, True - def fail_enum_call_arg(self, message: str, - context: Context) -> Tuple[List[str], - List[Optional[Expression]], bool]: + def fail_enum_call_arg( + self, message: str, context: Context + ) -> Tuple[List[str], List[Optional[Expression]], bool]: self.fail(message, context) return [], [], False diff --git a/mypy/semanal_infer.py b/mypy/semanal_infer.py index 73a1077c57886..56b5046451607 100644 --- a/mypy/semanal_infer.py +++ b/mypy/semanal_infer.py @@ -2,17 +2,24 @@ from typing import Optional -from mypy.nodes import Expression, Decorator, CallExpr, FuncDef, RefExpr, Var, ARG_POS +from mypy.nodes import ARG_POS, CallExpr, Decorator, Expression, FuncDef, RefExpr, Var +from mypy.semanal_shared import SemanticAnalyzerInterface +from mypy.typeops import function_type from mypy.types import ( - Type, CallableType, AnyType, TypeOfAny, TypeVarType, ProperType, get_proper_type + AnyType, + CallableType, + ProperType, + Type, + TypeOfAny, + TypeVarType, + get_proper_type, ) -from mypy.typeops import function_type from mypy.typevars import has_no_typevars -from mypy.semanal_shared import SemanticAnalyzerInterface -def infer_decorator_signature_if_simple(dec: Decorator, - analyzer: SemanticAnalyzerInterface) -> None: +def infer_decorator_signature_if_simple( + dec: Decorator, analyzer: SemanticAnalyzerInterface +) -> None: """Try to infer the type of the decorated function. This lets us resolve additional references to decorated functions @@ -30,8 +37,9 @@ def infer_decorator_signature_if_simple(dec: Decorator, [ARG_POS], [None], AnyType(TypeOfAny.special_form), - analyzer.named_type('builtins.function'), - name=dec.var.name) + analyzer.named_type("builtins.function"), + name=dec.var.name, + ) elif isinstance(dec.func.type, CallableType): dec.var.type = dec.func.type return @@ -47,7 +55,7 @@ def infer_decorator_signature_if_simple(dec: Decorator, if decorator_preserves_type: # No non-identity decorators left. We can trivially infer the type # of the function here. - dec.var.type = function_type(dec.func, analyzer.named_type('builtins.function')) + dec.var.type = function_type(dec.func, analyzer.named_type("builtins.function")) if dec.decorators: return_type = calculate_return_type(dec.decorators[0]) if return_type and isinstance(return_type, AnyType): @@ -58,7 +66,7 @@ def infer_decorator_signature_if_simple(dec: Decorator, if sig: # The outermost decorator always returns the same kind of function, # so we know that this is the type of the decorated function. - orig_sig = function_type(dec.func, analyzer.named_type('builtins.function')) + orig_sig = function_type(dec.func, analyzer.named_type("builtins.function")) sig.name = orig_sig.items[0].name dec.var.type = sig diff --git a/mypy/semanal_main.py b/mypy/semanal_main.py index b25aa0e225a6d..e593960717b02 100644 --- a/mypy/semanal_main.py +++ b/mypy/semanal_main.py @@ -24,30 +24,33 @@ will be incomplete. """ -from typing import List, Tuple, Optional, Union, Callable +from typing import Callable, List, Optional, Tuple, Union + from typing_extensions import TYPE_CHECKING, Final, TypeAlias as _TypeAlias -from mypy.backports import nullcontext -from mypy.nodes import ( - MypyFile, TypeInfo, FuncDef, Decorator, OverloadedFuncDef, Var -) -from mypy.semanal_typeargs import TypeArgumentAnalyzer +import mypy.build import mypy.state +from mypy.backports import nullcontext +from mypy.checker import FineGrainedDeferredNode +from mypy.errors import Errors +from mypy.nodes import Decorator, FuncDef, MypyFile, OverloadedFuncDef, TypeInfo, Var +from mypy.options import Options +from mypy.plugin import ClassDefContext from mypy.semanal import ( - SemanticAnalyzer, apply_semantic_analyzer_patches, remove_imported_names_from_symtable + SemanticAnalyzer, + apply_semantic_analyzer_patches, + remove_imported_names_from_symtable, ) from mypy.semanal_classprop import ( - calculate_class_abstract_status, calculate_class_vars, check_protocol_status, - add_type_promotion + add_type_promotion, + calculate_class_abstract_status, + calculate_class_vars, + check_protocol_status, ) -from mypy.errors import Errors from mypy.semanal_infer import infer_decorator_signature_if_simple -from mypy.checker import FineGrainedDeferredNode +from mypy.semanal_typeargs import TypeArgumentAnalyzer from mypy.server.aststrip import SavedAttributes from mypy.util import is_typeshed_file -from mypy.options import Options -from mypy.plugin import ClassDefContext -import mypy.build if TYPE_CHECKING: from mypy.build import Graph, State @@ -62,10 +65,10 @@ # Number of passes over core modules before going on to the rest of the builtin SCC. CORE_WARMUP: Final = 2 -core_modules: Final = ['typing', 'builtins', 'abc', 'collections'] +core_modules: Final = ["typing", "builtins", "abc", "collections"] -def semantic_analysis_for_scc(graph: 'Graph', scc: List[str], errors: Errors) -> None: +def semantic_analysis_for_scc(graph: "Graph", scc: List[str], errors: Errors) -> None: """Perform semantic analysis for all modules in a SCC (import cycle). Assume that reachability analysis has already been performed. @@ -89,11 +92,11 @@ def semantic_analysis_for_scc(graph: 'Graph', scc: List[str], errors: Errors) -> calculate_class_properties(graph, scc, errors) check_blockers(graph, scc) # Clean-up builtins, so that TypeVar etc. are not accessible without importing. - if 'builtins' in scc: - cleanup_builtin_scc(graph['builtins']) + if "builtins" in scc: + cleanup_builtin_scc(graph["builtins"]) -def cleanup_builtin_scc(state: 'State') -> None: +def cleanup_builtin_scc(state: "State") -> None: """Remove imported names from builtins namespace. This way names imported from typing in builtins.pyi aren't available @@ -102,14 +105,15 @@ def cleanup_builtin_scc(state: 'State') -> None: processing builtins.pyi itself. """ assert state.tree is not None - remove_imported_names_from_symtable(state.tree.names, 'builtins') + remove_imported_names_from_symtable(state.tree.names, "builtins") def semantic_analysis_for_targets( - state: 'State', - nodes: List[FineGrainedDeferredNode], - graph: 'Graph', - saved_attrs: SavedAttributes) -> None: + state: "State", + nodes: List[FineGrainedDeferredNode], + graph: "Graph", + saved_attrs: SavedAttributes, +) -> None: """Semantically analyze only selected nodes in a given module. This essentially mirrors the logic of semantic_analysis_for_scc() @@ -130,8 +134,9 @@ def semantic_analysis_for_targets( if isinstance(n.node, MypyFile): # Already done above. continue - process_top_level_function(analyzer, state, state.id, - n.node.fullname, n.node, n.active_typeinfo, patches) + process_top_level_function( + analyzer, state, state.id, n.node.fullname, n.node, n.active_typeinfo, patches + ) apply_semantic_analyzer_patches(patches) apply_class_plugin_hooks(graph, [state.id], state.manager.errors) check_type_arguments_in_targets(nodes, state, state.manager.errors) @@ -148,16 +153,21 @@ def restore_saved_attrs(saved_attrs: SavedAttributes) -> None: # This needs to mimic the logic in SemanticAnalyzer.analyze_member_lvalue() # regarding the existing variable in class body or in a superclass: # If the attribute of self is not defined in superclasses, create a new Var. - if (existing is None or - # (An abstract Var is considered as not defined.) - (isinstance(existing.node, Var) and existing.node.is_abstract_var) or - # Also an explicit declaration on self creates a new Var unless - # there is already one defined in the class body. - sym.node.explicit_self_type and not defined_in_this_class): + if ( + existing is None + or + # (An abstract Var is considered as not defined.) + (isinstance(existing.node, Var) and existing.node.is_abstract_var) + or + # Also an explicit declaration on self creates a new Var unless + # there is already one defined in the class body. + sym.node.explicit_self_type + and not defined_in_this_class + ): info.names[name] = sym -def process_top_levels(graph: 'Graph', scc: List[str], patches: Patches) -> None: +def process_top_levels(graph: "Graph", scc: List[str], patches: Patches) -> None: # Process top levels until everything has been bound. # Reverse order of the scc so the first modules in the original list will be @@ -200,24 +210,22 @@ def process_top_levels(graph: 'Graph', scc: List[str], patches: Patches) -> None next_id = worklist.pop() state = graph[next_id] assert state.tree is not None - deferred, incomplete, progress = semantic_analyze_target(next_id, state, - state.tree, - None, - final_iteration, - patches) + deferred, incomplete, progress = semantic_analyze_target( + next_id, state, state.tree, None, final_iteration, patches + ) all_deferred += deferred any_progress = any_progress or progress if not incomplete: state.manager.incomplete_namespaces.discard(next_id) if final_iteration: - assert not all_deferred, 'Must not defer during final iteration' + assert not all_deferred, "Must not defer during final iteration" # Reverse to process the targets in the same order on every iteration. This avoids # processing the same target twice in a row, which is inefficient. worklist = list(reversed(all_deferred)) final_iteration = not any_progress -def process_functions(graph: 'Graph', scc: List[str], patches: Patches) -> None: +def process_functions(graph: "Graph", scc: List[str], patches: Patches) -> None: # Process functions. for module in scc: tree = graph[module].tree @@ -234,22 +242,20 @@ def process_functions(graph: 'Graph', scc: List[str], patches: Patches) -> None: targets = sorted(get_all_leaf_targets(tree), key=lambda x: (x[1].line, x[0])) for target, node, active_type in targets: assert isinstance(node, (FuncDef, OverloadedFuncDef, Decorator)) - process_top_level_function(analyzer, - graph[module], - module, - target, - node, - active_type, - patches) - - -def process_top_level_function(analyzer: 'SemanticAnalyzer', - state: 'State', - module: str, - target: str, - node: Union[FuncDef, OverloadedFuncDef, Decorator], - active_type: Optional[TypeInfo], - patches: Patches) -> None: + process_top_level_function( + analyzer, graph[module], module, target, node, active_type, patches + ) + + +def process_top_level_function( + analyzer: "SemanticAnalyzer", + state: "State", + module: str, + target: str, + node: Union[FuncDef, OverloadedFuncDef, Decorator], + active_type: Optional[TypeInfo], + patches: Patches, +) -> None: """Analyze single top-level function or method. Process the body of the function (including nested functions) again and again, @@ -275,10 +281,11 @@ def process_top_level_function(analyzer: 'SemanticAnalyzer', if not (deferred or incomplete) or final_iteration: # OK, this is one last pass, now missing names will be reported. analyzer.incomplete_namespaces.discard(module) - deferred, incomplete, progress = semantic_analyze_target(target, state, node, active_type, - final_iteration, patches) + deferred, incomplete, progress = semantic_analyze_target( + target, state, node, active_type, final_iteration, patches + ) if final_iteration: - assert not deferred, 'Must not defer during final iteration' + assert not deferred, "Must not defer during final iteration" if not progress: final_iteration = True @@ -300,12 +307,14 @@ def get_all_leaf_targets(file: MypyFile) -> List[TargetInfo]: return result -def semantic_analyze_target(target: str, - state: 'State', - node: Union[MypyFile, FuncDef, OverloadedFuncDef, Decorator], - active_type: Optional[TypeInfo], - final_iteration: bool, - patches: Patches) -> Tuple[List[str], bool, bool]: +def semantic_analyze_target( + target: str, + state: "State", + node: Union[MypyFile, FuncDef, OverloadedFuncDef, Decorator], + active_type: Optional[TypeInfo], + final_iteration: bool, + patches: Patches, +) -> Tuple[List[str], bool, bool]: """Semantically analyze a single target. Return tuple with these items: @@ -327,12 +336,14 @@ def semantic_analyze_target(target: str, if isinstance(refresh_node, Decorator): # Decorator expressions will be processed as part of the module top level. refresh_node = refresh_node.func - analyzer.refresh_partial(refresh_node, - patches, - final_iteration, - file_node=tree, - options=state.options, - active_type=active_type) + analyzer.refresh_partial( + refresh_node, + patches, + final_iteration, + file_node=tree, + options=state.options, + active_type=active_type, + ) if isinstance(node, Decorator): infer_decorator_signature_if_simple(node, analyzer) for dep in analyzer.imports: @@ -352,28 +363,25 @@ def semantic_analyze_target(target: str, return [], analyzer.incomplete, analyzer.progress -def check_type_arguments(graph: 'Graph', scc: List[str], errors: Errors) -> None: +def check_type_arguments(graph: "Graph", scc: List[str], errors: Errors) -> None: for module in scc: state = graph[module] assert state.tree - analyzer = TypeArgumentAnalyzer(errors, - state.options, - is_typeshed_file(state.path or '')) + analyzer = TypeArgumentAnalyzer(errors, state.options, is_typeshed_file(state.path or "")) with state.wrap_context(): with mypy.state.state.strict_optional_set(state.options.strict_optional): state.tree.accept(analyzer) -def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], state: 'State', - errors: Errors) -> None: +def check_type_arguments_in_targets( + targets: List[FineGrainedDeferredNode], state: "State", errors: Errors +) -> None: """Check type arguments against type variable bounds and restrictions. This mirrors the logic in check_type_arguments() except that we process only some targets. This is used in fine grained incremental mode. """ - analyzer = TypeArgumentAnalyzer(errors, - state.options, - is_typeshed_file(state.path or '')) + analyzer = TypeArgumentAnalyzer(errors, state.options, is_typeshed_file(state.path or "")) with state.wrap_context(): with mypy.state.state.strict_optional_set(state.options.strict_optional): for target in targets: @@ -386,7 +394,7 @@ def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], stat target.node.accept(analyzer) -def apply_class_plugin_hooks(graph: 'Graph', scc: List[str], errors: Errors) -> None: +def apply_class_plugin_hooks(graph: "Graph", scc: List[str], errors: Errors) -> None: """Apply class plugin hooks within a SCC. We run these after to the main semantic analysis so that the hooks @@ -410,17 +418,25 @@ def apply_class_plugin_hooks(graph: 'Graph', scc: List[str], errors: Errors) -> assert tree for _, node, _ in tree.local_definitions(): if isinstance(node.node, TypeInfo): - if not apply_hooks_to_class(state.manager.semantic_analyzer, - module, node.node, state.options, tree, errors): + if not apply_hooks_to_class( + state.manager.semantic_analyzer, + module, + node.node, + state.options, + tree, + errors, + ): incomplete = True -def apply_hooks_to_class(self: SemanticAnalyzer, - module: str, - info: TypeInfo, - options: Options, - file_node: MypyFile, - errors: Errors) -> bool: +def apply_hooks_to_class( + self: SemanticAnalyzer, + module: str, + info: TypeInfo, + options: Options, + file_node: MypyFile, + errors: Errors, +) -> bool: # TODO: Move more class-related hooks here? defn = info.defn ok = True @@ -434,8 +450,8 @@ def apply_hooks_to_class(self: SemanticAnalyzer, return ok -def calculate_class_properties(graph: 'Graph', scc: List[str], errors: Errors) -> None: - builtins = graph['builtins'].tree +def calculate_class_properties(graph: "Graph", scc: List[str], errors: Errors) -> None: + builtins = graph["builtins"].tree assert builtins for module in scc: state = graph[module] @@ -447,10 +463,11 @@ def calculate_class_properties(graph: 'Graph', scc: List[str], errors: Errors) - calculate_class_abstract_status(node.node, tree.is_stub, errors) check_protocol_status(node.node, errors) calculate_class_vars(node.node) - add_type_promotion(node.node, tree.names, graph[module].options, - builtins.names) + add_type_promotion( + node.node, tree.names, graph[module].options, builtins.names + ) -def check_blockers(graph: 'Graph', scc: List[str]) -> None: +def check_blockers(graph: "Graph", scc: List[str]) -> None: for module in scc: graph[module].check_blockers() diff --git a/mypy/semanal_namedtuple.py b/mypy/semanal_namedtuple.py index e63be53878108..fb7e2e532398c 100644 --- a/mypy/semanal_namedtuple.py +++ b/mypy/semanal_namedtuple.py @@ -4,24 +4,63 @@ """ from contextlib import contextmanager -from typing import Tuple, List, Dict, Mapping, Optional, Union, cast, Iterator +from typing import Dict, Iterator, List, Mapping, Optional, Tuple, Union, cast + from typing_extensions import Final -from mypy.types import ( - Type, TupleType, AnyType, TypeOfAny, CallableType, TypeType, TypeVarType, - UnboundType, LiteralType, TYPED_NAMEDTUPLE_NAMES +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.nodes import ( + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + MDEF, + Argument, + AssignmentStmt, + Block, + BytesExpr, + CallExpr, + ClassDef, + Context, + Decorator, + EllipsisExpr, + Expression, + ExpressionStmt, + FuncBase, + FuncDef, + ListExpr, + NamedTupleExpr, + NameExpr, + PassStmt, + RefExpr, + StrExpr, + SymbolTable, + SymbolTableNode, + TempNode, + TupleExpr, + TypeInfo, + TypeVarExpr, + UnicodeExpr, + Var, ) +from mypy.options import Options from mypy.semanal_shared import ( - SemanticAnalyzerInterface, set_callable_name, calculate_tuple_fallback, PRIORITY_FALLBACKS + PRIORITY_FALLBACKS, + SemanticAnalyzerInterface, + calculate_tuple_fallback, + set_callable_name, ) -from mypy.nodes import ( - Var, EllipsisExpr, Argument, StrExpr, BytesExpr, UnicodeExpr, ExpressionStmt, NameExpr, - AssignmentStmt, PassStmt, Decorator, FuncBase, ClassDef, Expression, RefExpr, TypeInfo, - NamedTupleExpr, CallExpr, Context, TupleExpr, ListExpr, SymbolTableNode, FuncDef, Block, - TempNode, SymbolTable, TypeVarExpr, ARG_POS, ARG_NAMED_OPT, ARG_OPT, MDEF +from mypy.types import ( + TYPED_NAMEDTUPLE_NAMES, + AnyType, + CallableType, + LiteralType, + TupleType, + Type, + TypeOfAny, + TypeType, + TypeVarType, + UnboundType, ) -from mypy.options import Options -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.util import get_unique_redefinition_name # Matches "_prohibited" in typing.py, but adds __annotations__, which works at runtime but can't @@ -53,9 +92,9 @@ def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None: self.options = options self.api = api - def analyze_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool, - is_func_scope: bool - ) -> Tuple[bool, Optional[TypeInfo]]: + def analyze_namedtuple_classdef( + self, defn: ClassDef, is_stub_file: bool, is_func_scope: bool + ) -> Tuple[bool, Optional[TypeInfo]]: """Analyze if given class definition can be a named tuple definition. Return a tuple where first item indicates whether this can possibly be a named tuple, @@ -71,10 +110,11 @@ def analyze_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool, # This is a valid named tuple, but some types are incomplete. return True, None items, types, default_items = result - if is_func_scope and '@' not in defn.name: - defn.name += '@' + str(defn.line) + if is_func_scope and "@" not in defn.name: + defn.name += "@" + str(defn.line) info = self.build_namedtuple_typeinfo( - defn.name, items, types, default_items, defn.line) + defn.name, items, types, default_items, defn.line + ) defn.info = info defn.analyzed = NamedTupleExpr(info, is_typed=True) defn.analyzed.line = defn.line @@ -84,10 +124,9 @@ def analyze_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool, # This can't be a valid named tuple. return False, None - def check_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool - ) -> Optional[Tuple[List[str], - List[Type], - Dict[str, Expression]]]: + def check_namedtuple_classdef( + self, defn: ClassDef, is_stub_file: bool + ) -> Optional[Tuple[List[str], List[Type], Dict[str, Expression]]]: """Parse and validate fields in named tuple class definition. Return a three tuple: @@ -97,26 +136,25 @@ def check_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool or None, if any of the types are not ready. """ if self.options.python_version < (3, 6) and not is_stub_file: - self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) + self.fail("NamedTuple class syntax is only supported in Python 3.6", defn) return [], [], {} if len(defn.base_type_exprs) > 1: - self.fail('NamedTuple should be a single base', defn) + self.fail("NamedTuple should be a single base", defn) items: List[str] = [] types: List[Type] = [] default_items: Dict[str, Expression] = {} for stmt in defn.defs.body: if not isinstance(stmt, AssignmentStmt): # Still allow pass or ... (for empty namedtuples). - if (isinstance(stmt, PassStmt) or - (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): + if isinstance(stmt, PassStmt) or ( + isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr) + ): continue # Also allow methods, including decorated ones. if isinstance(stmt, (Decorator, FuncBase)): continue # And docstrings. - if (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, StrExpr)): + if isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr): continue self.fail(NAMEDTUP_CLASS_ERROR, stmt) elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): @@ -135,24 +173,26 @@ def check_namedtuple_classdef(self, defn: ClassDef, is_stub_file: bool return None types.append(analyzed) # ...despite possible minor failures that allow further analyzis. - if name.startswith('_'): - self.fail('NamedTuple field name cannot start with an underscore: {}' - .format(name), stmt) - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + if name.startswith("_"): + self.fail( + "NamedTuple field name cannot start with an underscore: {}".format(name), + stmt, + ) + if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax: self.fail(NAMEDTUP_CLASS_ERROR, stmt) elif isinstance(stmt.rvalue, TempNode): # x: int assigns rvalue to TempNode(AnyType()) if default_items: - self.fail('Non-default NamedTuple fields cannot follow default fields', - stmt) + self.fail( + "Non-default NamedTuple fields cannot follow default fields", stmt + ) else: default_items[name] = stmt.rvalue return items, types, default_items - def check_namedtuple(self, - node: Expression, - var_name: Optional[str], - is_func_scope: bool) -> Tuple[Optional[str], Optional[TypeInfo]]: + def check_namedtuple( + self, node: Expression, var_name: Optional[str], is_func_scope: bool + ) -> Tuple[Optional[str], Optional[TypeInfo]]: """Check if a call defines a namedtuple. The optional var_name argument is the name of the variable to @@ -173,7 +213,7 @@ def check_namedtuple(self, if not isinstance(callee, RefExpr): return None, None fullname = callee.fullname - if fullname == 'collections.namedtuple': + if fullname == "collections.namedtuple": is_typed = False elif fullname in TYPED_NAMEDTUPLE_NAMES: is_typed = True @@ -187,9 +227,9 @@ def check_namedtuple(self, if var_name: name = var_name if is_func_scope: - name += '@' + str(call.line) + name += "@" + str(call.line) else: - name = var_name = 'namedtuple@' + str(call.line) + name = var_name = "namedtuple@" + str(call.line) info = self.build_namedtuple_typeinfo(name, [], [], {}, node.line) self.store_namedtuple_info(info, var_name, call, is_typed) if name != var_name or is_func_scope: @@ -219,11 +259,10 @@ def check_namedtuple(self, # * This is a local (function or method level) named tuple, since # two methods of a class can define a named tuple with the same name, # and they will be stored in the same namespace (see below). - name += '@' + str(call.line) + name += "@" + str(call.line) if len(defaults) > 0: default_items = { - arg_name: default - for arg_name, default in zip(items[-len(defaults):], defaults) + arg_name: default for arg_name, default in zip(items[-len(defaults) :], defaults) } else: default_items = {} @@ -250,15 +289,16 @@ def check_namedtuple(self, self.api.add_symbol_skip_local(name, info) return typename, info - def store_namedtuple_info(self, info: TypeInfo, name: str, - call: CallExpr, is_typed: bool) -> None: + def store_namedtuple_info( + self, info: TypeInfo, name: str, call: CallExpr, is_typed: bool + ) -> None: self.api.add_symbol(name, info, call) call.analyzed = NamedTupleExpr(info, is_typed=is_typed) call.analyzed.set_line(call.line, call.column) - def parse_namedtuple_args(self, call: CallExpr, fullname: str - ) -> Optional[Tuple[List[str], List[Type], List[Expression], - str, bool]]: + def parse_namedtuple_args( + self, call: CallExpr, fullname: str + ) -> Optional[Tuple[List[str], List[Type], List[Expression], str, bool]]: """Parse a namedtuple() call into data needed to construct a type. Returns a 5-tuple: @@ -270,7 +310,7 @@ def parse_namedtuple_args(self, call: CallExpr, fullname: str Return None if the definition didn't typecheck. """ - type_name = 'NamedTuple' if fullname in TYPED_NAMEDTUPLE_NAMES else 'namedtuple' + type_name = "NamedTuple" if fullname in TYPED_NAMEDTUPLE_NAMES else "namedtuple" # TODO: Share code with check_argument_count in checkexpr.py? args = call.args if len(args) < 2: @@ -283,7 +323,7 @@ def parse_namedtuple_args(self, call: CallExpr, fullname: str self.fail('Too many arguments for "NamedTuple()"', call) return None for i, arg_name in enumerate(call.arg_names[2:], 2): - if arg_name == 'defaults': + if arg_name == "defaults": arg = args[i] # We don't care what the values are, as long as the argument is an iterable # and we can count how many defaults there are. @@ -293,41 +333,45 @@ def parse_namedtuple_args(self, call: CallExpr, fullname: str self.fail( "List or tuple literal expected as the defaults argument to " "{}()".format(type_name), - arg + arg, ) break if call.arg_kinds[:2] != [ARG_POS, ARG_POS]: self.fail(f'Unexpected arguments to "{type_name}()"', call) return None if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - self.fail( - f'"{type_name}()" expects a string literal as the first argument', call) + self.fail(f'"{type_name}()" expects a string literal as the first argument', call) return None typename = cast(Union[StrExpr, BytesExpr, UnicodeExpr], call.args[0]).value types: List[Type] = [] if not isinstance(args[1], (ListExpr, TupleExpr)): - if (fullname == 'collections.namedtuple' - and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): + if fullname == "collections.namedtuple" and isinstance( + args[1], (StrExpr, BytesExpr, UnicodeExpr) + ): str_expr = args[1] - items = str_expr.value.replace(',', ' ').split() + items = str_expr.value.replace(",", " ").split() else: self.fail( 'List or tuple literal expected as the second argument to "{}()"'.format( - type_name, + type_name ), call, ) return None else: listexpr = args[1] - if fullname == 'collections.namedtuple': + if fullname == "collections.namedtuple": # The fields argument contains just names, with implicit Any types. - if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) - for item in listexpr.items): + if any( + not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) + for item in listexpr.items + ): self.fail('String literal expected as "namedtuple()" item', call) return None - items = [cast(Union[StrExpr, BytesExpr, UnicodeExpr], item).value - for item in listexpr.items] + items = [ + cast(Union[StrExpr, BytesExpr, UnicodeExpr], item).value + for item in listexpr.items + ] else: # The fields argument contains (name, type) tuples. result = self.parse_namedtuple_fields_with_types(listexpr.items, call) @@ -339,18 +383,21 @@ def parse_namedtuple_args(self, call: CallExpr, fullname: str return [], [], [], typename, False if not types: types = [AnyType(TypeOfAny.unannotated) for _ in items] - underscore = [item for item in items if item.startswith('_')] + underscore = [item for item in items if item.startswith("_")] if underscore: - self.fail(f'"{type_name}()" field names cannot start with an underscore: ' - + ', '.join(underscore), call) + self.fail( + f'"{type_name}()" field names cannot start with an underscore: ' + + ", ".join(underscore), + call, + ) if len(defaults) > len(items): self.fail(f'Too many defaults given in call to "{type_name}()"', call) - defaults = defaults[:len(items)] + defaults = defaults[: len(items)] return items, types, defaults, typename, True - def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: Context - ) -> Optional[Tuple[List[str], List[Type], - List[Expression], bool]]: + def parse_namedtuple_fields_with_types( + self, nodes: List[Expression], context: Context + ) -> Optional[Tuple[List[str], List[Type], List[Expression], bool]]: """Parse typed named tuple fields. Return (names, types, defaults, whether types are all ready), or None if error occurred. @@ -371,7 +418,7 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C try: type = expr_to_unanalyzed_type(type_node, self.options, self.api.is_stub_file) except TypeTranslationError: - self.fail('Invalid field type', type_node) + self.fail("Invalid field type", type_node) return None analyzed = self.api.anal_type(type) # Workaround #4987 and avoid introducing a bogus UnboundType @@ -386,25 +433,29 @@ def parse_namedtuple_fields_with_types(self, nodes: List[Expression], context: C return None return items, types, [], True - def build_namedtuple_typeinfo(self, - name: str, - items: List[str], - types: List[Type], - default_items: Mapping[str, Expression], - line: int) -> TypeInfo: - strtype = self.api.named_type('builtins.str') + def build_namedtuple_typeinfo( + self, + name: str, + items: List[str], + types: List[Type], + default_items: Mapping[str, Expression], + line: int, + ) -> TypeInfo: + strtype = self.api.named_type("builtins.str") implicit_any = AnyType(TypeOfAny.special_form) - basetuple_type = self.api.named_type('builtins.tuple', [implicit_any]) - dictype = (self.api.named_type_or_none('builtins.dict', [strtype, implicit_any]) - or self.api.named_type('builtins.object')) + basetuple_type = self.api.named_type("builtins.tuple", [implicit_any]) + dictype = self.api.named_type_or_none( + "builtins.dict", [strtype, implicit_any] + ) or self.api.named_type("builtins.object") # Actual signature should return OrderedDict[str, Union[types]] - ordereddictype = (self.api.named_type_or_none('builtins.dict', [strtype, implicit_any]) - or self.api.named_type('builtins.object')) - fallback = self.api.named_type('builtins.tuple', [implicit_any]) + ordereddictype = self.api.named_type_or_none( + "builtins.dict", [strtype, implicit_any] + ) or self.api.named_type("builtins.object") + fallback = self.api.named_type("builtins.tuple", [implicit_any]) # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. # but it can't be expressed. 'new' and 'len' should be callable types. - iterable_type = self.api.named_type_or_none('typing.Iterable', [implicit_any]) - function_type = self.api.named_type('builtins.function') + iterable_type = self.api.named_type_or_none("typing.Iterable", [implicit_any]) + function_type = self.api.named_type("builtins.function") literals: List[Type] = [LiteralType(item, strtype) for item in items] match_args_type = TupleType(literals, basetuple_type) @@ -415,20 +466,20 @@ def build_namedtuple_typeinfo(self, info.tuple_type = tuple_base info.line = line # For use by mypyc. - info.metadata['namedtuple'] = {'fields': items.copy()} + info.metadata["namedtuple"] = {"fields": items.copy()} # We can't calculate the complete fallback type until after semantic # analysis, since otherwise base classes might be incomplete. Postpone a # callback function that patches the fallback. - self.api.schedule_patch(PRIORITY_FALLBACKS, - lambda: calculate_tuple_fallback(tuple_base)) + self.api.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base)) - def add_field(var: Var, is_initialized_in_class: bool = False, - is_property: bool = False) -> None: + def add_field( + var: Var, is_initialized_in_class: bool = False, is_property: bool = False + ) -> None: var.info = info var.is_initialized_in_class = is_initialized_in_class var.is_property = is_property - var._fullname = f'{info.fullname}.{var.name}' + var._fullname = f"{info.fullname}.{var.name}" info.names[var.name] = SymbolTableNode(MDEF, var) fields = [Var(item, typ) for item, typ in zip(items, types)] @@ -441,43 +492,44 @@ def add_field(var: Var, is_initialized_in_class: bool = False, vars = [Var(item, typ) for item, typ in zip(items, types)] tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) - add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) - add_field(Var('_field_types', dictype), is_initialized_in_class=True) - add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) - add_field(Var('_source', strtype), is_initialized_in_class=True) - add_field(Var('__annotations__', ordereddictype), is_initialized_in_class=True) - add_field(Var('__doc__', strtype), is_initialized_in_class=True) + add_field(Var("_fields", tuple_of_strings), is_initialized_in_class=True) + add_field(Var("_field_types", dictype), is_initialized_in_class=True) + add_field(Var("_field_defaults", dictype), is_initialized_in_class=True) + add_field(Var("_source", strtype), is_initialized_in_class=True) + add_field(Var("__annotations__", ordereddictype), is_initialized_in_class=True) + add_field(Var("__doc__", strtype), is_initialized_in_class=True) if self.options.python_version >= (3, 10): - add_field(Var('__match_args__', match_args_type), is_initialized_in_class=True) + add_field(Var("__match_args__", match_args_type), is_initialized_in_class=True) - tvd = TypeVarType(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, - -1, [], info.tuple_type) + tvd = TypeVarType( + SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], info.tuple_type + ) selftype = tvd - def add_method(funcname: str, - ret: Type, - args: List[Argument], - is_classmethod: bool = False, - is_new: bool = False, - ) -> None: + def add_method( + funcname: str, + ret: Type, + args: List[Argument], + is_classmethod: bool = False, + is_new: bool = False, + ) -> None: if is_classmethod or is_new: - first = [Argument(Var('_cls'), TypeType.make_normalized(selftype), None, ARG_POS)] + first = [Argument(Var("_cls"), TypeType.make_normalized(selftype), None, ARG_POS)] else: - first = [Argument(Var('_self'), selftype, None, ARG_POS)] + first = [Argument(Var("_self"), selftype, None, ARG_POS)] args = first + args types = [arg.type_annotation for arg in args] items = [arg.variable.name for arg in args] arg_kinds = [arg.kind for arg in args] assert None not in types - signature = CallableType(cast(List[Type], types), arg_kinds, items, ret, - function_type) + signature = CallableType(cast(List[Type], types), arg_kinds, items, ret, function_type) signature.variables = [tvd] func = FuncDef(funcname, args, Block([])) func.info = info func.is_class = is_classmethod func.type = set_callable_name(signature, func) - func._fullname = info.fullname + '.' + funcname + func._fullname = info.fullname + "." + funcname func.line = line if is_classmethod: v = Var(funcname, func.type) @@ -485,7 +537,7 @@ def add_method(funcname: str, v.info = info v._fullname = func._fullname func.is_decorated = True - dec = Decorator(func, [NameExpr('classmethod')], v) + dec = Decorator(func, [NameExpr("classmethod")], v) dec.line = line sym = SymbolTableNode(MDEF, dec) else: @@ -493,26 +545,34 @@ def add_method(funcname: str, sym.plugin_generated = True info.names[funcname] = sym - add_method('_replace', ret=selftype, - args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) + add_method( + "_replace", + ret=selftype, + args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars], + ) def make_init_arg(var: Var) -> Argument: default = default_items.get(var.name, None) kind = ARG_POS if default is None else ARG_OPT return Argument(var, var.type, default, kind) - add_method('__new__', ret=selftype, - args=[make_init_arg(var) for var in vars], - is_new=True) - add_method('_asdict', args=[], ret=ordereddictype) + add_method("__new__", ret=selftype, args=[make_init_arg(var) for var in vars], is_new=True) + add_method("_asdict", args=[], ret=ordereddictype) special_form_any = AnyType(TypeOfAny.special_form) - add_method('_make', ret=selftype, is_classmethod=True, - args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), - Argument(Var('new'), special_form_any, EllipsisExpr(), ARG_NAMED_OPT), - Argument(Var('len'), special_form_any, EllipsisExpr(), ARG_NAMED_OPT)]) - - self_tvar_expr = TypeVarExpr(SELF_TVAR_NAME, info.fullname + '.' + SELF_TVAR_NAME, - [], info.tuple_type) + add_method( + "_make", + ret=selftype, + is_classmethod=True, + args=[ + Argument(Var("iterable", iterable_type), iterable_type, None, ARG_POS), + Argument(Var("new"), special_form_any, EllipsisExpr(), ARG_NAMED_OPT), + Argument(Var("len"), special_form_any, EllipsisExpr(), ARG_NAMED_OPT), + ], + ) + + self_tvar_expr = TypeVarExpr( + SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], info.tuple_type + ) info.names[SELF_TVAR_NAME] = SymbolTableNode(MDEF, self_tvar_expr) return info @@ -537,15 +597,14 @@ def save_namedtuple_body(self, named_tuple_info: TypeInfo) -> Iterator[None]: continue ctx = named_tuple_info.names[prohibited].node assert ctx is not None - self.fail(f'Cannot overwrite NamedTuple attribute "{prohibited}"', - ctx) + self.fail(f'Cannot overwrite NamedTuple attribute "{prohibited}"', ctx) # Restore the names in the original symbol table. This ensures that the symbol # table contains the field objects created by build_namedtuple_typeinfo. Exclude # __doc__, which can legally be overwritten by the class. for key, value in nt_names.items(): if key in named_tuple_info.names: - if key == '__doc__': + if key == "__doc__": continue sym = named_tuple_info.names[key] if isinstance(sym.node, (FuncBase, Decorator)) and not sym.plugin_generated: diff --git a/mypy/semanal_newtype.py b/mypy/semanal_newtype.py index 948c5b36052f1..f59b8b6f62707 100644 --- a/mypy/semanal_newtype.py +++ b/mypy/semanal_newtype.py @@ -3,31 +3,52 @@ This is conceptually part of mypy.semanal (semantic analyzer pass 2). """ -from typing import Tuple, Optional +from typing import Optional, Tuple -from mypy.types import ( - Type, Instance, CallableType, NoneType, TupleType, AnyType, PlaceholderType, - TypeOfAny, get_proper_type -) +from mypy import errorcodes as codes +from mypy.errorcodes import ErrorCode +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.messages import MessageBuilder, format_type from mypy.nodes import ( - AssignmentStmt, NewTypeExpr, CallExpr, NameExpr, RefExpr, Context, StrExpr, BytesExpr, - UnicodeExpr, Block, FuncDef, Argument, TypeInfo, Var, SymbolTableNode, MDEF, ARG_POS, - PlaceholderNode + ARG_POS, + MDEF, + Argument, + AssignmentStmt, + Block, + BytesExpr, + CallExpr, + Context, + FuncDef, + NameExpr, + NewTypeExpr, + PlaceholderNode, + RefExpr, + StrExpr, + SymbolTableNode, + TypeInfo, + UnicodeExpr, + Var, ) -from mypy.semanal_shared import SemanticAnalyzerInterface from mypy.options import Options -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError +from mypy.semanal_shared import SemanticAnalyzerInterface from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type -from mypy.messages import MessageBuilder, format_type -from mypy.errorcodes import ErrorCode -from mypy import errorcodes as codes +from mypy.types import ( + AnyType, + CallableType, + Instance, + NoneType, + PlaceholderType, + TupleType, + Type, + TypeOfAny, + get_proper_type, +) class NewTypeAnalyzer: - def __init__(self, - options: Options, - api: SemanticAnalyzerInterface, - msg: MessageBuilder) -> None: + def __init__( + self, options: Options, api: SemanticAnalyzerInterface, msg: MessageBuilder + ) -> None: self.options = options self.api = api self.msg = msg @@ -50,11 +71,10 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool: # add placeholder as we do for ClassDef. if self.api.is_func_scope(): - name += '@' + str(s.line) + name += "@" + str(s.line) fullname = self.api.qualified_name(name) - if (not call.analyzed or - isinstance(call.analyzed, NewTypeExpr) and not call.analyzed.info): + if not call.analyzed or isinstance(call.analyzed, NewTypeExpr) and not call.analyzed.info: # Start from labeling this as a future class, as we do for normal ClassDefs. placeholder = PlaceholderNode(fullname, s, s.line, becomes_typeinfo=True) self.api.add_symbol(var_name, placeholder, s, can_defer=False) @@ -71,8 +91,9 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool: # Create the corresponding class definition if the aliased type is subtypeable if isinstance(old_type, TupleType): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, - old_type.partial_fallback, s.line) + newtype_class_info = self.build_newtype_typeinfo( + name, old_type, old_type.partial_fallback, s.line + ) newtype_class_info.tuple_type = old_type elif isinstance(old_type, Instance): if old_type.type.is_protocol: @@ -84,12 +105,13 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool: self.fail(message.format(format_type(old_type)), s, code=codes.VALID_NEWTYPE) # Otherwise the error was already reported. old_type = AnyType(TypeOfAny.from_error) - object_type = self.api.named_type('builtins.object') + object_type = self.api.named_type("builtins.object") newtype_class_info = self.build_newtype_typeinfo(name, old_type, object_type, s.line) newtype_class_info.fallback_to_any = True - check_for_explicit_any(old_type, self.options, self.api.is_typeshed_stub_file, self.msg, - context=s) + check_for_explicit_any( + old_type, self.options, self.api.is_typeshed_stub_file, self.msg, context=s + ) if self.options.disallow_any_unimported and has_any_from_unimported_type(old_type): self.msg.unimported_type_becomes_any("Argument 2 to NewType(...)", old_type, s) @@ -108,15 +130,18 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool: newtype_class_info.line = s.line return True - def analyze_newtype_declaration(self, - s: AssignmentStmt) -> Tuple[Optional[str], Optional[CallExpr]]: + def analyze_newtype_declaration( + self, s: AssignmentStmt + ) -> Tuple[Optional[str], Optional[CallExpr]]: """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" name, call = None, None - if (len(s.lvalues) == 1 - and isinstance(s.lvalues[0], NameExpr) - and isinstance(s.rvalue, CallExpr) - and isinstance(s.rvalue.callee, RefExpr) - and s.rvalue.callee.fullname == 'typing.NewType'): + if ( + len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and isinstance(s.rvalue, CallExpr) + and isinstance(s.rvalue.callee, RefExpr) + and s.rvalue.callee.fullname == "typing.NewType" + ): name = s.lvalues[0].name if s.type: @@ -125,8 +150,11 @@ def analyze_newtype_declaration(self, names = self.api.current_symbol_table() existing = names.get(name) # Give a better error message than generic "Name already defined". - if (existing and - not isinstance(existing.node, PlaceholderNode) and not s.rvalue.analyzed): + if ( + existing + and not isinstance(existing.node, PlaceholderNode) + and not s.rvalue.analyzed + ): self.fail(f'Cannot redefine "{name}" as a NewType', s) # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be @@ -136,8 +164,9 @@ def analyze_newtype_declaration(self, return name, call - def check_newtype_args(self, name: str, call: CallExpr, - context: Context) -> Tuple[Optional[Type], bool]: + def check_newtype_args( + self, name: str, call: CallExpr, context: Context + ) -> Tuple[Optional[Type], bool]: """Ananlyze base type in NewType call. Return a tuple (type, should defer). @@ -167,8 +196,7 @@ def check_newtype_args(self, name: str, call: CallExpr, # We want to use our custom error message (see above), so we suppress # the default error message for invalid types here. - old_type = get_proper_type(self.api.anal_type(unanalyzed_type, - report_invalid_types=False)) + old_type = get_proper_type(self.api.anal_type(unanalyzed_type, report_invalid_types=False)) should_defer = False if old_type is None or isinstance(old_type, PlaceholderType): should_defer = True @@ -181,25 +209,29 @@ def check_newtype_args(self, name: str, call: CallExpr, return None if has_failed else old_type, should_defer - def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance, - line: int) -> TypeInfo: + def build_newtype_typeinfo( + self, name: str, old_type: Type, base_type: Instance, line: int + ) -> TypeInfo: info = self.api.basic_new_typeinfo(name, base_type, line) info.is_newtype = True # Add __init__ method - args = [Argument(Var('self'), NoneType(), None, ARG_POS), - self.make_argument('item', old_type)] + args = [ + Argument(Var("self"), NoneType(), None, ARG_POS), + self.make_argument("item", old_type), + ] signature = CallableType( arg_types=[Instance(info, []), old_type], arg_kinds=[arg.kind for arg in args], - arg_names=['self', 'item'], + arg_names=["self", "item"], ret_type=NoneType(), - fallback=self.api.named_type('builtins.function'), - name=name) - init_func = FuncDef('__init__', args, Block([]), typ=signature) + fallback=self.api.named_type("builtins.function"), + name=name, + ) + init_func = FuncDef("__init__", args, Block([]), typ=signature) init_func.info = info - init_func._fullname = info.fullname + '.__init__' - info.names['__init__'] = SymbolTableNode(MDEF, init_func) + init_func._fullname = info.fullname + ".__init__" + info.names["__init__"] = SymbolTableNode(MDEF, init_func) return info diff --git a/mypy/semanal_pass1.py b/mypy/semanal_pass1.py index 2b096f08082a4..4f5292797d8f9 100644 --- a/mypy/semanal_pass1.py +++ b/mypy/semanal_pass1.py @@ -1,15 +1,28 @@ """Block/import reachability analysis.""" from mypy.nodes import ( - MypyFile, AssertStmt, IfStmt, Block, AssignmentStmt, ExpressionStmt, ReturnStmt, ForStmt, - MatchStmt, Import, ImportAll, ImportFrom, ClassDef, FuncDef + AssertStmt, + AssignmentStmt, + Block, + ClassDef, + ExpressionStmt, + ForStmt, + FuncDef, + IfStmt, + Import, + ImportAll, + ImportFrom, + MatchStmt, + MypyFile, + ReturnStmt, ) -from mypy.traverser import TraverserVisitor from mypy.options import Options from mypy.reachability import ( - infer_reachability_of_if_statement, assert_will_always_fail, - infer_reachability_of_match_statement + assert_will_always_fail, + infer_reachability_of_if_statement, + infer_reachability_of_match_statement, ) +from mypy.traverser import TraverserVisitor class SemanticAnalyzerPreAnalysis(TraverserVisitor): @@ -55,7 +68,7 @@ def visit_file(self, file: MypyFile, fnam: str, mod_id: str, options: Options) - # We've encountered an assert that's always false, # e.g. assert sys.platform == 'lol'. Truncate the # list of statements. This mutates file.defs too. - del file.defs[i + 1:] + del file.defs[i + 1 :] break def visit_func_def(self, node: FuncDef) -> None: @@ -64,10 +77,12 @@ def visit_func_def(self, node: FuncDef) -> None: super().visit_func_def(node) self.is_global_scope = old_global_scope file_node = self.cur_mod_node - if (self.is_global_scope - and file_node.is_stub - and node.name == '__getattr__' - and file_node.is_package_init_file()): + if ( + self.is_global_scope + and file_node.is_stub + and node.name == "__getattr__" + and file_node.is_package_init_file() + ): # __init__.pyi with __getattr__ means that any submodules are assumed # to exist, even if there is no stub. Note that we can't verify that the # return type is compatible, since we haven't bound types yet. diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index 6d6c4ac9f0d4c..85bf3b18d4998 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -1,22 +1,37 @@ """Shared definitions used by different parts of semantic analysis.""" from abc import abstractmethod +from typing import Callable, List, Optional, Union -from typing import Optional, List, Callable, Union -from typing_extensions import Final, Protocol from mypy_extensions import trait +from typing_extensions import Final, Protocol +from mypy import join +from mypy.errorcodes import ErrorCode from mypy.nodes import ( - Context, SymbolTableNode, FuncDef, Node, TypeInfo, Expression, - SymbolNode, SymbolTable + Context, + Expression, + FuncDef, + Node, + SymbolNode, + SymbolTable, + SymbolTableNode, + TypeInfo, ) +from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( - Type, FunctionLike, Instance, TupleType, TPDICT_FB_NAMES, ProperType, get_proper_type, - ParamSpecType, ParamSpecFlavor, Parameters, TypeVarId + TPDICT_FB_NAMES, + FunctionLike, + Instance, + Parameters, + ParamSpecFlavor, + ParamSpecType, + ProperType, + TupleType, + Type, + TypeVarId, + get_proper_type, ) -from mypy.tvar_scope import TypeVarLikeScope -from mypy.errorcodes import ErrorCode -from mypy import join # Priorities for ordering of patches within the "patch" phase of semantic analysis # (after the main pass): @@ -33,8 +48,9 @@ class SemanticAnalyzerCoreInterface: """ @abstractmethod - def lookup_qualified(self, name: str, ctx: Context, - suppress_errors: bool = False) -> Optional[SymbolTableNode]: + def lookup_qualified( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> Optional[SymbolTableNode]: raise NotImplementedError @abstractmethod @@ -46,8 +62,15 @@ def lookup_fully_qualified_or_none(self, name: str) -> Optional[SymbolTableNode] raise NotImplementedError @abstractmethod - def fail(self, msg: str, ctx: Context, serious: bool = False, *, - blocker: bool = False, code: Optional[ErrorCode] = None) -> None: + def fail( + self, + msg: str, + ctx: Context, + serious: bool = False, + *, + blocker: bool = False, + code: Optional[ErrorCode] = None, + ) -> None: raise NotImplementedError @abstractmethod @@ -96,18 +119,19 @@ class SemanticAnalyzerInterface(SemanticAnalyzerCoreInterface): """ @abstractmethod - def lookup(self, name: str, ctx: Context, - suppress_errors: bool = False) -> Optional[SymbolTableNode]: + def lookup( + self, name: str, ctx: Context, suppress_errors: bool = False + ) -> Optional[SymbolTableNode]: raise NotImplementedError @abstractmethod - def named_type(self, fullname: str, - args: Optional[List[Type]] = None) -> Instance: + def named_type(self, fullname: str, args: Optional[List[Type]] = None) -> Instance: raise NotImplementedError @abstractmethod - def named_type_or_none(self, fullname: str, - args: Optional[List[Type]] = None) -> Optional[Instance]: + def named_type_or_none( + self, fullname: str, args: Optional[List[Type]] = None + ) -> Optional[Instance]: raise NotImplementedError @abstractmethod @@ -115,12 +139,16 @@ def accept(self, node: Node) -> None: raise NotImplementedError @abstractmethod - def anal_type(self, t: Type, *, - tvar_scope: Optional[TypeVarLikeScope] = None, - allow_tuple_literal: bool = False, - allow_unbound_tvars: bool = False, - allow_required: bool = False, - report_invalid_types: bool = True) -> Optional[Type]: + def anal_type( + self, + t: Type, + *, + tvar_scope: Optional[TypeVarLikeScope] = None, + allow_tuple_literal: bool = False, + allow_unbound_tvars: bool = False, + allow_required: bool = False, + report_invalid_types: bool = True, + ) -> Optional[Type]: raise NotImplementedError @abstractmethod @@ -145,9 +173,15 @@ def current_symbol_table(self) -> SymbolTable: raise NotImplementedError @abstractmethod - def add_symbol(self, name: str, node: SymbolNode, context: Context, - module_public: bool = True, module_hidden: bool = False, - can_defer: bool = True) -> bool: + def add_symbol( + self, + name: str, + node: SymbolNode, + context: Context, + module_public: bool = True, + module_hidden: bool = False, + can_defer: bool = True, + ) -> bool: """Add symbol to the current symbol table.""" raise NotImplementedError @@ -185,11 +219,10 @@ def set_callable_name(sig: Type, fdef: FuncDef) -> ProperType: if fdef.info: if fdef.info.fullname in TPDICT_FB_NAMES: # Avoid exposing the internal _TypedDict name. - class_name = 'TypedDict' + class_name = "TypedDict" else: class_name = fdef.info.name - return sig.with_name( - f'{fdef.name} of {class_name}') + return sig.with_name(f"{fdef.name} of {class_name}") else: return sig.with_name(fdef.name) else: @@ -211,37 +244,46 @@ def calculate_tuple_fallback(typ: TupleType) -> None: we don't prevent their existence). """ fallback = typ.partial_fallback - assert fallback.type.fullname == 'builtins.tuple' + assert fallback.type.fullname == "builtins.tuple" fallback.args = (join.join_type_list(list(typ.items)),) + fallback.args[1:] class _NamedTypeCallback(Protocol): - def __call__( - self, fully_qualified_name: str, args: Optional[List[Type]] = None - ) -> Instance: ... + def __call__(self, fully_qualified_name: str, args: Optional[List[Type]] = None) -> Instance: + ... def paramspec_args( - name: str, fullname: str, id: Union[TypeVarId, int], *, - named_type_func: _NamedTypeCallback, line: int = -1, column: int = -1, - prefix: Optional[Parameters] = None + name: str, + fullname: str, + id: Union[TypeVarId, int], + *, + named_type_func: _NamedTypeCallback, + line: int = -1, + column: int = -1, + prefix: Optional[Parameters] = None, ) -> ParamSpecType: return ParamSpecType( name, fullname, id, flavor=ParamSpecFlavor.ARGS, - upper_bound=named_type_func('builtins.tuple', [named_type_func('builtins.object')]), + upper_bound=named_type_func("builtins.tuple", [named_type_func("builtins.object")]), line=line, column=column, - prefix=prefix + prefix=prefix, ) def paramspec_kwargs( - name: str, fullname: str, id: Union[TypeVarId, int], *, - named_type_func: _NamedTypeCallback, line: int = -1, column: int = -1, - prefix: Optional[Parameters] = None + name: str, + fullname: str, + id: Union[TypeVarId, int], + *, + named_type_func: _NamedTypeCallback, + line: int = -1, + column: int = -1, + prefix: Optional[Parameters] = None, ) -> ParamSpecType: return ParamSpecType( name, @@ -249,10 +291,9 @@ def paramspec_kwargs( id, flavor=ParamSpecFlavor.KWARGS, upper_bound=named_type_func( - 'builtins.dict', - [named_type_func('builtins.str'), named_type_func('builtins.object')] + "builtins.dict", [named_type_func("builtins.str"), named_type_func("builtins.object")] ), line=line, column=column, - prefix=prefix + prefix=prefix, ) diff --git a/mypy/semanal_typeargs.py b/mypy/semanal_typeargs.py index 483154000d1bd..8e1cae3717df3 100644 --- a/mypy/semanal_typeargs.py +++ b/mypy/semanal_typeargs.py @@ -7,20 +7,30 @@ from typing import List, Optional, Set -from mypy.nodes import TypeInfo, Context, MypyFile, FuncItem, ClassDef, Block, FakeInfo -from mypy.types import ( - Type, Instance, TypeVarType, AnyType, get_proper_types, TypeAliasType, ParamSpecType, - UnpackType, TupleType, TypeVarTupleType, TypeOfAny, get_proper_type -) +from mypy import errorcodes as codes, message_registry +from mypy.errorcodes import ErrorCode +from mypy.errors import Errors +from mypy.messages import format_type from mypy.mixedtraverser import MixedTraverserVisitor -from mypy.subtypes import is_subtype +from mypy.nodes import Block, ClassDef, Context, FakeInfo, FuncItem, MypyFile, TypeInfo +from mypy.options import Options from mypy.sametypes import is_same_type -from mypy.errors import Errors from mypy.scope import Scope -from mypy.options import Options -from mypy.errorcodes import ErrorCode -from mypy import message_registry, errorcodes as codes -from mypy.messages import format_type +from mypy.subtypes import is_subtype +from mypy.types import ( + AnyType, + Instance, + ParamSpecType, + TupleType, + Type, + TypeAliasType, + TypeOfAny, + TypeVarTupleType, + TypeVarType, + UnpackType, + get_proper_type, + get_proper_types, +) class TypeArgumentAnalyzer(MixedTraverserVisitor): @@ -82,8 +92,11 @@ def visit_instance(self, t: Instance) -> None: if not arg_values: self.fail( message_registry.INVALID_TYPEVAR_AS_TYPEARG.format( - arg.name, info.name), - t, code=codes.TYPE_VAR) + arg.name, info.name + ), + t, + code=codes.TYPE_VAR, + ) continue else: arg_values = [arg] @@ -91,8 +104,11 @@ def visit_instance(self, t: Instance) -> None: if not is_subtype(arg, tvar.upper_bound): self.fail( message_registry.INVALID_TYPEVAR_ARG_BOUND.format( - format_type(arg), info.name, format_type(tvar.upper_bound)), - t, code=codes.TYPE_VAR) + format_type(arg), info.name, format_type(tvar.upper_bound) + ), + t, + code=codes.TYPE_VAR, + ) super().visit_instance(t) def visit_unpack_type(self, typ: UnpackType) -> None: @@ -110,24 +126,35 @@ def visit_unpack_type(self, typ: UnpackType) -> None: # typechecking to work. self.fail(message_registry.INVALID_UNPACK.format(proper_type), typ) - def check_type_var_values(self, type: TypeInfo, actuals: List[Type], arg_name: str, - valids: List[Type], arg_number: int, context: Context) -> None: + def check_type_var_values( + self, + type: TypeInfo, + actuals: List[Type], + arg_name: str, + valids: List[Type], + arg_number: int, + context: Context, + ) -> None: for actual in get_proper_types(actuals): - if (not isinstance(actual, AnyType) and - not any(is_same_type(actual, value) - for value in valids)): + if not isinstance(actual, AnyType) and not any( + is_same_type(actual, value) for value in valids + ): if len(actuals) > 1 or not isinstance(actual, Instance): self.fail( message_registry.INVALID_TYPEVAR_ARG_VALUE.format(type.name), - context, code=codes.TYPE_VAR) + context, + code=codes.TYPE_VAR, + ) else: class_name = f'"{type.name}"' actual_type_name = f'"{actual.type.name}"' self.fail( message_registry.INCOMPATIBLE_TYPEVAR_VALUE.format( - arg_name, class_name, actual_type_name), + arg_name, class_name, actual_type_name + ), context, - code=codes.TYPE_VAR) + code=codes.TYPE_VAR, + ) def fail(self, msg: str, context: Context, *, code: Optional[ErrorCode] = None) -> None: self.errors.report(context.get_line(), context.get_column(), msg, code=code) diff --git a/mypy/semanal_typeddict.py b/mypy/semanal_typeddict.py index 4087f477c597e..dd6659bf10658 100644 --- a/mypy/semanal_typeddict.py +++ b/mypy/semanal_typeddict.py @@ -1,24 +1,39 @@ """Semantic analysis of TypedDict definitions.""" -from mypy.backports import OrderedDict -from typing import Optional, List, Set, Tuple +from typing import List, Optional, Set, Tuple + from typing_extensions import Final -from mypy.types import ( - Type, AnyType, TypeOfAny, TypedDictType, TPDICT_NAMES, RequiredType, -) +from mypy import errorcodes as codes +from mypy.backports import OrderedDict +from mypy.errorcodes import ErrorCode +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.messages import MessageBuilder from mypy.nodes import ( - CallExpr, TypedDictExpr, Expression, NameExpr, Context, StrExpr, BytesExpr, UnicodeExpr, - ClassDef, RefExpr, TypeInfo, AssignmentStmt, PassStmt, ExpressionStmt, EllipsisExpr, TempNode, - DictExpr, ARG_POS, ARG_NAMED + ARG_NAMED, + ARG_POS, + AssignmentStmt, + BytesExpr, + CallExpr, + ClassDef, + Context, + DictExpr, + EllipsisExpr, + Expression, + ExpressionStmt, + NameExpr, + PassStmt, + RefExpr, + StrExpr, + TempNode, + TypedDictExpr, + TypeInfo, + UnicodeExpr, ) -from mypy.semanal_shared import SemanticAnalyzerInterface -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.options import Options +from mypy.semanal_shared import SemanticAnalyzerInterface from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type -from mypy.messages import MessageBuilder -from mypy.errorcodes import ErrorCode -from mypy import errorcodes as codes +from mypy.types import TPDICT_NAMES, AnyType, RequiredType, Type, TypedDictType, TypeOfAny TPDICT_CLASS_ERROR: Final = ( "Invalid statement in TypedDict definition; " 'expected "field_name: field_type"' @@ -26,10 +41,9 @@ class TypedDictAnalyzer: - def __init__(self, - options: Options, - api: SemanticAnalyzerInterface, - msg: MessageBuilder) -> None: + def __init__( + self, options: Options, api: SemanticAnalyzerInterface, msg: MessageBuilder + ) -> None: self.options = options self.api = api self.msg = msg @@ -56,15 +70,18 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ if base_expr.fullname in TPDICT_NAMES or self.is_typeddict(base_expr): possible = True if possible: - if (len(defn.base_type_exprs) == 1 and - isinstance(defn.base_type_exprs[0], RefExpr) and - defn.base_type_exprs[0].fullname in TPDICT_NAMES): + if ( + len(defn.base_type_exprs) == 1 + and isinstance(defn.base_type_exprs[0], RefExpr) + and defn.base_type_exprs[0].fullname in TPDICT_NAMES + ): # Building a new TypedDict fields, types, required_keys = self.analyze_typeddict_classdef_fields(defn) if fields is None: return True, None # Defer - info = self.build_typeddict_typeinfo(defn.name, fields, types, required_keys, - defn.line) + info = self.build_typeddict_typeinfo( + defn.name, fields, types, required_keys, defn.line + ) defn.analyzed = TypedDictExpr(info) defn.analyzed.line = defn.line defn.analyzed.column = defn.column @@ -75,8 +92,8 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ typeddict_bases_set = set() for expr in defn.base_type_exprs: if isinstance(expr, RefExpr) and expr.fullname in TPDICT_NAMES: - if 'TypedDict' not in typeddict_bases_set: - typeddict_bases_set.add('TypedDict') + if "TypedDict" not in typeddict_bases_set: + typeddict_bases_set.add("TypedDict") else: self.fail('Duplicate base class "TypedDict"', defn) elif isinstance(expr, RefExpr) and self.is_typeddict(expr): @@ -103,13 +120,15 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ valid_items = base_items.copy() for key in base_items: if key in keys: - self.fail('Overwriting TypedDict field "{}" while merging' - .format(key), defn) + self.fail( + 'Overwriting TypedDict field "{}" while merging'.format(key), defn + ) keys.extend(valid_items.keys()) types.extend(valid_items.values()) required_keys.update(base_typed_dict.required_keys) - new_keys, new_types, new_required_keys = self.analyze_typeddict_classdef_fields(defn, - keys) + new_keys, new_types, new_required_keys = self.analyze_typeddict_classdef_fields( + defn, keys + ) if new_keys is None: return True, None # Defer keys.extend(new_keys) @@ -123,11 +142,8 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ return False, None def analyze_typeddict_classdef_fields( - self, - defn: ClassDef, - oldfields: Optional[List[str]] = None) -> Tuple[Optional[List[str]], - List[Type], - Set[str]]: + self, defn: ClassDef, oldfields: Optional[List[str]] = None + ) -> Tuple[Optional[List[str]], List[Type], Set[str]]: """Analyze fields defined in a TypedDict class definition. This doesn't consider inherited fields (if any). Also consider totality, @@ -143,9 +159,10 @@ def analyze_typeddict_classdef_fields( for stmt in defn.defs.body: if not isinstance(stmt, AssignmentStmt): # Still allow pass or ... (for empty TypedDict's). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, (EllipsisExpr, StrExpr)))): + if not isinstance(stmt, PassStmt) and not ( + isinstance(stmt, ExpressionStmt) + and isinstance(stmt.expr, (EllipsisExpr, StrExpr)) + ): self.fail(TPDICT_CLASS_ERROR, stmt) elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): # An assignment, but an invalid one. @@ -153,8 +170,9 @@ def analyze_typeddict_classdef_fields( else: name = stmt.lvalues[0].name if name in (oldfields or []): - self.fail('Overwriting TypedDict field "{}" while extending' - .format(name), stmt) + self.fail( + 'Overwriting TypedDict field "{}" while extending'.format(name), stmt + ) if name in fields: self.fail(f'Duplicate TypedDict key "{name}"', stmt) continue @@ -168,39 +186,32 @@ def analyze_typeddict_classdef_fields( return None, [], set() # Need to defer types.append(analyzed) # ...despite possible minor failures that allow further analyzis. - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + if stmt.type is None or hasattr(stmt, "new_syntax") and not stmt.new_syntax: self.fail(TPDICT_CLASS_ERROR, stmt) elif not isinstance(stmt.rvalue, TempNode): # x: int assigns rvalue to TempNode(AnyType()) - self.fail('Right hand side values are not supported in TypedDict', stmt) + self.fail("Right hand side values are not supported in TypedDict", stmt) total: Optional[bool] = True - if 'total' in defn.keywords: - total = self.api.parse_bool(defn.keywords['total']) + if "total" in defn.keywords: + total = self.api.parse_bool(defn.keywords["total"]) if total is None: self.fail('Value of "total" must be True or False', defn) total = True required_keys = { field for (field, t) in zip(fields, types) - if (total or ( - isinstance(t, RequiredType) and # type: ignore[misc] - t.required - )) and not ( - isinstance(t, RequiredType) and # type: ignore[misc] - not t.required - ) + if (total or (isinstance(t, RequiredType) and t.required)) # type: ignore[misc] + and not (isinstance(t, RequiredType) and not t.required) # type: ignore[misc] } types = [ # unwrap Required[T] to just T - t.item if isinstance(t, RequiredType) else t # type: ignore[misc] - for t in types + t.item if isinstance(t, RequiredType) else t for t in types # type: ignore[misc] ] return fields, types, required_keys - def check_typeddict(self, - node: Expression, - var_name: Optional[str], - is_func_scope: bool) -> Tuple[bool, Optional[TypeInfo]]: + def check_typeddict( + self, node: Expression, var_name: Optional[str], is_func_scope: bool + ) -> Tuple[bool, Optional[TypeInfo]]: """Check if a call defines a TypedDict. The optional var_name argument is the name of the variable to @@ -229,29 +240,27 @@ def check_typeddict(self, name, items, types, total, ok = res if not ok: # Error. Construct dummy return value. - info = self.build_typeddict_typeinfo('TypedDict', [], [], set(), call.line) + info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line) else: if var_name is not None and name != var_name: self.fail( 'First argument "{}" to TypedDict() does not match variable name "{}"'.format( - name, var_name), node, code=codes.NAME_MATCH) + name, var_name + ), + node, + code=codes.NAME_MATCH, + ) if name != var_name or is_func_scope: # Give it a unique name derived from the line number. - name += '@' + str(call.line) + name += "@" + str(call.line) required_keys = { field for (field, t) in zip(items, types) - if (total or ( - isinstance(t, RequiredType) and # type: ignore[misc] - t.required - )) and not ( - isinstance(t, RequiredType) and # type: ignore[misc] - not t.required - ) + if (total or (isinstance(t, RequiredType) and t.required)) # type: ignore[misc] + and not (isinstance(t, RequiredType) and not t.required) # type: ignore[misc] } types = [ # unwrap Required[T] to just T - t.item if isinstance(t, RequiredType) else t # type: ignore[misc] - for t in types + t.item if isinstance(t, RequiredType) else t for t in types # type: ignore[misc] ] info = self.build_typeddict_typeinfo(name, items, types, required_keys, call.line) info.line = node.line @@ -265,7 +274,8 @@ def check_typeddict(self, return True, info def parse_typeddict_args( - self, call: CallExpr) -> Optional[Tuple[str, List[str], List[Type], bool, bool]]: + self, call: CallExpr + ) -> Optional[Tuple[str, List[str], List[Type], bool, bool]]: """Parse typed dict call expression. Return names, types, totality, was there an error during parsing. @@ -280,21 +290,25 @@ def parse_typeddict_args( # TODO: Support keyword arguments if call.arg_kinds not in ([ARG_POS, ARG_POS], [ARG_POS, ARG_POS, ARG_NAMED]): return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) - if len(args) == 3 and call.arg_names[2] != 'total': + if len(args) == 3 and call.arg_names[2] != "total": return self.fail_typeddict_arg( - f'Unexpected keyword argument "{call.arg_names[2]}" for "TypedDict"', call) + f'Unexpected keyword argument "{call.arg_names[2]}" for "TypedDict"', call + ) if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): return self.fail_typeddict_arg( - "TypedDict() expects a string literal as the first argument", call) + "TypedDict() expects a string literal as the first argument", call + ) if not isinstance(args[1], DictExpr): return self.fail_typeddict_arg( - "TypedDict() expects a dictionary literal as the second argument", call) + "TypedDict() expects a dictionary literal as the second argument", call + ) total: Optional[bool] = True if len(args) == 3: total = self.api.parse_bool(call.args[2]) if total is None: return self.fail_typeddict_arg( - 'TypedDict() "total" argument must be True or False', call) + 'TypedDict() "total" argument must be True or False', call + ) dictexpr = args[1] res = self.parse_typeddict_fields_with_types(dictexpr.items, call) if res is None: @@ -302,8 +316,9 @@ def parse_typeddict_args( return None items, types, ok = res for t in types: - check_for_explicit_any(t, self.options, self.api.is_typeshed_stub_file, self.msg, - context=call) + check_for_explicit_any( + t, self.options, self.api.is_typeshed_stub_file, self.msg, context=call + ) if self.options.disallow_any_unimported: for t in types: @@ -313,9 +328,8 @@ def parse_typeddict_args( return args[0].value, items, types, total, ok def parse_typeddict_fields_with_types( - self, - dict_items: List[Tuple[Optional[Expression], Expression]], - context: Context) -> Optional[Tuple[List[str], List[Type], bool]]: + self, dict_items: List[Tuple[Optional[Expression], Expression]], context: Context + ) -> Optional[Tuple[List[str], List[Type], bool]]: """Parse typed dict items passed as pairs (name expression, type expression). Return names, types, was there an error. If some type is not ready, return None. @@ -335,17 +349,21 @@ def parse_typeddict_fields_with_types( self.fail_typeddict_arg("Invalid TypedDict() field name", name_context) return [], [], False try: - type = expr_to_unanalyzed_type(field_type_expr, self.options, - self.api.is_stub_file) + type = expr_to_unanalyzed_type( + field_type_expr, self.options, self.api.is_stub_file + ) except TypeTranslationError: - if (isinstance(field_type_expr, CallExpr) and - isinstance(field_type_expr.callee, RefExpr) and - field_type_expr.callee.fullname in TPDICT_NAMES): + if ( + isinstance(field_type_expr, CallExpr) + and isinstance(field_type_expr.callee, RefExpr) + and field_type_expr.callee.fullname in TPDICT_NAMES + ): self.fail_typeddict_arg( - 'Inline TypedDict types not supported; use assignment to define TypedDict', - field_type_expr) + "Inline TypedDict types not supported; use assignment to define TypedDict", + field_type_expr, + ) else: - self.fail_typeddict_arg('Invalid field type', field_type_expr) + self.fail_typeddict_arg("Invalid field type", field_type_expr) return [], [], False analyzed = self.api.anal_type(type, allow_required=True) if analyzed is None: @@ -353,30 +371,36 @@ def parse_typeddict_fields_with_types( types.append(analyzed) return items, types, True - def fail_typeddict_arg(self, message: str, - context: Context) -> Tuple[str, List[str], List[Type], bool, bool]: + def fail_typeddict_arg( + self, message: str, context: Context + ) -> Tuple[str, List[str], List[Type], bool, bool]: self.fail(message, context) - return '', [], [], True, False + return "", [], [], True, False - def build_typeddict_typeinfo(self, name: str, items: List[str], - types: List[Type], - required_keys: Set[str], - line: int) -> TypeInfo: + def build_typeddict_typeinfo( + self, name: str, items: List[str], types: List[Type], required_keys: Set[str], line: int + ) -> TypeInfo: # Prefer typing then typing_extensions if available. - fallback = (self.api.named_type_or_none('typing._TypedDict', []) or - self.api.named_type_or_none('typing_extensions._TypedDict', []) or - self.api.named_type_or_none('mypy_extensions._TypedDict', [])) + fallback = ( + self.api.named_type_or_none("typing._TypedDict", []) + or self.api.named_type_or_none("typing_extensions._TypedDict", []) + or self.api.named_type_or_none("mypy_extensions._TypedDict", []) + ) assert fallback is not None info = self.api.basic_new_typeinfo(name, fallback, line) - info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), required_keys, - fallback) + info.typeddict_type = TypedDictType( + OrderedDict(zip(items, types)), required_keys, fallback + ) return info # Helpers def is_typeddict(self, expr: Expression) -> bool: - return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and - expr.node.typeddict_type is not None) + return ( + isinstance(expr, RefExpr) + and isinstance(expr.node, TypeInfo) + and expr.node.typeddict_type is not None + ) def fail(self, msg: str, ctx: Context, *, code: Optional[ErrorCode] = None) -> None: self.api.fail(msg, ctx, code=code) diff --git a/mypy/server/astdiff.py b/mypy/server/astdiff.py index 1f1c6b65f385b..5b7af991b5a05 100644 --- a/mypy/server/astdiff.py +++ b/mypy/server/astdiff.py @@ -50,21 +50,50 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' fine-grained dependencies. """ -from typing import Set, Dict, Tuple, Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Set, Tuple, Union from mypy.nodes import ( - SymbolTable, TypeInfo, Var, SymbolNode, Decorator, TypeVarExpr, TypeAlias, - FuncBase, OverloadedFuncDef, FuncItem, MypyFile, ParamSpecExpr, UNBOUND_IMPORTED + UNBOUND_IMPORTED, + Decorator, + FuncBase, + FuncItem, + MypyFile, + OverloadedFuncDef, + ParamSpecExpr, + SymbolNode, + SymbolTable, + TypeAlias, + TypeInfo, + TypeVarExpr, + Var, ) from mypy.types import ( - Type, TypeVisitor, UnboundType, AnyType, NoneType, UninhabitedType, - ErasedType, DeletedType, Instance, TypeVarType, CallableType, TupleType, TypedDictType, - UnionType, Overloaded, PartialType, TypeType, LiteralType, TypeAliasType, ParamSpecType, - Parameters, UnpackType, TypeVarTupleType, + AnyType, + CallableType, + DeletedType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, ) from mypy.util import get_prefix - # Snapshot representation of a symbol table node or type. The representation is # opaque -- the only supported operations are comparing for equality and # hashing (latter for type snapshots only). Snapshots can contain primitive @@ -76,9 +105,8 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method' def compare_symbol_table_snapshots( - name_prefix: str, - snapshot1: Dict[str, SnapshotItem], - snapshot2: Dict[str, SnapshotItem]) -> Set[str]: + name_prefix: str, snapshot1: Dict[str, SnapshotItem], snapshot2: Dict[str, SnapshotItem] +) -> Set[str]: """Return names that are different in two snapshots of a symbol table. Only shallow (intra-module) differences are considered. References to things defined @@ -89,8 +117,8 @@ def compare_symbol_table_snapshots( Return a set of fully-qualified names (e.g., 'mod.func' or 'mod.Class.method'). """ # Find names only defined only in one version. - names1 = {f'{name_prefix}.{name}' for name in snapshot1} - names2 = {f'{name_prefix}.{name}' for name in snapshot2} + names1 = {f"{name_prefix}.{name}" for name in snapshot1} + names2 = {f"{name_prefix}.{name}" for name in snapshot2} triggers = names1 ^ names2 # Look for names defined in both versions that are different. @@ -99,11 +127,11 @@ def compare_symbol_table_snapshots( item2 = snapshot2[name] kind1 = item1[0] kind2 = item2[0] - item_name = f'{name_prefix}.{name}' + item_name = f"{name_prefix}.{name}" if kind1 != kind2: # Different kind of node in two snapshots -> trivially different. triggers.add(item_name) - elif kind1 == 'TypeInfo': + elif kind1 == "TypeInfo": if item1[:-1] != item2[:-1]: # Record major difference (outside class symbol tables). triggers.add(item_name) @@ -140,34 +168,37 @@ def snapshot_symbol_table(name_prefix: str, table: SymbolTable) -> Dict[str, Sna # If the reference is busted because the other module is missing, # the node will be a "stale_info" TypeInfo produced by fixup, # but that doesn't really matter to us here. - result[name] = ('Moduleref', common) + result[name] = ("Moduleref", common) elif isinstance(node, TypeVarExpr): - result[name] = ('TypeVar', - node.variance, - [snapshot_type(value) for value in node.values], - snapshot_type(node.upper_bound)) + result[name] = ( + "TypeVar", + node.variance, + [snapshot_type(value) for value in node.values], + snapshot_type(node.upper_bound), + ) elif isinstance(node, TypeAlias): - result[name] = ('TypeAlias', - node.alias_tvars, - node.normalized, - node.no_args, - snapshot_optional_type(node.target)) + result[name] = ( + "TypeAlias", + node.alias_tvars, + node.normalized, + node.no_args, + snapshot_optional_type(node.target), + ) elif isinstance(node, ParamSpecExpr): - result[name] = ('ParamSpec', - node.variance, - snapshot_type(node.upper_bound)) + result[name] = ("ParamSpec", node.variance, snapshot_type(node.upper_bound)) else: assert symbol.kind != UNBOUND_IMPORTED if node and get_prefix(node.fullname) != name_prefix: # This is a cross-reference to a node defined in another module. - result[name] = ('CrossRef', common) + result[name] = ("CrossRef", common) else: result[name] = snapshot_definition(node, common) return result -def snapshot_definition(node: Optional[SymbolNode], - common: Tuple[object, ...]) -> Tuple[object, ...]: +def snapshot_definition( + node: Optional[SymbolNode], common: Tuple[object, ...] +) -> Tuple[object, ...]: """Create a snapshot description of a symbol table node. The representation is nested tuples and dicts. Only externally @@ -179,14 +210,17 @@ def snapshot_definition(node: Optional[SymbolNode], signature = snapshot_type(node.type) else: signature = snapshot_untyped_signature(node) - return ('Func', common, - node.is_property, node.is_final, - node.is_class, node.is_static, - signature) + return ( + "Func", + common, + node.is_property, + node.is_final, + node.is_class, + node.is_static, + signature, + ) elif isinstance(node, Var): - return ('Var', common, - snapshot_optional_type(node.type), - node.is_final) + return ("Var", common, snapshot_optional_type(node.type), node.is_final) elif isinstance(node, Decorator): # Note that decorated methods are represented by Decorator instances in # a symbol table since we need to preserve information about the @@ -194,38 +228,42 @@ def snapshot_definition(node: Optional[SymbolNode], # example). Top-level decorated functions, however, are represented by # the corresponding Var node, since that happens to provide enough # context. - return ('Decorator', - node.is_overload, - snapshot_optional_type(node.var.type), - snapshot_definition(node.func, common)) + return ( + "Decorator", + node.is_overload, + snapshot_optional_type(node.var.type), + snapshot_definition(node.func, common), + ) elif isinstance(node, TypeInfo): - attrs = (node.is_abstract, - node.is_enum, - node.is_protocol, - node.fallback_to_any, - node.is_named_tuple, - node.is_newtype, - # We need this to e.g. trigger metaclass calculation in subclasses. - snapshot_optional_type(node.metaclass_type), - snapshot_optional_type(node.tuple_type), - snapshot_optional_type(node.typeddict_type), - [base.fullname for base in node.mro], - # Note that the structure of type variables is a part of the external interface, - # since creating instances might fail, for example: - # T = TypeVar('T', bound=int) - # class C(Generic[T]): - # ... - # x: C[str] <- this is invalid, and needs to be re-checked if `T` changes. - # An alternative would be to create both deps: <...> -> C, and <...> -> , - # but this currently seems a bit ad hoc. - tuple(snapshot_type(tdef) for tdef in node.defn.type_vars), - [snapshot_type(base) for base in node.bases], - [snapshot_type(p) for p in node._promote]) + attrs = ( + node.is_abstract, + node.is_enum, + node.is_protocol, + node.fallback_to_any, + node.is_named_tuple, + node.is_newtype, + # We need this to e.g. trigger metaclass calculation in subclasses. + snapshot_optional_type(node.metaclass_type), + snapshot_optional_type(node.tuple_type), + snapshot_optional_type(node.typeddict_type), + [base.fullname for base in node.mro], + # Note that the structure of type variables is a part of the external interface, + # since creating instances might fail, for example: + # T = TypeVar('T', bound=int) + # class C(Generic[T]): + # ... + # x: C[str] <- this is invalid, and needs to be re-checked if `T` changes. + # An alternative would be to create both deps: <...> -> C, and <...> -> , + # but this currently seems a bit ad hoc. + tuple(snapshot_type(tdef) for tdef in node.defn.type_vars), + [snapshot_type(base) for base in node.bases], + [snapshot_type(p) for p in node._promote], + ) prefix = node.fullname symbol_table = snapshot_symbol_table(prefix, node.names) # Special dependency for abstract attribute handling. - symbol_table['(abstract)'] = ('Abstract', tuple(sorted(node.abstract_attributes))) - return ('TypeInfo', common, attrs, symbol_table) + symbol_table["(abstract)"] = ("Abstract", tuple(sorted(node.abstract_attributes))) + return ("TypeInfo", common, attrs, symbol_table) else: # Other node types are handled elsewhere. assert False, type(node) @@ -253,7 +291,7 @@ def snapshot_simple_type(typ: Type) -> SnapshotItem: def encode_optional_str(s: Optional[str]) -> str: if s is None: - return '' + return "" else: return s @@ -274,11 +312,13 @@ class SnapshotTypeVisitor(TypeVisitor[SnapshotItem]): """ def visit_unbound_type(self, typ: UnboundType) -> SnapshotItem: - return ('UnboundType', - typ.name, - typ.optional, - typ.empty_tuple_index, - snapshot_types(typ.args)) + return ( + "UnboundType", + typ.name, + typ.optional, + typ.empty_tuple_index, + snapshot_types(typ.args), + ) def visit_any(self, typ: AnyType) -> SnapshotItem: return snapshot_simple_type(typ) @@ -296,74 +336,85 @@ def visit_deleted_type(self, typ: DeletedType) -> SnapshotItem: return snapshot_simple_type(typ) def visit_instance(self, typ: Instance) -> SnapshotItem: - return ('Instance', - encode_optional_str(typ.type.fullname), - snapshot_types(typ.args), - ('None',) if typ.last_known_value is None else snapshot_type(typ.last_known_value)) + return ( + "Instance", + encode_optional_str(typ.type.fullname), + snapshot_types(typ.args), + ("None",) if typ.last_known_value is None else snapshot_type(typ.last_known_value), + ) def visit_type_var(self, typ: TypeVarType) -> SnapshotItem: - return ('TypeVar', - typ.name, - typ.fullname, - typ.id.raw_id, - typ.id.meta_level, - snapshot_types(typ.values), - snapshot_type(typ.upper_bound), - typ.variance) + return ( + "TypeVar", + typ.name, + typ.fullname, + typ.id.raw_id, + typ.id.meta_level, + snapshot_types(typ.values), + snapshot_type(typ.upper_bound), + typ.variance, + ) def visit_param_spec(self, typ: ParamSpecType) -> SnapshotItem: - return ('ParamSpec', - typ.id.raw_id, - typ.id.meta_level, - typ.flavor, - snapshot_type(typ.upper_bound)) + return ( + "ParamSpec", + typ.id.raw_id, + typ.id.meta_level, + typ.flavor, + snapshot_type(typ.upper_bound), + ) def visit_type_var_tuple(self, typ: TypeVarTupleType) -> SnapshotItem: - return ('TypeVarTupleType', - typ.id.raw_id, - typ.id.meta_level, - snapshot_type(typ.upper_bound)) + return ( + "TypeVarTupleType", + typ.id.raw_id, + typ.id.meta_level, + snapshot_type(typ.upper_bound), + ) def visit_unpack_type(self, typ: UnpackType) -> SnapshotItem: - return ('UnpackType', snapshot_type(typ.type)) + return ("UnpackType", snapshot_type(typ.type)) def visit_parameters(self, typ: Parameters) -> SnapshotItem: - return ('Parameters', - snapshot_types(typ.arg_types), - tuple(encode_optional_str(name) for name in typ.arg_names), - tuple(typ.arg_kinds)) + return ( + "Parameters", + snapshot_types(typ.arg_types), + tuple(encode_optional_str(name) for name in typ.arg_names), + tuple(typ.arg_kinds), + ) def visit_callable_type(self, typ: CallableType) -> SnapshotItem: # FIX generics - return ('CallableType', - snapshot_types(typ.arg_types), - snapshot_type(typ.ret_type), - tuple(encode_optional_str(name) for name in typ.arg_names), - tuple(typ.arg_kinds), - typ.is_type_obj(), - typ.is_ellipsis_args) + return ( + "CallableType", + snapshot_types(typ.arg_types), + snapshot_type(typ.ret_type), + tuple(encode_optional_str(name) for name in typ.arg_names), + tuple(typ.arg_kinds), + typ.is_type_obj(), + typ.is_ellipsis_args, + ) def visit_tuple_type(self, typ: TupleType) -> SnapshotItem: - return ('TupleType', snapshot_types(typ.items)) + return ("TupleType", snapshot_types(typ.items)) def visit_typeddict_type(self, typ: TypedDictType) -> SnapshotItem: - items = tuple((key, snapshot_type(item_type)) - for key, item_type in typ.items.items()) + items = tuple((key, snapshot_type(item_type)) for key, item_type in typ.items.items()) required = tuple(sorted(typ.required_keys)) - return ('TypedDictType', items, required) + return ("TypedDictType", items, required) def visit_literal_type(self, typ: LiteralType) -> SnapshotItem: - return ('LiteralType', snapshot_type(typ.fallback), typ.value) + return ("LiteralType", snapshot_type(typ.fallback), typ.value) def visit_union_type(self, typ: UnionType) -> SnapshotItem: # Sort and remove duplicates so that we can use equality to test for # equivalent union type snapshots. items = {snapshot_type(item) for item in typ.items} normalized = tuple(sorted(items)) - return ('UnionType', normalized) + return ("UnionType", normalized) def visit_overloaded(self, typ: Overloaded) -> SnapshotItem: - return ('Overloaded', snapshot_types(typ.items)) + return ("Overloaded", snapshot_types(typ.items)) def visit_partial_type(self, typ: PartialType) -> SnapshotItem: # A partial type is not fully defined, so the result is indeterminate. We shouldn't @@ -371,11 +422,11 @@ def visit_partial_type(self, typ: PartialType) -> SnapshotItem: raise RuntimeError def visit_type_type(self, typ: TypeType) -> SnapshotItem: - return ('TypeType', snapshot_type(typ.item)) + return ("TypeType", snapshot_type(typ.item)) def visit_type_alias_type(self, typ: TypeAliasType) -> SnapshotItem: assert typ.alias is not None - return ('TypeAliasType', typ.alias.fullname, snapshot_types(typ.args)) + return ("TypeAliasType", typ.alias.fullname, snapshot_types(typ.args)) def snapshot_untyped_signature(func: Union[OverloadedFuncDef, FuncItem]) -> Tuple[object, ...]: @@ -396,7 +447,7 @@ def snapshot_untyped_signature(func: Union[OverloadedFuncDef, FuncItem]) -> Tupl if item.var.type: result.append(snapshot_type(item.var.type)) else: - result.append(('DecoratorWithoutType',)) + result.append(("DecoratorWithoutType",)) else: result.append(snapshot_untyped_signature(item)) return tuple(result) diff --git a/mypy/server/astmerge.py b/mypy/server/astmerge.py index be69b3c00d97c..d90061b60cf7c 100644 --- a/mypy/server/astmerge.py +++ b/mypy/server/astmerge.py @@ -45,29 +45,76 @@ See the main entry point merge_asts for more details. """ -from typing import Dict, List, cast, TypeVar, Optional +from typing import Dict, List, Optional, TypeVar, cast from mypy.nodes import ( - MypyFile, SymbolTable, Block, AssignmentStmt, NameExpr, MemberExpr, RefExpr, TypeInfo, - FuncDef, ClassDef, NamedTupleExpr, SymbolNode, Var, Statement, SuperExpr, NewTypeExpr, - OverloadedFuncDef, LambdaExpr, TypedDictExpr, EnumCallExpr, FuncBase, TypeAliasExpr, CallExpr, - CastExpr, TypeAlias, AssertTypeExpr, - MDEF + MDEF, + AssertTypeExpr, + AssignmentStmt, + Block, + CallExpr, + CastExpr, + ClassDef, + EnumCallExpr, + FuncBase, + FuncDef, + LambdaExpr, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + OverloadedFuncDef, + RefExpr, + Statement, + SuperExpr, + SymbolNode, + SymbolTable, + TypeAlias, + TypeAliasExpr, + TypedDictExpr, + TypeInfo, + Var, ) from mypy.traverser import TraverserVisitor from mypy.types import ( - Type, SyntheticTypeVisitor, Instance, AnyType, NoneType, CallableType, ErasedType, DeletedType, - TupleType, TypeType, TypedDictType, UnboundType, UninhabitedType, UnionType, - Overloaded, TypeVarType, TypeList, CallableArgument, EllipsisType, StarType, LiteralType, - RawExpressionType, PartialType, PlaceholderType, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + AnyType, + CallableArgument, + CallableType, + DeletedType, + EllipsisType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + PlaceholderType, + RawExpressionType, + StarType, + SyntheticTypeVisitor, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeType, + TypeVarTupleType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, ) -from mypy.util import get_prefix, replace_object_state from mypy.typestate import TypeState +from mypy.util import get_prefix, replace_object_state -def merge_asts(old: MypyFile, old_symbols: SymbolTable, - new: MypyFile, new_symbols: SymbolTable) -> None: +def merge_asts( + old: MypyFile, old_symbols: SymbolTable, new: MypyFile, new_symbols: SymbolTable +) -> None: """Merge a new version of a module AST to a previous version. The main idea is to preserve the identities of externally visible @@ -82,7 +129,8 @@ def merge_asts(old: MypyFile, old_symbols: SymbolTable, # Find the mapping from new to old node identities for all nodes # whose identities should be preserved. replacement_map = replacement_map_from_symbol_table( - old_symbols, new_symbols, prefix=old.fullname) + old_symbols, new_symbols, prefix=old.fullname + ) # Also replace references to the new MypyFile node. replacement_map[new] = old # Perform replacements to everywhere within the new AST (not including symbol @@ -96,7 +144,8 @@ def merge_asts(old: MypyFile, old_symbols: SymbolTable, def replacement_map_from_symbol_table( - old: SymbolTable, new: SymbolTable, prefix: str) -> Dict[SymbolNode, SymbolNode]: + old: SymbolTable, new: SymbolTable, prefix: str +) -> Dict[SymbolNode, SymbolNode]: """Create a new-to-old object identity map by comparing two symbol table revisions. Both symbol tables must refer to revisions of the same module id. The symbol tables @@ -106,25 +155,29 @@ def replacement_map_from_symbol_table( """ replacements: Dict[SymbolNode, SymbolNode] = {} for name, node in old.items(): - if (name in new and (node.kind == MDEF - or node.node and get_prefix(node.node.fullname) == prefix)): + if name in new and ( + node.kind == MDEF or node.node and get_prefix(node.node.fullname) == prefix + ): new_node = new[name] - if (type(new_node.node) == type(node.node) # noqa - and new_node.node and node.node and - new_node.node.fullname == node.node.fullname and - new_node.kind == node.kind): + if ( + type(new_node.node) == type(node.node) # noqa + and new_node.node + and node.node + and new_node.node.fullname == node.node.fullname + and new_node.kind == node.kind + ): replacements[new_node.node] = node.node if isinstance(node.node, TypeInfo) and isinstance(new_node.node, TypeInfo): type_repl = replacement_map_from_symbol_table( - node.node.names, - new_node.node.names, - prefix) + node.node.names, new_node.node.names, prefix + ) replacements.update(type_repl) return replacements -def replace_nodes_in_ast(node: SymbolNode, - replacements: Dict[SymbolNode, SymbolNode]) -> SymbolNode: +def replace_nodes_in_ast( + node: SymbolNode, replacements: Dict[SymbolNode, SymbolNode] +) -> SymbolNode: """Replace all references to replacement map keys within an AST node, recursively. Also replace the *identity* of any nodes that have replacements. Return the @@ -136,7 +189,7 @@ def replace_nodes_in_ast(node: SymbolNode, return replacements.get(node, node) -SN = TypeVar('SN', bound=SymbolNode) +SN = TypeVar("SN", bound=SymbolNode) class NodeReplaceVisitor(TraverserVisitor): @@ -475,8 +528,9 @@ def fixup(self, node: SN) -> SN: return node -def replace_nodes_in_symbol_table(symbols: SymbolTable, - replacements: Dict[SymbolNode, SymbolNode]) -> None: +def replace_nodes_in_symbol_table( + symbols: SymbolTable, replacements: Dict[SymbolNode, SymbolNode] +) -> None: for name, node in symbols.items(): if node.node: if node.node in replacements: diff --git a/mypy/server/aststrip.py b/mypy/server/aststrip.py index 4363223c1cf0b..936765160e920 100644 --- a/mypy/server/aststrip.py +++ b/mypy/server/aststrip.py @@ -32,25 +32,45 @@ """ import contextlib -from typing import Union, Iterator, Optional, Dict, Tuple +from typing import Dict, Iterator, Optional, Tuple, Union from mypy.backports import nullcontext from mypy.nodes import ( - FuncDef, NameExpr, MemberExpr, RefExpr, MypyFile, ClassDef, AssignmentStmt, - ImportFrom, CallExpr, Decorator, OverloadedFuncDef, Node, TupleExpr, ListExpr, - SuperExpr, IndexExpr, ImportAll, ForStmt, Block, CLASSDEF_NO_INFO, TypeInfo, - StarExpr, Var, SymbolTableNode + CLASSDEF_NO_INFO, + AssignmentStmt, + Block, + CallExpr, + ClassDef, + Decorator, + ForStmt, + FuncDef, + ImportAll, + ImportFrom, + IndexExpr, + ListExpr, + MemberExpr, + MypyFile, + NameExpr, + Node, + OverloadedFuncDef, + RefExpr, + StarExpr, + SuperExpr, + SymbolTableNode, + TupleExpr, + TypeInfo, + Var, ) from mypy.traverser import TraverserVisitor from mypy.types import CallableType from mypy.typestate import TypeState - SavedAttributes = Dict[Tuple[ClassDef, str], SymbolTableNode] -def strip_target(node: Union[MypyFile, FuncDef, OverloadedFuncDef], - saved_attrs: SavedAttributes) -> None: +def strip_target( + node: Union[MypyFile, FuncDef, OverloadedFuncDef], saved_attrs: SavedAttributes +) -> None: """Reset a fine-grained incremental target to state before semantic analysis. All TypeInfos are killed. Therefore we need to preserve the variables @@ -91,7 +111,7 @@ def strip_file_top_level(self, file_node: MypyFile) -> None: for name in file_node.names.copy(): # TODO: this is a hot fix, we should delete all names, # see https://github.com/python/mypy/issues/6422. - if '@' not in name: + if "@" not in name: del file_node.names[name] def visit_block(self, b: Block) -> None: @@ -113,8 +133,9 @@ def visit_class_def(self, node: ClassDef) -> None: node.type_vars = [] node.base_type_exprs.extend(node.removed_base_type_exprs) node.removed_base_type_exprs = [] - node.defs.body = [s for s in node.defs.body - if s not in to_delete] # type: ignore[comparison-overlap] + node.defs.body = [ + s for s in node.defs.body if s not in to_delete # type: ignore[comparison-overlap] + ] with self.enter_class(node.info): super().visit_class_def(node) TypeState.reset_subtype_caches_for(node.info) diff --git a/mypy/server/deps.py b/mypy/server/deps.py index f339344e79b58..078ce9bb8c7fa 100644 --- a/mypy/server/deps.py +++ b/mypy/server/deps.py @@ -79,51 +79,122 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a Test cases for this module live in 'test-data/unit/deps*.test'. """ -from typing import Dict, List, Set, Optional, Tuple +from typing import Dict, List, Optional, Set, Tuple + from typing_extensions import DefaultDict from mypy.checkmember import bind_self from mypy.nodes import ( - Node, Expression, MypyFile, FuncDef, ClassDef, AssignmentStmt, NameExpr, MemberExpr, Import, - ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr, - ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt, - TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block, - TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr, - LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr, + GDEF, + LDEF, + MDEF, AssertTypeExpr, + AssignmentStmt, + AwaitExpr, + Block, + CallExpr, + CastExpr, + ClassDef, + ComparisonExpr, + Decorator, + DelStmt, + DictionaryComprehension, + EnumCallExpr, + Expression, + ForStmt, + FuncBase, + FuncDef, + GeneratorExpr, + Import, + ImportAll, + ImportFrom, + IndexExpr, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + Node, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + PrintStmt, + RefExpr, + StarExpr, + SuperExpr, + TupleExpr, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeInfo, + TypeVarExpr, + UnaryExpr, + Var, + WithStmt, + YieldFromExpr, ) from mypy.operators import ( - op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods + op_methods, + ops_with_inplace_method, + reverse_op_methods, + unary_op_methods, ) +from mypy.options import Options +from mypy.scope import Scope +from mypy.server.trigger import make_trigger, make_wildcard_trigger from mypy.traverser import TraverserVisitor from mypy.types import ( - Type, Instance, AnyType, NoneType, TypeVisitor, CallableType, DeletedType, PartialType, - TupleType, TypeType, TypeVarType, TypedDictType, UnboundType, UninhabitedType, UnionType, - FunctionLike, Overloaded, TypeOfAny, LiteralType, ErasedType, get_proper_type, ProperType, - TypeAliasType, ParamSpecType, Parameters, UnpackType, TypeVarTupleType, + AnyType, + CallableType, + DeletedType, + ErasedType, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, ) -from mypy.server.trigger import make_trigger, make_wildcard_trigger -from mypy.util import correct_relative_import -from mypy.scope import Scope from mypy.typestate import TypeState -from mypy.options import Options +from mypy.util import correct_relative_import -def get_dependencies(target: MypyFile, - type_map: Dict[Expression, Type], - python_version: Tuple[int, int], - options: Options) -> Dict[str, Set[str]]: +def get_dependencies( + target: MypyFile, + type_map: Dict[Expression, Type], + python_version: Tuple[int, int], + options: Options, +) -> Dict[str, Set[str]]: """Get all dependencies of a node, recursively.""" visitor = DependencyVisitor(type_map, python_version, target.alias_deps, options) target.accept(visitor) return visitor.map -def get_dependencies_of_target(module_id: str, - module_tree: MypyFile, - target: Node, - type_map: Dict[Expression, Type], - python_version: Tuple[int, int]) -> Dict[str, Set[str]]: +def get_dependencies_of_target( + module_id: str, + module_tree: MypyFile, + target: Node, + type_map: Dict[Expression, Type], + python_version: Tuple[int, int], +) -> Dict[str, Set[str]]: """Get dependencies of a target -- don't recursive into nested targets.""" # TODO: Add tests for this function. visitor = DependencyVisitor(type_map, python_version, module_tree.alias_deps) @@ -146,11 +217,13 @@ def get_dependencies_of_target(module_id: str, class DependencyVisitor(TraverserVisitor): - def __init__(self, - type_map: Dict[Expression, Type], - python_version: Tuple[int, int], - alias_deps: 'DefaultDict[str, Set[str]]', - options: Optional[Options] = None) -> None: + def __init__( + self, + type_map: Dict[Expression, Type], + python_version: Tuple[int, int], + alias_deps: "DefaultDict[str, Set[str]]", + options: Optional[Options] = None, + ) -> None: self.scope = Scope() self.type_map = type_map self.python2 = python_version[0] == 2 @@ -192,8 +265,8 @@ def visit_func_def(self, o: FuncDef) -> None: for base in non_trivial_bases(o.info): # Base class __init__/__new__ doesn't generate a logical # dependency since the override can be incompatible. - if not self.use_logical_deps() or o.name not in ('__init__', '__new__'): - self.add_dependency(make_trigger(base.fullname + '.' + o.name)) + if not self.use_logical_deps() or o.name not in ("__init__", "__new__"): + self.add_dependency(make_trigger(base.fullname + "." + o.name)) self.add_type_alias_deps(self.scope.current_target()) super().visit_func_def(o) variants = set(o.expanded) - {o} @@ -219,8 +292,11 @@ def visit_decorator(self, o: Decorator) -> None: tname: Optional[str] = None if isinstance(d, RefExpr) and d.fullname is not None: tname = d.fullname - if (isinstance(d, CallExpr) and isinstance(d.callee, RefExpr) and - d.callee.fullname is not None): + if ( + isinstance(d, CallExpr) + and isinstance(d.callee, RefExpr) + and d.callee.fullname is not None + ): tname = d.callee.fullname if tname is not None: self.add_dependency(make_trigger(tname), make_trigger(o.func.fullname)) @@ -266,8 +342,9 @@ def process_type_info(self, info: TypeInfo) -> None: # # In this example we add -> , to invalidate Sub if # a new member is added to Super. - self.add_dependency(make_wildcard_trigger(base_info.fullname), - target=make_trigger(target)) + self.add_dependency( + make_wildcard_trigger(base_info.fullname), target=make_trigger(target) + ) # More protocol dependencies are collected in TypeState._snapshot_protocol_deps # after a full run or update is finished. @@ -276,12 +353,14 @@ def process_type_info(self, info: TypeInfo) -> None: if isinstance(node.node, Var): # Recheck Liskov if needed, self definitions are checked in the defining method if node.node.is_initialized_in_class and has_user_bases(info): - self.add_dependency(make_trigger(info.fullname + '.' + name)) + self.add_dependency(make_trigger(info.fullname + "." + name)) for base_info in non_trivial_bases(info): # If the type of an attribute changes in a base class, we make references # to the attribute in the subclass stale. - self.add_dependency(make_trigger(base_info.fullname + '.' + name), - target=make_trigger(info.fullname + '.' + name)) + self.add_dependency( + make_trigger(base_info.fullname + "." + name), + target=make_trigger(info.fullname + "." + name), + ) for base_info in non_trivial_bases(info): for name, node in base_info.names.items(): if self.use_logical_deps(): @@ -298,26 +377,34 @@ def process_type_info(self, info: TypeInfo) -> None: continue # __init__ and __new__ can be overridden with different signatures, so no # logical dependency. - if name in ('__init__', '__new__'): + if name in ("__init__", "__new__"): continue - self.add_dependency(make_trigger(base_info.fullname + '.' + name), - target=make_trigger(info.fullname + '.' + name)) + self.add_dependency( + make_trigger(base_info.fullname + "." + name), + target=make_trigger(info.fullname + "." + name), + ) if not self.use_logical_deps(): # These dependencies are only useful for propagating changes -- # they aren't logical dependencies since __init__ and __new__ can be # overridden with a different signature. - self.add_dependency(make_trigger(base_info.fullname + '.__init__'), - target=make_trigger(info.fullname + '.__init__')) - self.add_dependency(make_trigger(base_info.fullname + '.__new__'), - target=make_trigger(info.fullname + '.__new__')) + self.add_dependency( + make_trigger(base_info.fullname + ".__init__"), + target=make_trigger(info.fullname + ".__init__"), + ) + self.add_dependency( + make_trigger(base_info.fullname + ".__new__"), + target=make_trigger(info.fullname + ".__new__"), + ) # If the set of abstract attributes change, this may invalidate class # instantiation, or change the generated error message, since Python checks # class abstract status when creating an instance. - self.add_dependency(make_trigger(base_info.fullname + '.(abstract)'), - target=make_trigger(info.fullname + '.__init__')) + self.add_dependency( + make_trigger(base_info.fullname + ".(abstract)"), + target=make_trigger(info.fullname + ".__init__"), + ) # If the base class abstract attributes change, subclass abstract # attributes need to be recalculated. - self.add_dependency(make_trigger(base_info.fullname + '.(abstract)')) + self.add_dependency(make_trigger(base_info.fullname + ".(abstract)")) def visit_import(self, o: Import) -> None: for id, as_id in o.ids: @@ -327,19 +414,17 @@ def visit_import_from(self, o: ImportFrom) -> None: if self.use_logical_deps(): # Just importing a name doesn't create a logical dependency. return - module_id, _ = correct_relative_import(self.scope.current_module_id(), - o.relative, - o.id, - self.is_package_init_file) + module_id, _ = correct_relative_import( + self.scope.current_module_id(), o.relative, o.id, self.is_package_init_file + ) self.add_dependency(make_trigger(module_id)) # needed if module is added/removed for name, as_name in o.names: - self.add_dependency(make_trigger(module_id + '.' + name)) + self.add_dependency(make_trigger(module_id + "." + name)) def visit_import_all(self, o: ImportAll) -> None: - module_id, _ = correct_relative_import(self.scope.current_module_id(), - o.relative, - o.id, - self.is_package_init_file) + module_id, _ = correct_relative_import( + self.scope.current_module_id(), o.relative, o.id, self.is_package_init_file + ) # The current target needs to be rechecked if anything "significant" changes in the # target module namespace (as the imported definitions will need to be updated). self.add_dependency(make_wildcard_trigger(module_id)) @@ -352,8 +437,9 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: rvalue = o.rvalue if isinstance(rvalue, CallExpr) and isinstance(rvalue.analyzed, TypeVarExpr): analyzed = rvalue.analyzed - self.add_type_dependencies(analyzed.upper_bound, - target=make_trigger(analyzed.fullname)) + self.add_type_dependencies( + analyzed.upper_bound, target=make_trigger(analyzed.fullname) + ) for val in analyzed.values: self.add_type_dependencies(val, target=make_trigger(analyzed.fullname)) # We need to re-analyze the definition if bound or value is deleted. @@ -361,20 +447,20 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: elif isinstance(rvalue, CallExpr) and isinstance(rvalue.analyzed, NamedTupleExpr): # Depend on types of named tuple items. info = rvalue.analyzed.info - prefix = f'{self.scope.current_full_target()}.{info.name}' + prefix = f"{self.scope.current_full_target()}.{info.name}" for name, symnode in info.names.items(): - if not name.startswith('_') and isinstance(symnode.node, Var): + if not name.startswith("_") and isinstance(symnode.node, Var): typ = symnode.node.type if typ: self.add_type_dependencies(typ) self.add_type_dependencies(typ, target=make_trigger(prefix)) - attr_target = make_trigger(f'{prefix}.{name}') + attr_target = make_trigger(f"{prefix}.{name}") self.add_type_dependencies(typ, target=attr_target) elif isinstance(rvalue, CallExpr) and isinstance(rvalue.analyzed, TypedDictExpr): # Depend on the underlying typeddict type info = rvalue.analyzed.info assert info.typeddict_type is not None - prefix = f'{self.scope.current_full_target()}.{info.name}' + prefix = f"{self.scope.current_full_target()}.{info.name}" self.add_type_dependencies(info.typeddict_type, target=make_trigger(prefix)) elif isinstance(rvalue, CallExpr) and isinstance(rvalue.analyzed, EnumCallExpr): # Enum values are currently not checked, but for future we add the deps on them @@ -388,8 +474,8 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: typ = get_proper_type(self.type_map.get(lvalue)) if isinstance(typ, FunctionLike) and typ.is_type_obj(): class_name = typ.type_object().fullname - self.add_dependency(make_trigger(class_name + '.__init__')) - self.add_dependency(make_trigger(class_name + '.__new__')) + self.add_dependency(make_trigger(class_name + ".__init__")) + self.add_dependency(make_trigger(class_name + ".__new__")) if isinstance(rvalue, IndexExpr) and isinstance(rvalue.analyzed, TypeAliasExpr): self.add_type_dependencies(rvalue.analyzed.type) elif typ: @@ -404,7 +490,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: lvalue = items[i] rvalue = items[i + 1] if isinstance(lvalue, TupleExpr): - self.add_attribute_dependency_for_expr(rvalue, '__iter__') + self.add_attribute_dependency_for_expr(rvalue, "__iter__") if o.type: self.add_type_dependencies(o.type) if self.use_logical_deps() and o.unanalyzed_type is None: @@ -412,12 +498,15 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: # x = func(...) # we add a logical dependency -> , because if `func` is not annotated, # then it will make all points of use of `x` unchecked. - if (isinstance(rvalue, CallExpr) and isinstance(rvalue.callee, RefExpr) - and rvalue.callee.fullname is not None): + if ( + isinstance(rvalue, CallExpr) + and isinstance(rvalue.callee, RefExpr) + and rvalue.callee.fullname is not None + ): fname: Optional[str] = None if isinstance(rvalue.callee.node, TypeInfo): # use actual __init__ as a dependency source - init = rvalue.callee.node.get('__init__') + init = rvalue.callee.node.get("__init__") if init and isinstance(init.node, FuncBase): fname = init.node.fullname else: @@ -433,15 +522,16 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: def process_lvalue(self, lvalue: Expression) -> None: """Generate additional dependencies for an lvalue.""" if isinstance(lvalue, IndexExpr): - self.add_operator_method_dependency(lvalue.base, '__setitem__') + self.add_operator_method_dependency(lvalue.base, "__setitem__") elif isinstance(lvalue, NameExpr): if lvalue.kind in (MDEF, GDEF): # Assignment to an attribute in the class body, or direct assignment to a # global variable. lvalue_type = self.get_non_partial_lvalue_type(lvalue) type_triggers = self.get_type_triggers(lvalue_type) - attr_trigger = make_trigger('{}.{}'.format(self.scope.current_full_target(), - lvalue.name)) + attr_trigger = make_trigger( + "{}.{}".format(self.scope.current_full_target(), lvalue.name) + ) for type_trigger in type_triggers: self.add_dependency(type_trigger, attr_trigger) elif isinstance(lvalue, MemberExpr): @@ -451,7 +541,7 @@ def process_lvalue(self, lvalue: Expression) -> None: info = node.info if info and has_user_bases(info): # Recheck Liskov for self definitions - self.add_dependency(make_trigger(info.fullname + '.' + lvalue.name)) + self.add_dependency(make_trigger(info.fullname + "." + lvalue.name)) if lvalue.kind is None: # Reference to a non-module attribute if lvalue.expr not in self.type_map: @@ -503,7 +593,7 @@ def visit_operator_assignment_stmt(self, o: OperatorAssignmentStmt) -> None: method = op_methods[o.op] self.add_attribute_dependency_for_expr(o.lvalue, method) if o.op in ops_with_inplace_method: - inplace_method = '__i' + method[2:] + inplace_method = "__i" + method[2:] self.add_attribute_dependency_for_expr(o.lvalue, inplace_method) def visit_for_stmt(self, o: ForStmt) -> None: @@ -511,18 +601,18 @@ def visit_for_stmt(self, o: ForStmt) -> None: if not o.is_async: # __getitem__ is only used if __iter__ is missing but for simplicity we # just always depend on both. - self.add_attribute_dependency_for_expr(o.expr, '__iter__') - self.add_attribute_dependency_for_expr(o.expr, '__getitem__') + self.add_attribute_dependency_for_expr(o.expr, "__iter__") + self.add_attribute_dependency_for_expr(o.expr, "__getitem__") if o.inferred_iterator_type: if self.python2: - method = 'next' + method = "next" else: - method = '__next__' + method = "__next__" self.add_attribute_dependency(o.inferred_iterator_type, method) else: - self.add_attribute_dependency_for_expr(o.expr, '__aiter__') + self.add_attribute_dependency_for_expr(o.expr, "__aiter__") if o.inferred_iterator_type: - self.add_attribute_dependency(o.inferred_iterator_type, '__anext__') + self.add_attribute_dependency(o.inferred_iterator_type, "__anext__") self.process_lvalue(o.index) if isinstance(o.index, TupleExpr): @@ -530,8 +620,8 @@ def visit_for_stmt(self, o: ForStmt) -> None: item_type = o.inferred_item_type if item_type: # This is similar to above. - self.add_attribute_dependency(item_type, '__iter__') - self.add_attribute_dependency(item_type, '__getitem__') + self.add_attribute_dependency(item_type, "__iter__") + self.add_attribute_dependency(item_type, "__getitem__") if o.index_type: self.add_type_dependencies(o.index_type) @@ -539,23 +629,23 @@ def visit_with_stmt(self, o: WithStmt) -> None: super().visit_with_stmt(o) for e in o.expr: if not o.is_async: - self.add_attribute_dependency_for_expr(e, '__enter__') - self.add_attribute_dependency_for_expr(e, '__exit__') + self.add_attribute_dependency_for_expr(e, "__enter__") + self.add_attribute_dependency_for_expr(e, "__exit__") else: - self.add_attribute_dependency_for_expr(e, '__aenter__') - self.add_attribute_dependency_for_expr(e, '__aexit__') + self.add_attribute_dependency_for_expr(e, "__aenter__") + self.add_attribute_dependency_for_expr(e, "__aexit__") for typ in o.analyzed_types: self.add_type_dependencies(typ) def visit_print_stmt(self, o: PrintStmt) -> None: super().visit_print_stmt(o) if o.target: - self.add_attribute_dependency_for_expr(o.target, 'write') + self.add_attribute_dependency_for_expr(o.target, "write") def visit_del_stmt(self, o: DelStmt) -> None: super().visit_del_stmt(o) if isinstance(o.expr, IndexExpr): - self.add_attribute_dependency_for_expr(o.expr.base, '__delitem__') + self.add_attribute_dependency_for_expr(o.expr.base, "__delitem__") # Expressions @@ -570,8 +660,8 @@ def process_global_ref_expr(self, o: RefExpr) -> None: typ = get_proper_type(self.type_map.get(o)) if isinstance(typ, FunctionLike) and typ.is_type_obj(): class_name = typ.type_object().fullname - self.add_dependency(make_trigger(class_name + '.__init__')) - self.add_dependency(make_trigger(class_name + '.__new__')) + self.add_dependency(make_trigger(class_name + ".__init__")) + self.add_dependency(make_trigger(class_name + ".__new__")) def visit_name_expr(self, o: NameExpr) -> None: if o.kind == LDEF: @@ -601,7 +691,7 @@ def visit_member_expr(self, e: MemberExpr) -> None: return if isinstance(e.expr, RefExpr) and isinstance(e.expr.node, MypyFile): # Special case: reference to a missing module attribute. - self.add_dependency(make_trigger(e.expr.node.fullname + '.' + e.name)) + self.add_dependency(make_trigger(e.expr.node.fullname + "." + e.name)) return typ = get_proper_type(self.type_map[e.expr]) self.add_attribute_dependency(typ, e.name) @@ -623,13 +713,13 @@ def get_unimported_fullname(self, e: MemberExpr, typ: AnyType) -> Optional[str]: Return None if e doesn't refer to an unimported definition or if we can't determine the name. """ - suffix = '' + suffix = "" # Unwrap nested member expression to handle cases like "a.b.c.d" where # "a.b" is a known reference to an unimported module. Find the base # reference to an unimported module (such as "a.b") and the name suffix # (such as "c.d") needed to build a full name. while typ.type_of_any == TypeOfAny.from_another_any and isinstance(e.expr, MemberExpr): - suffix = '.' + e.name + suffix + suffix = "." + e.name + suffix e = e.expr if e.expr not in self.type_map: return None @@ -640,7 +730,7 @@ def get_unimported_fullname(self, e: MemberExpr, typ: AnyType) -> Optional[str]: typ = obj_type if typ.type_of_any == TypeOfAny.from_unimported_type and typ.missing_import_name: # Infer the full name of the unimported definition. - return typ.missing_import_name + '.' + e.name + suffix + return typ.missing_import_name + "." + e.name + suffix return None def visit_super_expr(self, e: SuperExpr) -> None: @@ -650,7 +740,7 @@ def visit_super_expr(self, e: SuperExpr) -> None: if e.info is not None: name = e.name for base in non_trivial_bases(e.info): - self.add_dependency(make_trigger(base.fullname + '.' + name)) + self.add_dependency(make_trigger(base.fullname + "." + name)) if name in base.names: # No need to depend on further base classes, since we found # the target. This is safe since if the target gets @@ -658,7 +748,7 @@ def visit_super_expr(self, e: SuperExpr) -> None: break def visit_call_expr(self, e: CallExpr) -> None: - if isinstance(e.callee, RefExpr) and e.callee.fullname == 'builtins.isinstance': + if isinstance(e.callee, RefExpr) and e.callee.fullname == "builtins.isinstance": self.process_isinstance_call(e) else: super().visit_call_expr(e) @@ -666,16 +756,18 @@ def visit_call_expr(self, e: CallExpr) -> None: if typ is not None: typ = get_proper_type(typ) if not isinstance(typ, FunctionLike): - self.add_attribute_dependency(typ, '__call__') + self.add_attribute_dependency(typ, "__call__") def process_isinstance_call(self, e: CallExpr) -> None: """Process "isinstance(...)" in a way to avoid some extra dependencies.""" if len(e.args) == 2: arg = e.args[1] - if (isinstance(arg, RefExpr) - and arg.kind == GDEF - and isinstance(arg.node, TypeInfo) - and arg.fullname): + if ( + isinstance(arg, RefExpr) + and arg.kind == GDEF + and isinstance(arg.node, TypeInfo) + and arg.fullname + ): # Special case to avoid redundant dependencies from "__init__". self.add_dependency(make_trigger(arg.fullname)) return @@ -698,7 +790,7 @@ def visit_type_application(self, e: TypeApplication) -> None: def visit_index_expr(self, e: IndexExpr) -> None: super().visit_index_expr(e) - self.add_operator_method_dependency(e.base, '__getitem__') + self.add_operator_method_dependency(e.base, "__getitem__") def visit_unary_expr(self, e: UnaryExpr) -> None: super().visit_unary_expr(e) @@ -717,14 +809,14 @@ def visit_comparison_expr(self, e: ComparisonExpr) -> None: left = e.operands[i] right = e.operands[i + 1] self.process_binary_op(op, left, right) - if self.python2 and op in ('==', '!=', '<', '<=', '>', '>='): - self.add_operator_method_dependency(left, '__cmp__') - self.add_operator_method_dependency(right, '__cmp__') + if self.python2 and op in ("==", "!=", "<", "<=", ">", ">="): + self.add_operator_method_dependency(left, "__cmp__") + self.add_operator_method_dependency(right, "__cmp__") def process_binary_op(self, op: str, left: Expression, right: Expression) -> None: method = op_methods.get(op) if method: - if op == 'in': + if op == "in": self.add_operator_method_dependency(right, method) else: self.add_operator_method_dependency(left, method) @@ -745,7 +837,7 @@ def add_operator_method_dependency_for_type(self, typ: ProperType, method: str) if isinstance(typ, TupleType): typ = typ.partial_fallback if isinstance(typ, Instance): - trigger = make_trigger(typ.type.fullname + '.' + method) + trigger = make_trigger(typ.type.fullname + "." + method) self.add_dependency(trigger) elif isinstance(typ, UnionType): for item in typ.items: @@ -776,7 +868,7 @@ def visit_yield_from_expr(self, e: YieldFromExpr) -> None: def visit_await_expr(self, e: AwaitExpr) -> None: super().visit_await_expr(e) - self.add_attribute_dependency_for_expr(e.expr, '__await__') + self.add_attribute_dependency_for_expr(e.expr, "__await__") # Helpers @@ -792,8 +884,9 @@ def add_dependency(self, trigger: str, target: Optional[str] = None) -> None: If the target is not given explicitly, use the current target. """ - if trigger.startswith((' List[str]: if isinstance(typ, TupleType): typ = typ.partial_fallback if isinstance(typ, Instance): - member = f'{typ.type.fullname}.{name}' + member = f"{typ.type.fullname}.{name}" return [make_trigger(member)] elif isinstance(typ, FunctionLike) and typ.is_type_obj(): - member = f'{typ.type_object().fullname}.{name}' + member = f"{typ.type_object().fullname}.{name}" triggers = [make_trigger(member)] triggers.extend(self.attribute_triggers(typ.fallback, name)) return triggers @@ -842,9 +935,9 @@ def attribute_triggers(self, typ: Type, name: str) -> List[str]: elif isinstance(typ, TypeType): triggers = self.attribute_triggers(typ.item, name) if isinstance(typ.item, Instance) and typ.item.type.metaclass_type is not None: - triggers.append(make_trigger('%s.%s' % - (typ.item.type.metaclass_type.type.fullname, - name))) + triggers.append( + make_trigger("%s.%s" % (typ.item.type.metaclass_type.type.fullname, name)) + ) return triggers else: return [] @@ -857,7 +950,7 @@ def add_attribute_dependency_for_expr(self, e: Expression, name: str) -> None: def add_iter_dependency(self, node: Expression) -> None: typ = self.type_map.get(node) if typ: - self.add_attribute_dependency(typ, '__iter__') + self.add_attribute_dependency(typ, "__iter__") def use_logical_deps(self) -> bool: return self.options is not None and self.options.logical_deps @@ -945,8 +1038,8 @@ def visit_type_type(self, typ: TypeType) -> List[str]: if not self.use_logical_deps: old_triggers = triggers[:] for trigger in old_triggers: - triggers.append(trigger.rstrip('>') + '.__init__>') - triggers.append(trigger.rstrip('>') + '.__new__>') + triggers.append(trigger.rstrip(">") + ".__init__>") + triggers.append(trigger.rstrip(">") + ".__new__>") return triggers def visit_type_var(self, typ: TypeVarType) -> List[str]: @@ -1005,31 +1098,31 @@ def visit_union_type(self, typ: UnionType) -> List[str]: return triggers -def merge_dependencies(new_deps: Dict[str, Set[str]], - deps: Dict[str, Set[str]]) -> None: +def merge_dependencies(new_deps: Dict[str, Set[str]], deps: Dict[str, Set[str]]) -> None: for trigger, targets in new_deps.items(): deps.setdefault(trigger, set()).update(targets) def non_trivial_bases(info: TypeInfo) -> List[TypeInfo]: - return [base for base in info.mro[1:] - if base.fullname != 'builtins.object'] + return [base for base in info.mro[1:] if base.fullname != "builtins.object"] def has_user_bases(info: TypeInfo) -> bool: - return any(base.module_name not in ('builtins', 'typing', 'enum') for base in info.mro[1:]) + return any(base.module_name not in ("builtins", "typing", "enum") for base in info.mro[1:]) -def dump_all_dependencies(modules: Dict[str, MypyFile], - type_map: Dict[Expression, Type], - python_version: Tuple[int, int], - options: Options) -> None: +def dump_all_dependencies( + modules: Dict[str, MypyFile], + type_map: Dict[Expression, Type], + python_version: Tuple[int, int], + options: Options, +) -> None: """Generate dependencies for all interesting modules and print them to stdout.""" all_deps: Dict[str, Set[str]] = {} for id, node in modules.items(): # Uncomment for debugging: # print('processing', id) - if id in ('builtins', 'typing') or '/typeshed/' in node.path: + if id in ("builtins", "typing") or "/typeshed/" in node.path: continue assert id == node.fullname deps = get_dependencies(node, type_map, python_version, options) @@ -1040,4 +1133,4 @@ def dump_all_dependencies(modules: Dict[str, MypyFile], for trigger, targets in sorted(all_deps.items(), key=lambda x: x[0]): print(trigger) for target in sorted(targets): - print(f' {target}') + print(f" {target}") diff --git a/mypy/server/mergecheck.py b/mypy/server/mergecheck.py index 41d19f60f436f..44db789a71059 100644 --- a/mypy/server/mergecheck.py +++ b/mypy/server/mergecheck.py @@ -1,10 +1,11 @@ """Check for duplicate AST nodes after merge.""" from typing import Dict, List, Tuple + from typing_extensions import Final -from mypy.nodes import FakeInfo, SymbolNode, Var, Decorator, FuncDef -from mypy.server.objgraph import get_reachable_graph, get_path +from mypy.nodes import Decorator, FakeInfo, FuncDef, SymbolNode, Var +from mypy.server.objgraph import get_path, get_reachable_graph # If True, print more verbose output on failure. DUMP_MISMATCH_NODES: Final = False @@ -50,34 +51,35 @@ def check_consistency(o: object) -> None: path2 = get_path(sym2, seen, parents) if fn in m: - print('\nDuplicate {!r} nodes with fullname {!r} found:'.format( - type(sym).__name__, fn)) - print('[1] %d: %s' % (id(sym1), path_to_str(path1))) - print('[2] %d: %s' % (id(sym2), path_to_str(path2))) + print( + "\nDuplicate {!r} nodes with fullname {!r} found:".format(type(sym).__name__, fn) + ) + print("[1] %d: %s" % (id(sym1), path_to_str(path1))) + print("[2] %d: %s" % (id(sym2), path_to_str(path2))) if DUMP_MISMATCH_NODES and fn in m: # Add verbose output with full AST node contents. - print('---') + print("---") print(id(sym1), sym1) - print('---') + print("---") print(id(sym2), sym2) assert sym.fullname not in m def path_to_str(path: List[Tuple[object, object]]) -> str: - result = '' + result = "" for attr, obj in path: t = type(obj).__name__ - if t in ('dict', 'tuple', 'SymbolTable', 'list'): - result += f'[{repr(attr)}]' + if t in ("dict", "tuple", "SymbolTable", "list"): + result += f"[{repr(attr)}]" else: if isinstance(obj, Var): - result += f'.{attr}({t}:{obj.name})' - elif t in ('BuildManager', 'FineGrainedBuildManager'): + result += f".{attr}({t}:{obj.name})" + elif t in ("BuildManager", "FineGrainedBuildManager"): # Omit class name for some classes that aren't part of a class # hierarchy since there isn't much ambiguity. - result += f'.{attr}' + result += f".{attr}" else: - result += f'.{attr}({t})' + result += f".{attr}({t})" return result diff --git a/mypy/server/objgraph.py b/mypy/server/objgraph.py index 236f70d04e38a..053c26eef1d15 100644 --- a/mypy/server/objgraph.py +++ b/mypy/server/objgraph.py @@ -1,10 +1,10 @@ """Find all objects reachable from a root object.""" -from collections.abc import Iterable -import weakref import types +import weakref +from collections.abc import Iterable +from typing import Dict, Iterator, List, Mapping, Tuple -from typing import List, Dict, Iterator, Tuple, Mapping from typing_extensions import Final method_descriptor_type: Final = type(object.__dir__) @@ -20,35 +20,16 @@ method_wrapper_type, ) -ATTR_BLACKLIST: Final = { - '__doc__', - '__name__', - '__class__', - '__dict__', -} +ATTR_BLACKLIST: Final = {"__doc__", "__name__", "__class__", "__dict__"} # Instances of these types can't have references to other objects -ATOMIC_TYPE_BLACKLIST: Final = { - bool, - int, - float, - str, - type(None), - object, -} +ATOMIC_TYPE_BLACKLIST: Final = {bool, int, float, str, type(None), object} # Don't look at most attributes of these types -COLLECTION_TYPE_BLACKLIST: Final = { - list, - set, - dict, - tuple, -} +COLLECTION_TYPE_BLACKLIST: Final = {list, set, dict, tuple} # Don't return these objects -TYPE_BLACKLIST: Final = { - weakref.ReferenceType, -} +TYPE_BLACKLIST: Final = {weakref.ReferenceType} def isproperty(o: object, attr: str) -> bool: @@ -57,7 +38,7 @@ def isproperty(o: object, attr: str) -> bool: def get_edge_candidates(o: object) -> Iterator[Tuple[object, object]]: # use getattr because mypyc expects dict, not mappingproxy - if '__getattribute__' in getattr(type(o), '__dict__'): # noqa + if "__getattribute__" in getattr(type(o), "__dict__"): # noqa return if type(o) not in COLLECTION_TYPE_BLACKLIST: for attr in dir(o): @@ -77,23 +58,22 @@ def get_edge_candidates(o: object) -> Iterator[Tuple[object, object]]: def get_edges(o: object) -> Iterator[Tuple[object, object]]: for s, e in get_edge_candidates(o): - if (isinstance(e, FUNCTION_TYPES)): + if isinstance(e, FUNCTION_TYPES): # We don't want to collect methods, but do want to collect values # in closures and self pointers to other objects - if hasattr(e, '__closure__'): - yield (s, '__closure__'), e.__closure__ # type: ignore - if hasattr(e, '__self__'): + if hasattr(e, "__closure__"): + yield (s, "__closure__"), e.__closure__ # type: ignore + if hasattr(e, "__self__"): se = e.__self__ # type: ignore - if se is not o and se is not type(o) and hasattr(s, '__self__'): + if se is not o and se is not type(o) and hasattr(s, "__self__"): yield s.__self__, se # type: ignore else: if not type(e) in TYPE_BLACKLIST: yield s, e -def get_reachable_graph(root: object) -> Tuple[Dict[int, object], - Dict[int, Tuple[int, object]]]: +def get_reachable_graph(root: object) -> Tuple[Dict[int, object], Dict[int, Tuple[int, object]]]: parents = {} seen = {id(root): root} worklist = [root] @@ -109,9 +89,9 @@ def get_reachable_graph(root: object) -> Tuple[Dict[int, object], return seen, parents -def get_path(o: object, - seen: Dict[int, object], - parents: Dict[int, Tuple[int, object]]) -> List[Tuple[object, object]]: +def get_path( + o: object, seen: Dict[int, object], parents: Dict[int, Tuple[int, object]] +) -> List[Tuple[object, object]]: path = [] while id(o) in parents: pid, attr = parents[id(o)] diff --git a/mypy/server/subexpr.py b/mypy/server/subexpr.py index 4078c4170fcfd..60ebe95e33b13 100644 --- a/mypy/server/subexpr.py +++ b/mypy/server/subexpr.py @@ -3,11 +3,35 @@ from typing import List from mypy.nodes import ( - Expression, Node, MemberExpr, YieldFromExpr, YieldExpr, CallExpr, OpExpr, ComparisonExpr, - SliceExpr, CastExpr, RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, - IndexExpr, GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension, - ConditionalExpr, TypeApplication, LambdaExpr, StarExpr, BackquoteExpr, AwaitExpr, - AssignmentExpr, AssertTypeExpr, + AssertTypeExpr, + AssignmentExpr, + AwaitExpr, + BackquoteExpr, + CallExpr, + CastExpr, + ComparisonExpr, + ConditionalExpr, + DictExpr, + DictionaryComprehension, + Expression, + GeneratorExpr, + IndexExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MemberExpr, + Node, + OpExpr, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + TupleExpr, + TypeApplication, + UnaryExpr, + YieldExpr, + YieldFromExpr, ) from mypy.traverser import TraverserVisitor diff --git a/mypy/server/target.py b/mypy/server/target.py index 1069a6703e77f..06987b551d6bd 100644 --- a/mypy/server/target.py +++ b/mypy/server/target.py @@ -1,8 +1,8 @@ def trigger_to_target(s: str) -> str: - assert s[0] == '<' + assert s[0] == "<" # Strip off the angle brackets s = s[1:-1] # If there is a [wildcard] or similar, strip that off too - if s[-1] == ']': - s = s.split('[')[0] + if s[-1] == "]": + s = s.split("[")[0] return s diff --git a/mypy/server/trigger.py b/mypy/server/trigger.py index bfd542a405373..3770780a458ba 100644 --- a/mypy/server/trigger.py +++ b/mypy/server/trigger.py @@ -9,7 +9,7 @@ def make_trigger(name: str) -> str: - return f'<{name}>' + return f"<{name}>" def make_wildcard_trigger(module: str) -> str: @@ -21,4 +21,4 @@ def make_wildcard_trigger(module: str) -> str: This is used for "from m import *" dependencies. """ - return f'<{module}{WILDCARD_TAG}>' + return f"<{module}{WILDCARD_TAG}>" diff --git a/mypy/server/update.py b/mypy/server/update.py index e50bb1d158a28..f40c47236b416 100644 --- a/mypy/server/update.py +++ b/mypy/server/update.py @@ -115,38 +115,48 @@ import os import sys import time -from typing import ( - Dict, List, Set, Tuple, Union, Optional, NamedTuple, Sequence, Callable -) +from typing import Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union + from typing_extensions import Final from mypy.build import ( - BuildManager, State, BuildResult, Graph, load_graph, - process_fresh_modules, DEBUG_FINE_GRAINED, + DEBUG_FINE_GRAINED, FAKE_ROOT_MODULE, + BuildManager, + BuildResult, + Graph, + State, + load_graph, + process_fresh_modules, ) -from mypy.modulefinder import BuildSource from mypy.checker import FineGrainedDeferredNode from mypy.errors import CompileError +from mypy.fscache import FileSystemCache +from mypy.modulefinder import BuildSource from mypy.nodes import ( - MypyFile, FuncDef, TypeInfo, SymbolNode, Decorator, - OverloadedFuncDef, SymbolTable, ImportFrom + Decorator, + FuncDef, + ImportFrom, + MypyFile, + OverloadedFuncDef, + SymbolNode, + SymbolTable, + TypeInfo, ) from mypy.options import Options -from mypy.fscache import FileSystemCache -from mypy.server.astdiff import ( - snapshot_symbol_table, compare_symbol_table_snapshots, SnapshotItem -) from mypy.semanal_main import ( - semantic_analysis_for_scc, semantic_analysis_for_targets, core_modules + core_modules, + semantic_analysis_for_scc, + semantic_analysis_for_targets, ) +from mypy.server.astdiff import SnapshotItem, compare_symbol_table_snapshots, snapshot_symbol_table from mypy.server.astmerge import merge_asts -from mypy.server.aststrip import strip_target, SavedAttributes +from mypy.server.aststrip import SavedAttributes, strip_target from mypy.server.deps import get_dependencies_of_target, merge_dependencies from mypy.server.target import trigger_to_target -from mypy.server.trigger import make_trigger, WILDCARD_TAG -from mypy.util import module_prefix, split_target +from mypy.server.trigger import WILDCARD_TAG, make_trigger from mypy.typestate import TypeState +from mypy.util import module_prefix, split_target MAX_ITER: Final = 1000 @@ -190,9 +200,9 @@ def __init__(self, result: BuildResult) -> None: # Targets processed during last update (for testing only). self.processed_targets: List[str] = [] - def update(self, - changed_modules: List[Tuple[str, str]], - removed_modules: List[Tuple[str, str]]) -> List[str]: + def update( + self, changed_modules: List[Tuple[str, str]], removed_modules: List[Tuple[str, str]] + ) -> List[str]: """Update previous build result by processing changed modules. Also propagate changes to other modules as needed, but only process @@ -226,17 +236,19 @@ def update(self, self.updated_modules = [] changed_modules = dedupe_modules(changed_modules + self.stale) initial_set = {id for id, _ in changed_modules} - self.manager.log_fine_grained('==== update %s ====' % ', '.join( - repr(id) for id, _ in changed_modules)) + self.manager.log_fine_grained( + "==== update %s ====" % ", ".join(repr(id) for id, _ in changed_modules) + ) if self.previous_targets_with_errors and is_verbose(self.manager): - self.manager.log_fine_grained('previous targets with errors: %s' % - sorted(self.previous_targets_with_errors)) + self.manager.log_fine_grained( + "previous targets with errors: %s" % sorted(self.previous_targets_with_errors) + ) blocking_error = None if self.blocking_error: # Handle blocking errors first. We'll exit as soon as we find a # module that still has blocking errors. - self.manager.log_fine_grained(f'existing blocker: {self.blocking_error[0]}') + self.manager.log_fine_grained(f"existing blocker: {self.blocking_error[0]}") changed_modules = dedupe_modules([self.blocking_error] + changed_modules) blocking_error = self.blocking_error[0] self.blocking_error = None @@ -262,8 +274,14 @@ def update(self, # when propagating changes from the errored targets, # which prevents us from reprocessing errors in it. changed_modules = propagate_changes_using_dependencies( - self.manager, self.graph, self.deps, set(), {next_id}, - self.previous_targets_with_errors, self.processed_targets) + self.manager, + self.graph, + self.deps, + set(), + {next_id}, + self.previous_targets_with_errors, + self.processed_targets, + ) changed_modules = dedupe_modules(changed_modules) if not changed_modules: # Preserve state needed for the next update. @@ -281,8 +299,14 @@ def trigger(self, target: str) -> List[str]: """ self.manager.errors.reset() changed_modules = propagate_changes_using_dependencies( - self.manager, self.graph, self.deps, set(), set(), - self.previous_targets_with_errors | {target}, []) + self.manager, + self.graph, + self.deps, + set(), + set(), + self.previous_targets_with_errors | {target}, + [], + ) # Preserve state needed for the next update. self.previous_targets_with_errors = self.manager.errors.targets() self.previous_messages = self.manager.errors.new_messages()[:] @@ -296,13 +320,13 @@ def flush_cache(self) -> None: """ self.manager.ast_cache.clear() - def update_one(self, - changed_modules: List[Tuple[str, str]], - initial_set: Set[str], - removed_set: Set[str], - blocking_error: Optional[str]) -> Tuple[List[Tuple[str, str]], - Tuple[str, str], - Optional[List[str]]]: + def update_one( + self, + changed_modules: List[Tuple[str, str]], + initial_set: Set[str], + removed_set: Set[str], + blocking_error: Optional[str], + ) -> Tuple[List[Tuple[str, str]], Tuple[str, str], Optional[List[str]]]: """Process a module from the list of changed modules. Returns: @@ -318,31 +342,31 @@ def update_one(self, # If we have a module with a blocking error that is no longer # in the import graph, we must skip it as otherwise we'll be # stuck with the blocking error. - if (next_id == blocking_error - and next_id not in self.previous_modules - and next_id not in initial_set): + if ( + next_id == blocking_error + and next_id not in self.previous_modules + and next_id not in initial_set + ): self.manager.log_fine_grained( - f'skip {next_id!r} (module with blocking error not in import graph)') + f"skip {next_id!r} (module with blocking error not in import graph)" + ) return changed_modules, (next_id, next_path), None result = self.update_module(next_id, next_path, next_id in removed_set) remaining, (next_id, next_path), blocker_messages = result - changed_modules = [(id, path) for id, path in changed_modules - if id != next_id] + changed_modules = [(id, path) for id, path in changed_modules if id != next_id] changed_modules = dedupe_modules(remaining + changed_modules) t1 = time.time() self.manager.log_fine_grained( - f"update once: {next_id} in {t1 - t0:.3f}s - {len(changed_modules)} left") + f"update once: {next_id} in {t1 - t0:.3f}s - {len(changed_modules)} left" + ) return changed_modules, (next_id, next_path), blocker_messages - def update_module(self, - module: str, - path: str, - force_removed: bool) -> Tuple[List[Tuple[str, str]], - Tuple[str, str], - Optional[List[str]]]: + def update_module( + self, module: str, path: str, force_removed: bool + ) -> Tuple[List[Tuple[str, str]], Tuple[str, str], Optional[List[str]]]: """Update a single modified module. If the module contains imports of previously unseen modules, only process one of @@ -361,7 +385,7 @@ def update_module(self, - Module which was actually processed as (id, path) tuple - If there was a blocking error, the error messages from it """ - self.manager.log_fine_grained(f'--- update single {module!r} ---') + self.manager.log_fine_grained(f"--- update single {module!r} ---") self.updated_modules.append(module) # builtins and friends could potentially get triggered because @@ -389,8 +413,9 @@ def update_module(self, manager.errors.reset() self.processed_targets.append(module) - result = update_module_isolated(module, path, manager, previous_modules, graph, - force_removed) + result = update_module_isolated( + module, path, manager, previous_modules, graph, force_removed + ) if isinstance(result, BlockedUpdate): # Blocking error -- just give up module, path, remaining, errors = result @@ -403,21 +428,23 @@ def update_module(self, t1 = time.time() triggered = calculate_active_triggers(manager, old_snapshots, {module: tree}) if is_verbose(self.manager): - filtered = [trigger for trigger in triggered - if not trigger.endswith('__>')] - self.manager.log_fine_grained(f'triggered: {sorted(filtered)!r}') + filtered = [trigger for trigger in triggered if not trigger.endswith("__>")] + self.manager.log_fine_grained(f"triggered: {sorted(filtered)!r}") self.triggered.extend(triggered | self.previous_targets_with_errors) if module in graph: graph[module].update_fine_grained_deps(self.deps) graph[module].free_state() remaining += propagate_changes_using_dependencies( - manager, graph, self.deps, triggered, + manager, + graph, + self.deps, + triggered, {module}, - targets_with_errors=set(), processed_targets=self.processed_targets) + targets_with_errors=set(), + processed_targets=self.processed_targets, + ) t2 = time.time() - manager.add_stats( - update_isolated_time=t1 - t0, - propagate_time=t2 - t1) + manager.add_stats(update_isolated_time=t1 - t0, propagate_time=t2 - t1) # Preserve state needed for the next update. self.previous_targets_with_errors.update(manager.errors.targets()) @@ -426,8 +453,9 @@ def update_module(self, return remaining, (module, path), None -def find_unloaded_deps(manager: BuildManager, graph: Dict[str, State], - initial: Sequence[str]) -> List[str]: +def find_unloaded_deps( + manager: BuildManager, graph: Dict[str, State], initial: Sequence[str] +) -> List[str]: """Find all the deps of the nodes in initial that haven't had their tree loaded. The key invariant here is that if a module is loaded, so are all @@ -453,8 +481,7 @@ def find_unloaded_deps(manager: BuildManager, graph: Dict[str, State], return unloaded -def ensure_deps_loaded(module: str, - deps: Dict[str, Set[str]], graph: Dict[str, State]) -> None: +def ensure_deps_loaded(module: str, deps: Dict[str, Set[str]], graph: Dict[str, State]) -> None: """Ensure that the dependencies on a module are loaded. Dependencies are loaded into the 'deps' dictionary. @@ -465,22 +492,26 @@ def ensure_deps_loaded(module: str, """ if module in graph and graph[module].fine_grained_deps_loaded: return - parts = module.split('.') + parts = module.split(".") for i in range(len(parts)): - base = '.'.join(parts[:i + 1]) + base = ".".join(parts[: i + 1]) if base in graph and not graph[base].fine_grained_deps_loaded: merge_dependencies(graph[base].load_fine_grained_deps(), deps) graph[base].fine_grained_deps_loaded = True -def ensure_trees_loaded(manager: BuildManager, graph: Dict[str, State], - initial: Sequence[str]) -> None: +def ensure_trees_loaded( + manager: BuildManager, graph: Dict[str, State], initial: Sequence[str] +) -> None: """Ensure that the modules in initial and their deps have loaded trees.""" to_process = find_unloaded_deps(manager, graph, initial) if to_process: if is_verbose(manager): - manager.log_fine_grained("Calling process_fresh_modules on set of size {} ({})".format( - len(to_process), sorted(to_process))) + manager.log_fine_grained( + "Calling process_fresh_modules on set of size {} ({})".format( + len(to_process), sorted(to_process) + ) + ) process_fresh_modules(graph, to_process, manager) @@ -511,12 +542,14 @@ class BlockedUpdate(NamedTuple): UpdateResult = Union[NormalUpdate, BlockedUpdate] -def update_module_isolated(module: str, - path: str, - manager: BuildManager, - previous_modules: Dict[str, str], - graph: Graph, - force_removed: bool) -> UpdateResult: +def update_module_isolated( + module: str, + path: str, + manager: BuildManager, + previous_modules: Dict[str, str], + graph: Graph, + force_removed: bool, +) -> UpdateResult: """Build a new version of one changed module only. Don't propagate changes to elsewhere in the program. Raise CompileError on @@ -533,7 +566,7 @@ def update_module_isolated(module: str, Returns a named tuple describing the result (see above for details). """ if module not in graph: - manager.log_fine_grained(f'new module {module!r}') + manager.log_fine_grained(f"new module {module!r}") if not manager.fscache.isfile(path) or force_removed: delete_module(module, path, graph, manager) @@ -591,7 +624,7 @@ def restore(ids: List[str]) -> None: remaining_modules = changed_modules # The remaining modules haven't been processed yet so drop them. restore([id for id, _ in remaining_modules]) - manager.log_fine_grained(f'--> {module!r} (newly imported)') + manager.log_fine_grained(f"--> {module!r} (newly imported)") else: remaining_modules = [] @@ -620,10 +653,7 @@ def restore(ids: List[str]) -> None: t2 = time.time() state.finish_passes() t3 = time.time() - manager.add_stats( - semanal_time=t1 - t0, - typecheck_time=t2 - t1, - finish_passes_time=t3 - t2) + manager.add_stats(semanal_time=t1 - t0, typecheck_time=t2 - t1, finish_passes_time=t3 - t2) graph[module] = state @@ -657,20 +687,17 @@ def find_relative_leaf_module(modules: List[Tuple[str, str]], graph: Graph) -> T return modules[0] -def delete_module(module_id: str, - path: str, - graph: Graph, - manager: BuildManager) -> None: - manager.log_fine_grained(f'delete module {module_id!r}') +def delete_module(module_id: str, path: str, graph: Graph, manager: BuildManager) -> None: + manager.log_fine_grained(f"delete module {module_id!r}") # TODO: Remove deps for the module (this only affects memory use, not correctness) if module_id in graph: del graph[module_id] if module_id in manager.modules: del manager.modules[module_id] - components = module_id.split('.') + components = module_id.split(".") if len(components) > 1: # Delete reference to module in parent module. - parent_id = '.'.join(components[:-1]) + parent_id = ".".join(components[:-1]) # If parent module is ignored, it won't be included in the modules dictionary. if parent_id in manager.modules: parent = manager.modules[parent_id] @@ -693,13 +720,12 @@ def dedupe_modules(modules: List[Tuple[str, str]]) -> List[Tuple[str, str]]: def get_module_to_path_map(graph: Graph) -> Dict[str, str]: - return {module: node.xpath - for module, node in graph.items()} + return {module: node.xpath for module, node in graph.items()} -def get_sources(fscache: FileSystemCache, - modules: Dict[str, str], - changed_modules: List[Tuple[str, str]]) -> List[BuildSource]: +def get_sources( + fscache: FileSystemCache, modules: Dict[str, str], changed_modules: List[Tuple[str, str]] +) -> List[BuildSource]: sources = [] for id, path in changed_modules: if fscache.isfile(path): @@ -707,9 +733,11 @@ def get_sources(fscache: FileSystemCache, return sources -def calculate_active_triggers(manager: BuildManager, - old_snapshots: Dict[str, Dict[str, SnapshotItem]], - new_modules: Dict[str, Optional[MypyFile]]) -> Set[str]: +def calculate_active_triggers( + manager: BuildManager, + old_snapshots: Dict[str, Dict[str, SnapshotItem]], + new_modules: Dict[str, Optional[MypyFile]], +) -> Set[str]: """Determine activated triggers by comparing old and new symbol tables. For example, if only the signature of function m.f is different in the new @@ -728,14 +756,15 @@ def calculate_active_triggers(manager: BuildManager, else: snapshot2 = snapshot_symbol_table(id, new.names) diff = compare_symbol_table_snapshots(id, snapshot1, snapshot2) - package_nesting_level = id.count('.') + package_nesting_level = id.count(".") for item in diff.copy(): - if (item.count('.') <= package_nesting_level + 1 - and item.split('.')[-1] not in ('__builtins__', - '__file__', - '__name__', - '__package__', - '__doc__')): + if item.count(".") <= package_nesting_level + 1 and item.split(".")[-1] not in ( + "__builtins__", + "__file__", + "__name__", + "__package__", + "__doc__", + ): # Activate catch-all wildcard trigger for top-level module changes (used for # "from m import *"). This also gets triggered by changes to module-private # entries, but as these unneeded dependencies only result in extra processing, @@ -744,19 +773,20 @@ def calculate_active_triggers(manager: BuildManager, # TODO: Some __* names cause mistriggers. Fix the underlying issue instead of # special casing them here. diff.add(id + WILDCARD_TAG) - if item.count('.') > package_nesting_level + 1: + if item.count(".") > package_nesting_level + 1: # These are for changes within classes, used by protocols. - diff.add(item.rsplit('.', 1)[0] + WILDCARD_TAG) + diff.add(item.rsplit(".", 1)[0] + WILDCARD_TAG) names |= diff return {make_trigger(name) for name in names} def replace_modules_with_new_variants( - manager: BuildManager, - graph: Dict[str, State], - old_modules: Dict[str, Optional[MypyFile]], - new_modules: Dict[str, Optional[MypyFile]]) -> None: + manager: BuildManager, + graph: Dict[str, State], + old_modules: Dict[str, Optional[MypyFile]], + new_modules: Dict[str, Optional[MypyFile]], +) -> None: """Replace modules with newly builds versions. Retain the identities of externally visible AST nodes in the @@ -770,20 +800,20 @@ def replace_modules_with_new_variants( preserved_module = old_modules.get(id) new_module = new_modules[id] if preserved_module and new_module is not None: - merge_asts(preserved_module, preserved_module.names, - new_module, new_module.names) + merge_asts(preserved_module, preserved_module.names, new_module, new_module.names) manager.modules[id] = preserved_module graph[id].tree = preserved_module def propagate_changes_using_dependencies( - manager: BuildManager, - graph: Dict[str, State], - deps: Dict[str, Set[str]], - triggered: Set[str], - up_to_date_modules: Set[str], - targets_with_errors: Set[str], - processed_targets: List[str]) -> List[Tuple[str, str]]: + manager: BuildManager, + graph: Dict[str, State], + deps: Dict[str, Set[str]], + triggered: Set[str], + up_to_date_modules: Set[str], + targets_with_errors: Set[str], + processed_targets: List[str], +) -> List[Tuple[str, str]]: """Transitively rechecks targets based on triggers and the dependency map. Returns a list (module id, path) tuples representing modules that contain @@ -801,10 +831,11 @@ def propagate_changes_using_dependencies( while triggered or targets_with_errors: num_iter += 1 if num_iter > MAX_ITER: - raise RuntimeError('Max number of iterations (%d) reached (endless loop?)' % MAX_ITER) + raise RuntimeError("Max number of iterations (%d) reached (endless loop?)" % MAX_ITER) - todo, unloaded, stale_protos = find_targets_recursive(manager, graph, - triggered, deps, up_to_date_modules) + todo, unloaded, stale_protos = find_targets_recursive( + manager, graph, triggered, deps, up_to_date_modules + ) # TODO: we sort to make it deterministic, but this is *incredibly* ad hoc remaining_modules.extend((id, graph[id].xpath) for id in sorted(unloaded)) # Also process targets that used to have errors, as otherwise some @@ -814,7 +845,7 @@ def propagate_changes_using_dependencies( if id is not None and id not in up_to_date_modules: if id not in todo: todo[id] = set() - manager.log_fine_grained(f'process target with error: {target}') + manager.log_fine_grained(f"process target with error: {target}") more_nodes, _ = lookup_target(manager, target) todo[id].update(more_nodes) triggered = set() @@ -834,18 +865,18 @@ def propagate_changes_using_dependencies( up_to_date_modules = set() targets_with_errors = set() if is_verbose(manager): - manager.log_fine_grained(f'triggered: {list(triggered)!r}') + manager.log_fine_grained(f"triggered: {list(triggered)!r}") return remaining_modules def find_targets_recursive( - manager: BuildManager, - graph: Graph, - triggers: Set[str], - deps: Dict[str, Set[str]], - up_to_date_modules: Set[str]) -> Tuple[Dict[str, Set[FineGrainedDeferredNode]], - Set[str], Set[TypeInfo]]: + manager: BuildManager, + graph: Graph, + triggers: Set[str], + deps: Dict[str, Set[str]], + up_to_date_modules: Set[str], +) -> Tuple[Dict[str, Set[FineGrainedDeferredNode]], Set[str], Set[TypeInfo]]: """Find names of all targets that need to reprocessed, given some triggers. Returns: A tuple containing a: @@ -866,7 +897,7 @@ def find_targets_recursive( current = worklist worklist = set() for target in current: - if target.startswith('<'): + if target.startswith("<"): module_id = module_prefix(graph, trigger_to_target(target)) if module_id: ensure_deps_loaded(module_id, deps, graph) @@ -880,8 +911,10 @@ def find_targets_recursive( if module_id in up_to_date_modules: # Already processed. continue - if (module_id not in manager.modules - or manager.modules[module_id].is_cache_skeleton): + if ( + module_id not in manager.modules + or manager.modules[module_id].is_cache_skeleton + ): # We haven't actually parsed and checked the module, so we don't have # access to the actual nodes. # Add it to the queue of files that need to be processed fully. @@ -890,7 +923,7 @@ def find_targets_recursive( if module_id not in result: result[module_id] = set() - manager.log_fine_grained(f'process: {target}') + manager.log_fine_grained(f"process: {target}") deferred, stale_proto = lookup_target(manager, target) if stale_proto: stale_protos.add(stale_proto) @@ -899,19 +932,20 @@ def find_targets_recursive( return result, unloaded_files, stale_protos -def reprocess_nodes(manager: BuildManager, - graph: Dict[str, State], - module_id: str, - nodeset: Set[FineGrainedDeferredNode], - deps: Dict[str, Set[str]], - processed_targets: List[str]) -> Set[str]: +def reprocess_nodes( + manager: BuildManager, + graph: Dict[str, State], + module_id: str, + nodeset: Set[FineGrainedDeferredNode], + deps: Dict[str, Set[str]], + processed_targets: List[str], +) -> Set[str]: """Reprocess a set of nodes within a single module. Return fired triggers. """ if module_id not in graph: - manager.log_fine_grained('%s not in graph (blocking errors or deleted?)' % - module_id) + manager.log_fine_grained("%s not in graph (blocking errors or deleted?)" % module_id) return set() file_node = manager.modules[module_id] @@ -929,7 +963,8 @@ def key(node: FineGrainedDeferredNode) -> int: options = graph[module_id].options manager.errors.set_file_ignored_lines( - file_node.path, file_node.ignored_lines, options.ignore_errors) + file_node.path, file_node.ignored_lines, options.ignore_errors + ) targets = set() for node in nodes: @@ -976,9 +1011,9 @@ def key(node: FineGrainedDeferredNode) -> int: new_symbols_snapshot = snapshot_symbol_table(file_node.fullname, file_node.names) # Check if any attribute types were changed and need to be propagated further. - changed = compare_symbol_table_snapshots(file_node.fullname, - old_symbols_snapshot, - new_symbols_snapshot) + changed = compare_symbol_table_snapshots( + file_node.fullname, old_symbols_snapshot, new_symbols_snapshot + ) new_triggered = {make_trigger(name) for name in changed} # Dependencies may have changed. @@ -1005,41 +1040,45 @@ def find_symbol_tables_recursive(prefix: str, symbols: SymbolTable) -> Dict[str, result = {} result[prefix] = symbols for name, node in symbols.items(): - if isinstance(node.node, TypeInfo) and node.node.fullname.startswith(prefix + '.'): - more = find_symbol_tables_recursive(prefix + '.' + name, node.node.names) + if isinstance(node.node, TypeInfo) and node.node.fullname.startswith(prefix + "."): + more = find_symbol_tables_recursive(prefix + "." + name, node.node.names) result.update(more) return result -def update_deps(module_id: str, - nodes: List[FineGrainedDeferredNode], - graph: Dict[str, State], - deps: Dict[str, Set[str]], - options: Options) -> None: +def update_deps( + module_id: str, + nodes: List[FineGrainedDeferredNode], + graph: Dict[str, State], + deps: Dict[str, Set[str]], + options: Options, +) -> None: for deferred in nodes: node = deferred.node type_map = graph[module_id].type_map() tree = graph[module_id].tree assert tree is not None, "Tree must be processed at this stage" - new_deps = get_dependencies_of_target(module_id, tree, node, type_map, - options.python_version) + new_deps = get_dependencies_of_target( + module_id, tree, node, type_map, options.python_version + ) for trigger, targets in new_deps.items(): deps.setdefault(trigger, set()).update(targets) # Merge also the newly added protocol deps (if any). TypeState.update_protocol_deps(deps) -def lookup_target(manager: BuildManager, - target: str) -> Tuple[List[FineGrainedDeferredNode], Optional[TypeInfo]]: +def lookup_target( + manager: BuildManager, target: str +) -> Tuple[List[FineGrainedDeferredNode], Optional[TypeInfo]]: """Look up a target by fully-qualified name. The first item in the return tuple is a list of deferred nodes that needs to be reprocessed. If the target represents a TypeInfo corresponding to a protocol, return it as a second item in the return tuple, otherwise None. """ + def not_found() -> None: - manager.log_fine_grained( - f"Can't find matching target for {target} (stale dependency?)") + manager.log_fine_grained(f"Can't find matching target for {target} (stale dependency?)") modules = manager.modules items = split_target(modules, target) @@ -1048,7 +1087,7 @@ def not_found() -> None: return [], None module, rest = items if rest: - components = rest.split('.') + components = rest.split(".") else: components = [] node: Optional[SymbolNode] = modules[module] @@ -1059,8 +1098,7 @@ def not_found() -> None: active_class = node if isinstance(node, MypyFile): file = node - if (not isinstance(node, (MypyFile, TypeInfo)) - or c not in node.names): + if not isinstance(node, (MypyFile, TypeInfo)) or c not in node.names: not_found() # Stale dependency return [], None # Don't reprocess plugin generated targets. They should get @@ -1088,15 +1126,13 @@ def not_found() -> None: for name, symnode in node.names.items(): node = symnode.node if isinstance(node, FuncDef): - method, _ = lookup_target(manager, target + '.' + name) + method, _ = lookup_target(manager, target + "." + name) result.extend(method) return result, stale_info if isinstance(node, Decorator): # Decorator targets actually refer to the function definition only. node = node.func - if not isinstance(node, (FuncDef, - MypyFile, - OverloadedFuncDef)): + if not isinstance(node, (FuncDef, MypyFile, OverloadedFuncDef)): # The target can't be refreshed. It's possible that the target was # changed to another type and we have a stale dependency pointing to it. not_found() @@ -1113,9 +1149,9 @@ def is_verbose(manager: BuildManager) -> bool: return manager.options.verbosity >= 1 or DEBUG_FINE_GRAINED -def target_from_node(module: str, - node: Union[FuncDef, MypyFile, OverloadedFuncDef] - ) -> Optional[str]: +def target_from_node( + module: str, node: Union[FuncDef, MypyFile, OverloadedFuncDef] +) -> Optional[str]: """Return the target name corresponding to a deferred node. Args: @@ -1131,29 +1167,30 @@ def target_from_node(module: str, return module else: # OverloadedFuncDef or FuncDef if node.info: - return f'{node.info.fullname}.{node.name}' + return f"{node.info.fullname}.{node.name}" else: - return f'{module}.{node.name}' + return f"{module}.{node.name}" if sys.platform != "win32": INIT_SUFFIXES: Final = ("/__init__.py", "/__init__.pyi") else: INIT_SUFFIXES: Final = ( - os.sep + '__init__.py', - os.sep + '__init__.pyi', - os.altsep + '__init__.py', - os.altsep + '__init__.pyi', + os.sep + "__init__.py", + os.sep + "__init__.pyi", + os.altsep + "__init__.py", + os.altsep + "__init__.pyi", ) def refresh_suppressed_submodules( - module: str, - path: Optional[str], - deps: Dict[str, Set[str]], - graph: Graph, - fscache: FileSystemCache, - refresh_file: Callable[[str, str], List[str]]) -> Optional[List[str]]: + module: str, + path: Optional[str], + deps: Dict[str, Set[str]], + graph: Graph, + fscache: FileSystemCache, + refresh_file: Callable[[str, str], List[str]], +) -> Optional[List[str]]: """Look for submodules that are now suppressed in target package. If a submodule a.b gets added, we need to mark it as suppressed @@ -1181,12 +1218,14 @@ def refresh_suppressed_submodules( except FileNotFoundError: entries = [] for fnam in entries: - if (not fnam.endswith(('.py', '.pyi')) - or fnam.startswith("__init__.") - or fnam.count('.') != 1): + if ( + not fnam.endswith((".py", ".pyi")) + or fnam.startswith("__init__.") + or fnam.count(".") != 1 + ): continue - shortname = fnam.split('.')[0] - submodule = module + '.' + shortname + shortname = fnam.split(".")[0] + submodule = module + "." + shortname trigger = make_trigger(submodule) # We may be missing the required fine-grained deps. @@ -1212,9 +1251,11 @@ def refresh_suppressed_submodules( assert tree # Will be fine, due to refresh_file() above for imp in tree.imports: if isinstance(imp, ImportFrom): - if (imp.id == module - and any(name == shortname for name, _ in imp.names) - and submodule not in state.suppressed_set): + if ( + imp.id == module + and any(name == shortname for name, _ in imp.names) + and submodule not in state.suppressed_set + ): state.suppressed.append(submodule) state.suppressed_set.add(submodule) return messages diff --git a/mypy/sharedparse.py b/mypy/sharedparse.py index d8bde1bd253b8..a705cf7921b06 100644 --- a/mypy/sharedparse.py +++ b/mypy/sharedparse.py @@ -1,4 +1,5 @@ from typing import Optional + from typing_extensions import Final """Shared logic between our three mypy parser files.""" diff --git a/mypy/solve.py b/mypy/solve.py index 8a3280e33c0b1..2c3a5b5e3300e 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -1,17 +1,18 @@ """Type inference constraint solving""" -from typing import List, Dict, Optional from collections import defaultdict +from typing import Dict, List, Optional -from mypy.types import Type, AnyType, UninhabitedType, TypeVarId, TypeOfAny, get_proper_type -from mypy.constraints import Constraint, SUPERTYPE_OF +from mypy.constraints import SUPERTYPE_OF, Constraint from mypy.join import join_types from mypy.meet import meet_types from mypy.subtypes import is_subtype +from mypy.types import AnyType, Type, TypeOfAny, TypeVarId, UninhabitedType, get_proper_type -def solve_constraints(vars: List[TypeVarId], constraints: List[Constraint], - strict: bool = True) -> List[Optional[Type]]: +def solve_constraints( + vars: List[TypeVarId], constraints: List[Constraint], strict: bool = True +) -> List[Optional[Type]]: """Solve type constraints. Return the best type(s) for type variables; each type can be None if the value of the variable diff --git a/mypy/split_namespace.py b/mypy/split_namespace.py index 64a239c6a1c7f..e5cadb65de40e 100644 --- a/mypy/split_namespace.py +++ b/mypy/split_namespace.py @@ -8,27 +8,26 @@ # __getattr__/__setattr__ and has some issues with __dict__ import argparse - -from typing import Tuple, Any +from typing import Any, Tuple class SplitNamespace(argparse.Namespace): def __init__(self, standard_namespace: object, alt_namespace: object, alt_prefix: str) -> None: - self.__dict__['_standard_namespace'] = standard_namespace - self.__dict__['_alt_namespace'] = alt_namespace - self.__dict__['_alt_prefix'] = alt_prefix + self.__dict__["_standard_namespace"] = standard_namespace + self.__dict__["_alt_namespace"] = alt_namespace + self.__dict__["_alt_prefix"] = alt_prefix def _get(self) -> Tuple[Any, Any]: return (self._standard_namespace, self._alt_namespace) def __setattr__(self, name: str, value: Any) -> None: if name.startswith(self._alt_prefix): - setattr(self._alt_namespace, name[len(self._alt_prefix):], value) + setattr(self._alt_namespace, name[len(self._alt_prefix) :], value) else: setattr(self._standard_namespace, name, value) def __getattr__(self, name: str) -> Any: if name.startswith(self._alt_prefix): - return getattr(self._alt_namespace, name[len(self._alt_prefix):]) + return getattr(self._alt_namespace, name[len(self._alt_prefix) :]) else: return getattr(self._standard_namespace, name) diff --git a/mypy/state.py b/mypy/state.py index 8aba966a33c06..b289fcfe73aec 100644 --- a/mypy/state.py +++ b/mypy/state.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Optional, Tuple, Iterator +from typing import Iterator, Optional, Tuple from typing_extensions import Final diff --git a/mypy/stats.py b/mypy/stats.py index a9769b55e20d8..562bef294279b 100644 --- a/mypy/stats.py +++ b/mypy/stats.py @@ -1,28 +1,64 @@ """Utilities for calculating and reporting statistics about types.""" import os +import typing from collections import Counter from contextlib import contextmanager +from typing import Dict, Iterator, List, Optional, Union, cast -import typing -from typing import Dict, List, cast, Optional, Union, Iterator from typing_extensions import Final +from mypy import nodes +from mypy.argmap import map_formals_to_actuals +from mypy.nodes import ( + AssignmentExpr, + AssignmentStmt, + BreakStmt, + BytesExpr, + CallExpr, + ClassDef, + ComparisonExpr, + ComplexExpr, + ContinueStmt, + EllipsisExpr, + Expression, + ExpressionStmt, + FloatExpr, + FuncDef, + Import, + ImportAll, + ImportFrom, + IndexExpr, + IntExpr, + MemberExpr, + MypyFile, + NameExpr, + Node, + OpExpr, + PassStmt, + RefExpr, + StrExpr, + TypeApplication, + UnaryExpr, + UnicodeExpr, + YieldFromExpr, +) from mypy.traverser import TraverserVisitor from mypy.typeanal import collect_all_inner_types from mypy.types import ( - Type, AnyType, Instance, FunctionLike, TupleType, TypeVarType, TypeQuery, CallableType, - TypeOfAny, get_proper_type, get_proper_types -) -from mypy import nodes -from mypy.nodes import ( - Expression, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr, MypyFile, - MemberExpr, OpExpr, ComparisonExpr, IndexExpr, UnaryExpr, YieldFromExpr, RefExpr, ClassDef, - AssignmentExpr, ImportFrom, Import, ImportAll, PassStmt, BreakStmt, ContinueStmt, StrExpr, - BytesExpr, UnicodeExpr, IntExpr, FloatExpr, ComplexExpr, EllipsisExpr, ExpressionStmt, Node + AnyType, + CallableType, + FunctionLike, + Instance, + TupleType, + Type, + TypeOfAny, + TypeQuery, + TypeVarType, + get_proper_type, + get_proper_types, ) from mypy.util import correct_relative_import -from mypy.argmap import map_formals_to_actuals TYPE_EMPTY: Final = 0 TYPE_UNANALYZED: Final = 1 # type of non-typechecked code @@ -30,23 +66,19 @@ TYPE_IMPRECISE: Final = 3 TYPE_ANY: Final = 4 -precision_names: Final = [ - 'empty', - 'unanalyzed', - 'precise', - 'imprecise', - 'any', -] +precision_names: Final = ["empty", "unanalyzed", "precise", "imprecise", "any"] class StatisticsVisitor(TraverserVisitor): - def __init__(self, - inferred: bool, - filename: str, - modules: Dict[str, MypyFile], - typemap: Optional[Dict[Expression, Type]] = None, - all_nodes: bool = False, - visit_untyped_defs: bool = True) -> None: + def __init__( + self, + inferred: bool, + filename: str, + modules: Dict[str, MypyFile], + typemap: Optional[Dict[Expression, Type]] = None, + all_nodes: bool = False, + visit_untyped_defs: bool = True, + ) -> None: self.inferred = inferred self.filename = filename self.modules = modules @@ -95,10 +127,9 @@ def visit_import_all(self, imp: ImportAll) -> None: self.process_import(imp) def process_import(self, imp: Union[ImportFrom, ImportAll]) -> None: - import_id, ok = correct_relative_import(self.cur_mod_id, - imp.relative, - imp.id, - self.cur_mod_node.is_package_init_file()) + import_id, ok = correct_relative_import( + self.cur_mod_id, imp.relative, imp.id, self.cur_mod_node.is_package_init_file() + ) if ok and import_id in self.modules: kind = TYPE_PRECISE else: @@ -117,9 +148,11 @@ def visit_func_def(self, o: FuncDef) -> None: self.line = o.line if len(o.expanded) > 1 and o.expanded != [o] * len(o.expanded): if o in o.expanded: - print('{}:{}: ERROR: cycle in function expansion; skipping'.format( - self.filename, - o.get_line())) + print( + "{}:{}: ERROR: cycle in function expansion; skipping".format( + self.filename, o.get_line() + ) + ) return for defn in o.expanded: self.visit_func_def(cast(FuncDef, defn)) @@ -127,8 +160,7 @@ def visit_func_def(self, o: FuncDef) -> None: if o.type: sig = cast(CallableType, o.type) arg_types = sig.arg_types - if (sig.arg_names and sig.arg_names[0] == 'self' and - not self.inferred): + if sig.arg_names and sig.arg_names[0] == "self" and not self.inferred: arg_types = arg_types[1:] for arg in arg_types: self.type(arg) @@ -165,8 +197,9 @@ def visit_type_application(self, o: TypeApplication) -> None: def visit_assignment_stmt(self, o: AssignmentStmt) -> None: self.line = o.line - if (isinstance(o.rvalue, nodes.CallExpr) and - isinstance(o.rvalue.analyzed, nodes.TypeVarExpr)): + if isinstance(o.rvalue, nodes.CallExpr) and isinstance( + o.rvalue.analyzed, nodes.TypeVarExpr + ): # Type variable definition -- not a real assignment. return if o.type: @@ -201,10 +234,7 @@ def visit_continue_stmt(self, o: ContinueStmt) -> None: self.record_precise_if_checked_scope(o) def visit_name_expr(self, o: NameExpr) -> None: - if o.fullname in ('builtins.None', - 'builtins.True', - 'builtins.False', - 'builtins.Ellipsis'): + if o.fullname in ("builtins.None", "builtins.True", "builtins.False", "builtins.Ellipsis"): self.record_precise_if_checked_scope(o) else: self.process_node(o) @@ -250,7 +280,8 @@ def record_callable_target_precision(self, o: CallExpr, callee: CallableType) -> o.arg_names, callee.arg_kinds, callee.arg_names, - lambda n: typemap[o.args[n]]) + lambda n: typemap[o.args[n]], + ) for formals in actual_to_formal: for n in formals: formal = get_proper_type(callee.arg_types[n]) @@ -337,12 +368,11 @@ def type(self, t: Optional[Type]) -> None: return if isinstance(t, AnyType): - self.log(' !! Any type around line %d' % self.line) + self.log(" !! Any type around line %d" % self.line) self.num_any_exprs += 1 self.record_line(self.line, TYPE_ANY) - elif ((not self.all_nodes and is_imprecise(t)) or - (self.all_nodes and is_imprecise2(t))): - self.log(' !! Imprecise type around line %d' % self.line) + elif (not self.all_nodes and is_imprecise(t)) or (self.all_nodes and is_imprecise2(t)): + self.log(" !! Imprecise type around line %d" % self.line) self.num_imprecise_exprs += 1 self.record_line(self.line, TYPE_IMPRECISE) else: @@ -382,41 +412,39 @@ def log(self, string: str) -> None: self.output.append(string) def record_line(self, line: int, precision: int) -> None: - self.line_map[line] = max(precision, - self.line_map.get(line, TYPE_EMPTY)) + self.line_map[line] = max(precision, self.line_map.get(line, TYPE_EMPTY)) -def dump_type_stats(tree: MypyFile, - path: str, - modules: Dict[str, MypyFile], - inferred: bool = False, - typemap: Optional[Dict[Expression, Type]] = None) -> None: +def dump_type_stats( + tree: MypyFile, + path: str, + modules: Dict[str, MypyFile], + inferred: bool = False, + typemap: Optional[Dict[Expression, Type]] = None, +) -> None: if is_special_module(path): return print(path) - visitor = StatisticsVisitor(inferred, - filename=tree.fullname, - modules=modules, - typemap=typemap) + visitor = StatisticsVisitor(inferred, filename=tree.fullname, modules=modules, typemap=typemap) tree.accept(visitor) for line in visitor.output: print(line) - print(' ** precision **') - print(' precise ', visitor.num_precise_exprs) - print(' imprecise', visitor.num_imprecise_exprs) - print(' any ', visitor.num_any_exprs) - print(' ** kinds **') - print(' simple ', visitor.num_simple_types) - print(' generic ', visitor.num_generic_types) - print(' function ', visitor.num_function_types) - print(' tuple ', visitor.num_tuple_types) - print(' TypeVar ', visitor.num_typevar_types) - print(' complex ', visitor.num_complex_types) - print(' any ', visitor.num_any_types) + print(" ** precision **") + print(" precise ", visitor.num_precise_exprs) + print(" imprecise", visitor.num_imprecise_exprs) + print(" any ", visitor.num_any_exprs) + print(" ** kinds **") + print(" simple ", visitor.num_simple_types) + print(" generic ", visitor.num_generic_types) + print(" function ", visitor.num_function_types) + print(" tuple ", visitor.num_tuple_types) + print(" TypeVar ", visitor.num_typevar_types) + print(" complex ", visitor.num_complex_types) + print(" any ", visitor.num_any_types) def is_special_module(path: str) -> bool: - return os.path.basename(path) in ('abc.pyi', 'typing.pyi', 'builtins.pyi') + return os.path.basename(path) in ("abc.pyi", "typing.pyi", "builtins.pyi") def is_imprecise(t: Type) -> bool: @@ -449,8 +477,7 @@ def is_generic(t: Type) -> bool: def is_complex(t: Type) -> bool: t = get_proper_type(t) - return is_generic(t) or isinstance(t, (FunctionLike, TupleType, - TypeVarType)) + return is_generic(t) or isinstance(t, (FunctionLike, TupleType, TypeVarType)) def ensure_dir_exists(dir: str) -> None: diff --git a/mypy/strconv.py b/mypy/strconv.py index 8d6cf92d8f2a0..cbb9bad2e9942 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -1,13 +1,13 @@ """Conversion of parse tree nodes to strings.""" -import re import os +import re +from typing import Any, List, Optional, Sequence, Tuple, Union -from typing import Any, List, Tuple, Optional, Union, Sequence from typing_extensions import TYPE_CHECKING -from mypy.util import short_type, IdMapper import mypy.nodes +from mypy.util import IdMapper, short_type from mypy.visitor import NodeVisitor if TYPE_CHECKING: @@ -39,24 +39,24 @@ def get_id(self, o: object) -> Optional[int]: def format_id(self, o: object) -> str: if self.id_mapper: - return f'<{self.get_id(o)}>' + return f"<{self.get_id(o)}>" else: - return '' + return "" - def dump(self, nodes: Sequence[object], obj: 'mypy.nodes.Context') -> str: + def dump(self, nodes: Sequence[object], obj: "mypy.nodes.Context") -> str: """Convert a list of items to a multiline pretty-printed string. The tag is produced from the type name of obj and its line number. See mypy.util.dump_tagged for a description of the nodes argument. """ - tag = short_type(obj) + ':' + str(obj.get_line()) + tag = short_type(obj) + ":" + str(obj.get_line()) if self.show_ids: assert self.id_mapper is not None - tag += f'<{self.get_id(obj)}>' + tag += f"<{self.get_id(obj)}>" return dump_tagged(nodes, tag, self) - def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]: + def func_helper(self, o: "mypy.nodes.FuncItem") -> List[object]: """Return a list in a format suitable for dump() that represents the arguments and the body of a function. The caller can then decorate the array with information specific to methods, global functions or @@ -70,146 +70,144 @@ def func_helper(self, o: 'mypy.nodes.FuncItem') -> List[object]: args.append(arg.variable) elif kind.is_optional(): assert arg.initializer is not None - args.append(('default', [arg.variable, arg.initializer])) + args.append(("default", [arg.variable, arg.initializer])) elif kind == mypy.nodes.ARG_STAR: - extra.append(('VarArg', [arg.variable])) + extra.append(("VarArg", [arg.variable])) elif kind == mypy.nodes.ARG_STAR2: - extra.append(('DictVarArg', [arg.variable])) + extra.append(("DictVarArg", [arg.variable])) a: List[Any] = [] if args: - a.append(('Args', args)) + a.append(("Args", args)) if o.type: a.append(o.type) if o.is_generator: - a.append('Generator') + a.append("Generator") a.extend(extra) a.append(o.body) return a # Top-level structures - def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> str: + def visit_mypy_file(self, o: "mypy.nodes.MypyFile") -> str: # Skip implicit definitions. a: List[Any] = [o.defs] if o.is_bom: - a.insert(0, 'BOM') + a.insert(0, "BOM") # Omit path to special file with name "main". This is used to simplify # test case descriptions; the file "main" is used by default in many # test cases. - if o.path != 'main': + if o.path != "main": # Insert path. Normalize directory separators to / to unify test # case# output in all platforms. - a.insert(0, o.path.replace(os.sep, '/')) + a.insert(0, o.path.replace(os.sep, "/")) if o.ignored_lines: - a.append('IgnoredLines(%s)' % ', '.join(str(line) - for line in sorted(o.ignored_lines))) + a.append("IgnoredLines(%s)" % ", ".join(str(line) for line in sorted(o.ignored_lines))) return self.dump(a, o) - def visit_import(self, o: 'mypy.nodes.Import') -> str: + def visit_import(self, o: "mypy.nodes.Import") -> str: a = [] for id, as_id in o.ids: if as_id is not None: - a.append(f'{id} : {as_id}') + a.append(f"{id} : {as_id}") else: a.append(id) return f"Import:{o.line}({', '.join(a)})" - def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> str: + def visit_import_from(self, o: "mypy.nodes.ImportFrom") -> str: a = [] for name, as_name in o.names: if as_name is not None: - a.append(f'{name} : {as_name}') + a.append(f"{name} : {as_name}") else: a.append(name) return f"ImportFrom:{o.line}({'.' * o.relative + o.id}, [{', '.join(a)}])" - def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> str: + def visit_import_all(self, o: "mypy.nodes.ImportAll") -> str: return f"ImportAll:{o.line}({'.' * o.relative + o.id})" # Definitions - def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> str: + def visit_func_def(self, o: "mypy.nodes.FuncDef") -> str: a = self.func_helper(o) a.insert(0, o.name) arg_kinds = {arg.kind for arg in o.arguments} if len(arg_kinds & {mypy.nodes.ARG_NAMED, mypy.nodes.ARG_NAMED_OPT}) > 0: - a.insert(1, f'MaxPos({o.max_pos})') + a.insert(1, f"MaxPos({o.max_pos})") if o.is_abstract: - a.insert(-1, 'Abstract') + a.insert(-1, "Abstract") if o.is_static: - a.insert(-1, 'Static') + a.insert(-1, "Static") if o.is_class: - a.insert(-1, 'Class') + a.insert(-1, "Class") if o.is_property: - a.insert(-1, 'Property') + a.insert(-1, "Property") return self.dump(a, o) - def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> str: + def visit_overloaded_func_def(self, o: "mypy.nodes.OverloadedFuncDef") -> str: a: Any = o.items[:] if o.type: a.insert(0, o.type) if o.impl: a.insert(0, o.impl) if o.is_static: - a.insert(-1, 'Static') + a.insert(-1, "Static") if o.is_class: - a.insert(-1, 'Class') + a.insert(-1, "Class") return self.dump(a, o) - def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> str: + def visit_class_def(self, o: "mypy.nodes.ClassDef") -> str: a = [o.name, o.defs.body] # Display base types unless they are implicitly just builtins.object # (in this case base_type_exprs is empty). if o.base_type_exprs: if o.info and o.info.bases: - if (len(o.info.bases) != 1 - or o.info.bases[0].type.fullname != 'builtins.object'): - a.insert(1, ('BaseType', o.info.bases)) + if len(o.info.bases) != 1 or o.info.bases[0].type.fullname != "builtins.object": + a.insert(1, ("BaseType", o.info.bases)) else: - a.insert(1, ('BaseTypeExpr', o.base_type_exprs)) + a.insert(1, ("BaseTypeExpr", o.base_type_exprs)) if o.type_vars: - a.insert(1, ('TypeVars', o.type_vars)) + a.insert(1, ("TypeVars", o.type_vars)) if o.metaclass: - a.insert(1, f'Metaclass({o.metaclass})') + a.insert(1, f"Metaclass({o.metaclass})") if o.decorators: - a.insert(1, ('Decorators', o.decorators)) + a.insert(1, ("Decorators", o.decorators)) if o.info and o.info._promote: - a.insert(1, f'Promote({o.info._promote})') + a.insert(1, f"Promote({o.info._promote})") if o.info and o.info.tuple_type: - a.insert(1, ('TupleType', [o.info.tuple_type])) + a.insert(1, ("TupleType", [o.info.tuple_type])) if o.info and o.info.fallback_to_any: - a.insert(1, 'FallbackToAny') + a.insert(1, "FallbackToAny") return self.dump(a, o) - def visit_var(self, o: 'mypy.nodes.Var') -> str: - lst = '' + def visit_var(self, o: "mypy.nodes.Var") -> str: + lst = "" # Add :nil line number tag if no line number is specified to remain # compatible with old test case descriptions that assume this. if o.line < 0: - lst = ':nil' - return 'Var' + lst + '(' + o.name + ')' + lst = ":nil" + return "Var" + lst + "(" + o.name + ")" - def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> str: + def visit_global_decl(self, o: "mypy.nodes.GlobalDecl") -> str: return self.dump([o.names], o) - def visit_nonlocal_decl(self, o: 'mypy.nodes.NonlocalDecl') -> str: + def visit_nonlocal_decl(self, o: "mypy.nodes.NonlocalDecl") -> str: return self.dump([o.names], o) - def visit_decorator(self, o: 'mypy.nodes.Decorator') -> str: + def visit_decorator(self, o: "mypy.nodes.Decorator") -> str: return self.dump([o.var, o.decorators, o.func], o) # Statements - def visit_block(self, o: 'mypy.nodes.Block') -> str: + def visit_block(self, o: "mypy.nodes.Block") -> str: return self.dump(o.body, o) - def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> str: + def visit_expression_stmt(self, o: "mypy.nodes.ExpressionStmt") -> str: return self.dump([o.expr], o) - def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> str: + def visit_assignment_stmt(self, o: "mypy.nodes.AssignmentStmt") -> str: a: List[Any] = [] if len(o.lvalues) > 1: - a = [('Lvalues', o.lvalues)] + a = [("Lvalues", o.lvalues)] else: a = [o.lvalues[0]] a.append(o.rvalue) @@ -217,66 +215,66 @@ def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> str: a.append(o.type) return self.dump(a, o) - def visit_operator_assignment_stmt(self, o: 'mypy.nodes.OperatorAssignmentStmt') -> str: + def visit_operator_assignment_stmt(self, o: "mypy.nodes.OperatorAssignmentStmt") -> str: return self.dump([o.op, o.lvalue, o.rvalue], o) - def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> str: + def visit_while_stmt(self, o: "mypy.nodes.WhileStmt") -> str: a: List[Any] = [o.expr, o.body] if o.else_body: - a.append(('Else', o.else_body.body)) + a.append(("Else", o.else_body.body)) return self.dump(a, o) - def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> str: + def visit_for_stmt(self, o: "mypy.nodes.ForStmt") -> str: a: List[Any] = [] if o.is_async: - a.append(('Async', '')) + a.append(("Async", "")) a.append(o.index) if o.index_type: a.append(o.index_type) a.extend([o.expr, o.body]) if o.else_body: - a.append(('Else', o.else_body.body)) + a.append(("Else", o.else_body.body)) return self.dump(a, o) - def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> str: + def visit_return_stmt(self, o: "mypy.nodes.ReturnStmt") -> str: return self.dump([o.expr], o) - def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> str: + def visit_if_stmt(self, o: "mypy.nodes.IfStmt") -> str: a: List[Any] = [] for i in range(len(o.expr)): - a.append(('If', [o.expr[i]])) - a.append(('Then', o.body[i].body)) + a.append(("If", [o.expr[i]])) + a.append(("Then", o.body[i].body)) if not o.else_body: return self.dump(a, o) else: - return self.dump([a, ('Else', o.else_body.body)], o) + return self.dump([a, ("Else", o.else_body.body)], o) - def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> str: + def visit_break_stmt(self, o: "mypy.nodes.BreakStmt") -> str: return self.dump([], o) - def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> str: + def visit_continue_stmt(self, o: "mypy.nodes.ContinueStmt") -> str: return self.dump([], o) - def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> str: + def visit_pass_stmt(self, o: "mypy.nodes.PassStmt") -> str: return self.dump([], o) - def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> str: + def visit_raise_stmt(self, o: "mypy.nodes.RaiseStmt") -> str: return self.dump([o.expr, o.from_expr], o) - def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> str: + def visit_assert_stmt(self, o: "mypy.nodes.AssertStmt") -> str: if o.msg is not None: return self.dump([o.expr, o.msg], o) else: return self.dump([o.expr], o) - def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> str: + def visit_await_expr(self, o: "mypy.nodes.AwaitExpr") -> str: return self.dump([o.expr], o) - def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> str: + def visit_del_stmt(self, o: "mypy.nodes.DelStmt") -> str: return self.dump([o.expr], o) - def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> str: + def visit_try_stmt(self, o: "mypy.nodes.TryStmt") -> str: a: List[Any] = [o.body] for i in range(len(o.vars)): @@ -286,124 +284,128 @@ def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> str: a.append(o.handlers[i]) if o.else_body: - a.append(('Else', o.else_body.body)) + a.append(("Else", o.else_body.body)) if o.finally_body: - a.append(('Finally', o.finally_body.body)) + a.append(("Finally", o.finally_body.body)) return self.dump(a, o) - def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> str: + def visit_with_stmt(self, o: "mypy.nodes.WithStmt") -> str: a: List[Any] = [] if o.is_async: - a.append(('Async', '')) + a.append(("Async", "")) for i in range(len(o.expr)): - a.append(('Expr', [o.expr[i]])) + a.append(("Expr", [o.expr[i]])) if o.target[i]: - a.append(('Target', [o.target[i]])) + a.append(("Target", [o.target[i]])) if o.unanalyzed_type: a.append(o.unanalyzed_type) return self.dump(a + [o.body], o) - def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> str: + def visit_print_stmt(self, o: "mypy.nodes.PrintStmt") -> str: a: List[Any] = o.args[:] if o.target: - a.append(('Target', [o.target])) + a.append(("Target", [o.target])) if o.newline: - a.append('Newline') + a.append("Newline") return self.dump(a, o) - def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> str: + def visit_exec_stmt(self, o: "mypy.nodes.ExecStmt") -> str: return self.dump([o.expr, o.globals, o.locals], o) - def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> str: + def visit_match_stmt(self, o: "mypy.nodes.MatchStmt") -> str: a: List[Any] = [o.subject] for i in range(len(o.patterns)): - a.append(('Pattern', [o.patterns[i]])) + a.append(("Pattern", [o.patterns[i]])) if o.guards[i] is not None: - a.append(('Guard', [o.guards[i]])) - a.append(('Body', o.bodies[i].body)) + a.append(("Guard", [o.guards[i]])) + a.append(("Body", o.bodies[i].body)) return self.dump(a, o) # Expressions # Simple expressions - def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> str: - return f'IntExpr({o.value})' + def visit_int_expr(self, o: "mypy.nodes.IntExpr") -> str: + return f"IntExpr({o.value})" - def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> str: - return f'StrExpr({self.str_repr(o.value)})' + def visit_str_expr(self, o: "mypy.nodes.StrExpr") -> str: + return f"StrExpr({self.str_repr(o.value)})" - def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> str: - return f'BytesExpr({self.str_repr(o.value)})' + def visit_bytes_expr(self, o: "mypy.nodes.BytesExpr") -> str: + return f"BytesExpr({self.str_repr(o.value)})" - def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> str: - return f'UnicodeExpr({self.str_repr(o.value)})' + def visit_unicode_expr(self, o: "mypy.nodes.UnicodeExpr") -> str: + return f"UnicodeExpr({self.str_repr(o.value)})" def str_repr(self, s: str) -> str: - s = re.sub(r'\\u[0-9a-fA-F]{4}', lambda m: '\\' + m.group(0), s) - return re.sub('[^\\x20-\\x7e]', - lambda m: r'\u%.4x' % ord(m.group(0)), s) + s = re.sub(r"\\u[0-9a-fA-F]{4}", lambda m: "\\" + m.group(0), s) + return re.sub("[^\\x20-\\x7e]", lambda m: r"\u%.4x" % ord(m.group(0)), s) - def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> str: - return f'FloatExpr({o.value})' + def visit_float_expr(self, o: "mypy.nodes.FloatExpr") -> str: + return f"FloatExpr({o.value})" - def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> str: - return f'ComplexExpr({o.value})' + def visit_complex_expr(self, o: "mypy.nodes.ComplexExpr") -> str: + return f"ComplexExpr({o.value})" - def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> str: - return 'Ellipsis' + def visit_ellipsis(self, o: "mypy.nodes.EllipsisExpr") -> str: + return "Ellipsis" - def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> str: + def visit_star_expr(self, o: "mypy.nodes.StarExpr") -> str: return self.dump([o.expr], o) - def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> str: - pretty = self.pretty_name(o.name, o.kind, o.fullname, - o.is_inferred_def or o.is_special_form, - o.node) + def visit_name_expr(self, o: "mypy.nodes.NameExpr") -> str: + pretty = self.pretty_name( + o.name, o.kind, o.fullname, o.is_inferred_def or o.is_special_form, o.node + ) if isinstance(o.node, mypy.nodes.Var) and o.node.is_final: - pretty += f' = {o.node.final_value}' - return short_type(o) + '(' + pretty + ')' - - def pretty_name(self, name: str, kind: Optional[int], fullname: Optional[str], - is_inferred_def: bool, target_node: 'Optional[mypy.nodes.Node]' = None) -> str: + pretty += f" = {o.node.final_value}" + return short_type(o) + "(" + pretty + ")" + + def pretty_name( + self, + name: str, + kind: Optional[int], + fullname: Optional[str], + is_inferred_def: bool, + target_node: "Optional[mypy.nodes.Node]" = None, + ) -> str: n = name if is_inferred_def: - n += '*' + n += "*" if target_node: id = self.format_id(target_node) else: - id = '' + id = "" if isinstance(target_node, mypy.nodes.MypyFile) and name == fullname: n += id - elif kind == mypy.nodes.GDEF or (fullname != name and - fullname is not None): + elif kind == mypy.nodes.GDEF or (fullname != name and fullname is not None): # Append fully qualified name for global references. - n += f' [{fullname}{id}]' + n += f" [{fullname}{id}]" elif kind == mypy.nodes.LDEF: # Add tag to signify a local reference. - n += f' [l{id}]' + n += f" [l{id}]" elif kind == mypy.nodes.MDEF: # Add tag to signify a member reference. - n += f' [m{id}]' + n += f" [m{id}]" else: n += id return n - def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> str: + def visit_member_expr(self, o: "mypy.nodes.MemberExpr") -> str: pretty = self.pretty_name(o.name, o.kind, o.fullname, o.is_inferred_def, o.node) return self.dump([o.expr, pretty], o) - def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> str: + def visit_yield_expr(self, o: "mypy.nodes.YieldExpr") -> str: return self.dump([o.expr], o) - def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> str: + def visit_yield_from_expr(self, o: "mypy.nodes.YieldFromExpr") -> str: if o.expr: return self.dump([o.expr.accept(self)], o) else: return self.dump([], o) - def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> str: + def visit_call_expr(self, o: "mypy.nodes.CallExpr") -> str: if o.analyzed: return o.analyzed.accept(self) args: List[mypy.nodes.Expression] = [] @@ -412,193 +414,193 @@ def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> str: if kind in [mypy.nodes.ARG_POS, mypy.nodes.ARG_STAR]: args.append(o.args[i]) if kind == mypy.nodes.ARG_STAR: - extra.append('VarArg') + extra.append("VarArg") elif kind == mypy.nodes.ARG_NAMED: - extra.append(('KwArgs', [o.arg_names[i], o.args[i]])) + extra.append(("KwArgs", [o.arg_names[i], o.args[i]])) elif kind == mypy.nodes.ARG_STAR2: - extra.append(('DictVarArg', [o.args[i]])) + extra.append(("DictVarArg", [o.args[i]])) else: raise RuntimeError(f"unknown kind {kind}") a: List[Any] = [o.callee, ("Args", args)] return self.dump(a + extra, o) - def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> str: + def visit_op_expr(self, o: "mypy.nodes.OpExpr") -> str: return self.dump([o.op, o.left, o.right], o) - def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> str: + def visit_comparison_expr(self, o: "mypy.nodes.ComparisonExpr") -> str: return self.dump([o.operators, o.operands], o) - def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> str: + def visit_cast_expr(self, o: "mypy.nodes.CastExpr") -> str: return self.dump([o.expr, o.type], o) - def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> str: + def visit_assert_type_expr(self, o: "mypy.nodes.AssertTypeExpr") -> str: return self.dump([o.expr, o.type], o) - def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> str: + def visit_reveal_expr(self, o: "mypy.nodes.RevealExpr") -> str: if o.kind == mypy.nodes.REVEAL_TYPE: return self.dump([o.expr], o) else: # REVEAL_LOCALS return self.dump([o.local_nodes], o) - def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> str: + def visit_assignment_expr(self, o: "mypy.nodes.AssignmentExpr") -> str: return self.dump([o.target, o.value], o) - def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> str: + def visit_unary_expr(self, o: "mypy.nodes.UnaryExpr") -> str: return self.dump([o.op, o.expr], o) - def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> str: + def visit_list_expr(self, o: "mypy.nodes.ListExpr") -> str: return self.dump(o.items, o) - def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> str: + def visit_dict_expr(self, o: "mypy.nodes.DictExpr") -> str: return self.dump([[k, v] for k, v in o.items], o) - def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> str: + def visit_set_expr(self, o: "mypy.nodes.SetExpr") -> str: return self.dump(o.items, o) - def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> str: + def visit_tuple_expr(self, o: "mypy.nodes.TupleExpr") -> str: return self.dump(o.items, o) - def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> str: + def visit_index_expr(self, o: "mypy.nodes.IndexExpr") -> str: if o.analyzed: return o.analyzed.accept(self) return self.dump([o.base, o.index], o) - def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> str: + def visit_super_expr(self, o: "mypy.nodes.SuperExpr") -> str: return self.dump([o.name, o.call], o) - def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> str: - return self.dump([o.expr, ('Types', o.types)], o) + def visit_type_application(self, o: "mypy.nodes.TypeApplication") -> str: + return self.dump([o.expr, ("Types", o.types)], o) - def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> str: + def visit_type_var_expr(self, o: "mypy.nodes.TypeVarExpr") -> str: import mypy.types a: List[Any] = [] if o.variance == mypy.nodes.COVARIANT: - a += ['Variance(COVARIANT)'] + a += ["Variance(COVARIANT)"] if o.variance == mypy.nodes.CONTRAVARIANT: - a += ['Variance(CONTRAVARIANT)'] + a += ["Variance(CONTRAVARIANT)"] if o.values: - a += [('Values', o.values)] - if not mypy.types.is_named_instance(o.upper_bound, 'builtins.object'): - a += [f'UpperBound({o.upper_bound})'] + a += [("Values", o.values)] + if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"): + a += [f"UpperBound({o.upper_bound})"] return self.dump(a, o) - def visit_paramspec_expr(self, o: 'mypy.nodes.ParamSpecExpr') -> str: + def visit_paramspec_expr(self, o: "mypy.nodes.ParamSpecExpr") -> str: import mypy.types a: List[Any] = [] if o.variance == mypy.nodes.COVARIANT: - a += ['Variance(COVARIANT)'] + a += ["Variance(COVARIANT)"] if o.variance == mypy.nodes.CONTRAVARIANT: - a += ['Variance(CONTRAVARIANT)'] - if not mypy.types.is_named_instance(o.upper_bound, 'builtins.object'): - a += [f'UpperBound({o.upper_bound})'] + a += ["Variance(CONTRAVARIANT)"] + if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"): + a += [f"UpperBound({o.upper_bound})"] return self.dump(a, o) - def visit_type_var_tuple_expr(self, o: 'mypy.nodes.TypeVarTupleExpr') -> str: + def visit_type_var_tuple_expr(self, o: "mypy.nodes.TypeVarTupleExpr") -> str: import mypy.types a: List[Any] = [] if o.variance == mypy.nodes.COVARIANT: - a += ['Variance(COVARIANT)'] + a += ["Variance(COVARIANT)"] if o.variance == mypy.nodes.CONTRAVARIANT: - a += ['Variance(CONTRAVARIANT)'] - if not mypy.types.is_named_instance(o.upper_bound, 'builtins.object'): - a += [f'UpperBound({o.upper_bound})'] + a += ["Variance(CONTRAVARIANT)"] + if not mypy.types.is_named_instance(o.upper_bound, "builtins.object"): + a += [f"UpperBound({o.upper_bound})"] return self.dump(a, o) - def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> str: - return f'TypeAliasExpr({o.type})' + def visit_type_alias_expr(self, o: "mypy.nodes.TypeAliasExpr") -> str: + return f"TypeAliasExpr({o.type})" - def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> str: - return f'NamedTupleExpr:{o.line}({o.info.name}, {o.info.tuple_type})' + def visit_namedtuple_expr(self, o: "mypy.nodes.NamedTupleExpr") -> str: + return f"NamedTupleExpr:{o.line}({o.info.name}, {o.info.tuple_type})" - def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> str: - return f'EnumCallExpr:{o.line}({o.info.name}, {o.items})' + def visit_enum_call_expr(self, o: "mypy.nodes.EnumCallExpr") -> str: + return f"EnumCallExpr:{o.line}({o.info.name}, {o.items})" - def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> str: - return f'TypedDictExpr:{o.line}({o.info.name})' + def visit_typeddict_expr(self, o: "mypy.nodes.TypedDictExpr") -> str: + return f"TypedDictExpr:{o.line}({o.info.name})" - def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> str: - return f'PromoteExpr:{o.line}({o.type})' + def visit__promote_expr(self, o: "mypy.nodes.PromoteExpr") -> str: + return f"PromoteExpr:{o.line}({o.type})" - def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> str: - return f'NewTypeExpr:{o.line}({o.name}, {self.dump([o.old_type], o)})' + def visit_newtype_expr(self, o: "mypy.nodes.NewTypeExpr") -> str: + return f"NewTypeExpr:{o.line}({o.name}, {self.dump([o.old_type], o)})" - def visit_lambda_expr(self, o: 'mypy.nodes.LambdaExpr') -> str: + def visit_lambda_expr(self, o: "mypy.nodes.LambdaExpr") -> str: a = self.func_helper(o) return self.dump(a, o) - def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> str: + def visit_generator_expr(self, o: "mypy.nodes.GeneratorExpr") -> str: condlists = o.condlists if any(o.condlists) else None return self.dump([o.left_expr, o.indices, o.sequences, condlists], o) - def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> str: + def visit_list_comprehension(self, o: "mypy.nodes.ListComprehension") -> str: return self.dump([o.generator], o) - def visit_set_comprehension(self, o: 'mypy.nodes.SetComprehension') -> str: + def visit_set_comprehension(self, o: "mypy.nodes.SetComprehension") -> str: return self.dump([o.generator], o) - def visit_dictionary_comprehension(self, o: 'mypy.nodes.DictionaryComprehension') -> str: + def visit_dictionary_comprehension(self, o: "mypy.nodes.DictionaryComprehension") -> str: condlists = o.condlists if any(o.condlists) else None return self.dump([o.key, o.value, o.indices, o.sequences, condlists], o) - def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> str: - return self.dump([('Condition', [o.cond]), o.if_expr, o.else_expr], o) + def visit_conditional_expr(self, o: "mypy.nodes.ConditionalExpr") -> str: + return self.dump([("Condition", [o.cond]), o.if_expr, o.else_expr], o) - def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> str: + def visit_slice_expr(self, o: "mypy.nodes.SliceExpr") -> str: a: List[Any] = [o.begin_index, o.end_index, o.stride] if not a[0]: - a[0] = '' + a[0] = "" if not a[1]: - a[1] = '' + a[1] = "" return self.dump(a, o) - def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str: + def visit_backquote_expr(self, o: "mypy.nodes.BackquoteExpr") -> str: return self.dump([o.expr], o) - def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: + def visit_temp_node(self, o: "mypy.nodes.TempNode") -> str: return self.dump([o.type], o) - def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> str: + def visit_as_pattern(self, o: "mypy.patterns.AsPattern") -> str: return self.dump([o.pattern, o.name], o) - def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> str: + def visit_or_pattern(self, o: "mypy.patterns.OrPattern") -> str: return self.dump(o.patterns, o) - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> str: + def visit_value_pattern(self, o: "mypy.patterns.ValuePattern") -> str: return self.dump([o.expr], o) - def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> str: + def visit_singleton_pattern(self, o: "mypy.patterns.SingletonPattern") -> str: return self.dump([o.value], o) - def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> str: + def visit_sequence_pattern(self, o: "mypy.patterns.SequencePattern") -> str: return self.dump(o.patterns, o) - def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> str: + def visit_starred_pattern(self, o: "mypy.patterns.StarredPattern") -> str: return self.dump([o.capture], o) - def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> str: + def visit_mapping_pattern(self, o: "mypy.patterns.MappingPattern") -> str: a: List[Any] = [] for i in range(len(o.keys)): - a.append(('Key', [o.keys[i]])) - a.append(('Value', [o.values[i]])) + a.append(("Key", [o.keys[i]])) + a.append(("Value", [o.values[i]])) if o.rest is not None: - a.append(('Rest', [o.rest])) + a.append(("Rest", [o.rest])) return self.dump(a, o) - def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> str: + def visit_class_pattern(self, o: "mypy.patterns.ClassPattern") -> str: a: List[Any] = [o.class_ref] if len(o.positionals) > 0: - a.append(('Positionals', o.positionals)) + a.append(("Positionals", o.positionals)) for i in range(len(o.keyword_keys)): - a.append(('Keyword', [o.keyword_keys[i], o.keyword_values[i]])) + a.append(("Keyword", [o.keyword_keys[i], o.keyword_values[i]])) return self.dump(a, o) -def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv') -> str: +def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: "StrConv") -> str: """Convert an array into a pretty-printed multiline string representation. The format is @@ -614,7 +616,7 @@ def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv' a: List[str] = [] if tag: - a.append(tag + '(') + a.append(tag + "(") for n in nodes: if isinstance(n, list): if n: @@ -629,12 +631,12 @@ def dump_tagged(nodes: Sequence[object], tag: Optional[str], str_conv: 'StrConv' elif n is not None: a.append(indent(str(n), 2)) if tag: - a[-1] += ')' - return '\n'.join(a) + a[-1] += ")" + return "\n".join(a) def indent(s: str, n: int) -> str: """Indent all the lines in s (separated by newlines) by n spaces.""" - s = ' ' * n + s - s = s.replace('\n', '\n' + ' ' * n) + s = " " * n + s + s = s.replace("\n", "\n" + " " * n) return s diff --git a/mypy/stubdoc.py b/mypy/stubdoc.py index 175b6f9f432cc..8ec975bd4a427 100644 --- a/mypy/stubdoc.py +++ b/mypy/stubdoc.py @@ -3,14 +3,21 @@ This module provides several functions to generate better stubs using docstrings and Sphinx docs (.rst files). """ -import re -import io import contextlib +import io +import re import tokenize - from typing import ( - Optional, MutableMapping, MutableSequence, List, Sequence, Tuple, NamedTuple, Any + Any, + List, + MutableMapping, + MutableSequence, + NamedTuple, + Optional, + Sequence, + Tuple, ) + from typing_extensions import Final # Type alias for signatures strings in format ('func_name', '(arg, opt_arg=False)'). @@ -23,9 +30,9 @@ def is_valid_type(s: str) -> bool: """Try to determine whether a string might be a valid type annotation.""" - if s in ('True', 'False', 'retval'): + if s in ("True", "False", "retval"): return False - if ',' in s and '[' not in s: + if "," in s and "[" not in s: return False return _TYPE_RE.match(s) is not None @@ -42,13 +49,17 @@ def __init__(self, name: str, type: Optional[str] = None, default: bool = False) self.default = default def __repr__(self) -> str: - return "ArgSig(name={}, type={}, default={})".format(repr(self.name), repr(self.type), - repr(self.default)) + return "ArgSig(name={}, type={}, default={})".format( + repr(self.name), repr(self.type), repr(self.default) + ) def __eq__(self, other: Any) -> bool: if isinstance(other, ArgSig): - return (self.name == other.name and self.type == other.type and - self.default == other.default) + return ( + self.name == other.name + and self.type == other.type + and self.default == other.default + ) return False @@ -87,12 +98,18 @@ def __init__(self, function_name: str) -> None: def add_token(self, token: tokenize.TokenInfo) -> None: """Process next token from the token stream.""" - if (token.type == tokenize.NAME and token.string == self.function_name and - self.state[-1] == STATE_INIT): + if ( + token.type == tokenize.NAME + and token.string == self.function_name + and self.state[-1] == STATE_INIT + ): self.state.append(STATE_FUNCTION_NAME) - elif (token.type == tokenize.OP and token.string == '(' and - self.state[-1] == STATE_FUNCTION_NAME): + elif ( + token.type == tokenize.OP + and token.string == "(" + and self.state[-1] == STATE_FUNCTION_NAME + ): self.state.pop() self.accumulator = "" self.found = True @@ -102,24 +119,36 @@ def add_token(self, token: tokenize.TokenInfo) -> None: # Reset state, function name not followed by '('. self.state.pop() - elif (token.type == tokenize.OP and token.string in ('[', '(', '{') and - self.state[-1] != STATE_INIT): + elif ( + token.type == tokenize.OP + and token.string in ("[", "(", "{") + and self.state[-1] != STATE_INIT + ): self.accumulator += token.string self.state.append(STATE_OPEN_BRACKET) - elif (token.type == tokenize.OP and token.string in (']', ')', '}') and - self.state[-1] == STATE_OPEN_BRACKET): + elif ( + token.type == tokenize.OP + and token.string in ("]", ")", "}") + and self.state[-1] == STATE_OPEN_BRACKET + ): self.accumulator += token.string self.state.pop() - elif (token.type == tokenize.OP and token.string == ':' and - self.state[-1] == STATE_ARGUMENT_LIST): + elif ( + token.type == tokenize.OP + and token.string == ":" + and self.state[-1] == STATE_ARGUMENT_LIST + ): self.arg_name = self.accumulator self.accumulator = "" self.state.append(STATE_ARGUMENT_TYPE) - elif (token.type == tokenize.OP and token.string == '=' and - self.state[-1] in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_TYPE)): + elif ( + token.type == tokenize.OP + and token.string == "=" + and self.state[-1] in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_TYPE) + ): if self.state[-1] == STATE_ARGUMENT_TYPE: self.arg_type = self.accumulator self.state.pop() @@ -128,9 +157,12 @@ def add_token(self, token: tokenize.TokenInfo) -> None: self.accumulator = "" self.state.append(STATE_ARGUMENT_DEFAULT) - elif (token.type == tokenize.OP and token.string in (',', ')') and - self.state[-1] in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_DEFAULT, - STATE_ARGUMENT_TYPE)): + elif ( + token.type == tokenize.OP + and token.string in (",", ")") + and self.state[-1] + in (STATE_ARGUMENT_LIST, STATE_ARGUMENT_DEFAULT, STATE_ARGUMENT_TYPE) + ): if self.state[-1] == STATE_ARGUMENT_DEFAULT: self.arg_default = self.accumulator self.state.pop() @@ -139,36 +171,43 @@ def add_token(self, token: tokenize.TokenInfo) -> None: self.state.pop() elif self.state[-1] == STATE_ARGUMENT_LIST: self.arg_name = self.accumulator - if not (token.string == ')' and self.accumulator.strip() == '') \ - and not _ARG_NAME_RE.match(self.arg_name): + if not ( + token.string == ")" and self.accumulator.strip() == "" + ) and not _ARG_NAME_RE.match(self.arg_name): # Invalid argument name. self.reset() return - if token.string == ')': + if token.string == ")": self.state.pop() # arg_name is empty when there are no args. e.g. func() if self.arg_name: try: - self.args.append(ArgSig(name=self.arg_name, type=self.arg_type, - default=bool(self.arg_default))) + self.args.append( + ArgSig( + name=self.arg_name, type=self.arg_type, default=bool(self.arg_default) + ) + ) except ValueError: # wrong type, use Any - self.args.append(ArgSig(name=self.arg_name, type=None, - default=bool(self.arg_default))) + self.args.append( + ArgSig(name=self.arg_name, type=None, default=bool(self.arg_default)) + ) self.arg_name = "" self.arg_type = None self.arg_default = None self.accumulator = "" - elif token.type == tokenize.OP and token.string == '->' and self.state[-1] == STATE_INIT: + elif token.type == tokenize.OP and token.string == "->" and self.state[-1] == STATE_INIT: self.accumulator = "" self.state.append(STATE_RETURN_VALUE) # ENDMAKER is necessary for python 3.4 and 3.5. - elif (token.type in (tokenize.NEWLINE, tokenize.ENDMARKER) and - self.state[-1] in (STATE_INIT, STATE_RETURN_VALUE)): + elif token.type in (tokenize.NEWLINE, tokenize.ENDMARKER) and self.state[-1] in ( + STATE_INIT, + STATE_RETURN_VALUE, + ): if self.state[-1] == STATE_RETURN_VALUE: if not is_valid_type(self.accumulator): self.reset() @@ -178,11 +217,12 @@ def add_token(self, token: tokenize.TokenInfo) -> None: self.state.pop() if self.found: - self.signatures.append(FunctionSig(name=self.function_name, args=self.args, - ret_type=self.ret_type)) + self.signatures.append( + FunctionSig(name=self.function_name, args=self.args, ret_type=self.ret_type) + ) self.found = False self.args = [] - self.ret_type = 'Any' + self.ret_type = "Any" # Leave state as INIT. else: self.accumulator += token.string @@ -195,11 +235,12 @@ def reset(self) -> None: def get_signatures(self) -> List[FunctionSig]: """Return sorted copy of the list of signatures found so far.""" + def has_arg(name: str, signature: FunctionSig) -> bool: return any(x.name == name for x in signature.args) def args_kwargs(signature: FunctionSig) -> bool: - return has_arg('*args', signature) and has_arg('**kwargs', signature) + return has_arg("*args", signature) and has_arg("**kwargs", signature) # Move functions with (*args, **kwargs) in their signature to last place. return list(sorted(self.signatures, key=lambda x: 1 if args_kwargs(x) else 0)) @@ -227,7 +268,7 @@ def infer_sig_from_docstring(docstr: Optional[str], name: str) -> Optional[List[ # Return all found signatures, even if there is a parse error after some are found. with contextlib.suppress(tokenize.TokenError): try: - tokens = tokenize.tokenize(io.BytesIO(docstr.encode('utf-8')).readline) + tokens = tokenize.tokenize(io.BytesIO(docstr.encode("utf-8")).readline) for token in tokens: state.add_token(token) except IndentationError: @@ -263,63 +304,59 @@ def infer_ret_type_sig_from_anon_docstring(docstr: str) -> Optional[str]: return infer_ret_type_sig_from_docstring("stub" + docstr.strip(), "stub") -def parse_signature(sig: str) -> Optional[Tuple[str, - List[str], - List[str]]]: +def parse_signature(sig: str) -> Optional[Tuple[str, List[str], List[str]]]: """Split function signature into its name, positional an optional arguments. The expected format is "func_name(arg, opt_arg=False)". Return the name of function and lists of positional and optional argument names. """ - m = re.match(r'([.a-zA-Z0-9_]+)\(([^)]*)\)', sig) + m = re.match(r"([.a-zA-Z0-9_]+)\(([^)]*)\)", sig) if not m: return None name = m.group(1) - name = name.split('.')[-1] + name = name.split(".")[-1] arg_string = m.group(2) if not arg_string.strip(): # Simple case -- no arguments. return name, [], [] - args = [arg.strip() for arg in arg_string.split(',')] + args = [arg.strip() for arg in arg_string.split(",")] positional = [] optional = [] i = 0 while i < len(args): # Accept optional arguments as in both formats: x=None and [x]. - if args[i].startswith('[') or '=' in args[i]: + if args[i].startswith("[") or "=" in args[i]: break - positional.append(args[i].rstrip('[')) + positional.append(args[i].rstrip("[")) i += 1 - if args[i - 1].endswith('['): + if args[i - 1].endswith("["): break while i < len(args): arg = args[i] - arg = arg.strip('[]') - arg = arg.split('=')[0] + arg = arg.strip("[]") + arg = arg.split("=")[0] optional.append(arg) i += 1 return name, positional, optional -def build_signature(positional: Sequence[str], - optional: Sequence[str]) -> str: +def build_signature(positional: Sequence[str], optional: Sequence[str]) -> str: """Build function signature from lists of positional and optional argument names.""" args: MutableSequence[str] = [] args.extend(positional) for arg in optional: - if arg.startswith('*'): + if arg.startswith("*"): args.append(arg) else: - args.append(f'{arg}=...') + args.append(f"{arg}=...") sig = f"({', '.join(args)})" # Ad-hoc fixes. - sig = sig.replace('(self)', '') + sig = sig.replace("(self)", "") return sig -def parse_all_signatures(lines: Sequence[str]) -> Tuple[List[Sig], - List[Sig]]: +def parse_all_signatures(lines: Sequence[str]) -> Tuple[List[Sig], List[Sig]]: """Parse all signatures in a given reST document. Return lists of found signatures for functions and classes. @@ -328,13 +365,13 @@ def parse_all_signatures(lines: Sequence[str]) -> Tuple[List[Sig], class_sigs = [] for line in lines: line = line.strip() - m = re.match(r'\.\. *(function|method|class) *:: *[a-zA-Z_]', line) + m = re.match(r"\.\. *(function|method|class) *:: *[a-zA-Z_]", line) if m: - sig = line.split('::')[1].strip() + sig = line.split("::")[1].strip() parsed = parse_signature(sig) if parsed: name, fixed, optional = parsed - if m.group(1) != 'class': + if m.group(1) != "class": sigs.append((name, build_signature(fixed, optional))) else: class_sigs.append((name, build_signature(fixed, optional))) @@ -366,6 +403,6 @@ def infer_prop_type_from_docstring(docstr: Optional[str]) -> Optional[str]: """ if not docstr: return None - test_str = r'^([a-zA-Z0-9_, \.\[\]]*): ' + test_str = r"^([a-zA-Z0-9_, \.\[\]]*): " m = re.match(test_str, docstr) return m.group(1) if m else None diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 34d01b337bacb..f6e1fd6c23ce5 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -45,113 +45,149 @@ - we don't seem to always detect properties ('closed' in 'io', for example) """ +import argparse import glob import os import os.path import sys import traceback -import argparse from collections import defaultdict +from typing import Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union, cast -from typing import ( - List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast, -) from typing_extensions import Final import mypy.build +import mypy.mixedtraverser import mypy.parse import mypy.traverser -import mypy.mixedtraverser import mypy.util from mypy import defaults +from mypy.build import build +from mypy.errors import CompileError, Errors +from mypy.find_sources import InvalidSourceList, create_source_list from mypy.modulefinder import ( - ModuleNotFoundReason, FindModuleCache, SearchPaths, BuildSource, default_lib_path + BuildSource, + FindModuleCache, + ModuleNotFoundReason, + SearchPaths, + default_lib_path, ) +from mypy.moduleinspect import ModuleInspect from mypy.nodes import ( - Expression, IntExpr, UnaryExpr, StrExpr, BytesExpr, NameExpr, FloatExpr, MemberExpr, - TupleExpr, ListExpr, ComparisonExpr, CallExpr, IndexExpr, EllipsisExpr, - ClassDef, MypyFile, Decorator, AssignmentStmt, TypeInfo, - IfStmt, ImportAll, ImportFrom, Import, FuncDef, FuncBase, Block, - Statement, OverloadedFuncDef, ARG_POS, ARG_STAR, ARG_STAR2, ARG_NAMED, + ARG_NAMED, + ARG_POS, + ARG_STAR, + ARG_STAR2, + AssignmentStmt, + Block, + BytesExpr, + CallExpr, + ClassDef, + ComparisonExpr, + Decorator, + EllipsisExpr, + Expression, + FloatExpr, + FuncBase, + FuncDef, + IfStmt, + Import, + ImportAll, + ImportFrom, + IndexExpr, + IntExpr, + ListExpr, + MemberExpr, + MypyFile, + NameExpr, + OverloadedFuncDef, + Statement, + StrExpr, + TupleExpr, + TypeInfo, + UnaryExpr, ) +from mypy.options import Options as MypyOptions +from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures from mypy.stubgenc import generate_stub_for_c_module from mypy.stubutil import ( - default_py2_interpreter, CantImport, generate_guarded, - walk_packages, find_module_path_and_all_py2, find_module_path_and_all_py3, - report_missing, fail_missing, remove_misplaced_type_comments, common_dir_prefix + CantImport, + common_dir_prefix, + default_py2_interpreter, + fail_missing, + find_module_path_and_all_py2, + find_module_path_and_all_py3, + generate_guarded, + remove_misplaced_type_comments, + report_missing, + walk_packages, ) -from mypy.stubdoc import parse_all_signatures, find_unique_signatures, Sig -from mypy.options import Options as MypyOptions +from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression from mypy.types import ( - Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, - AnyType, get_proper_type, OVERLOAD_NAMES + OVERLOAD_NAMES, + AnyType, + CallableType, + Instance, + NoneType, + TupleType, + Type, + TypeList, + TypeStrVisitor, + UnboundType, + get_proper_type, ) from mypy.visitor import NodeVisitor -from mypy.find_sources import create_source_list, InvalidSourceList -from mypy.build import build -from mypy.errors import CompileError, Errors -from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression -from mypy.moduleinspect import ModuleInspect -TYPING_MODULE_NAMES: Final = ( - 'typing', - 'typing_extensions', -) +TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions") # Common ways of naming package containing vendored modules. -VENDOR_PACKAGES: Final = [ - 'packages', - 'vendor', - 'vendored', - '_vendor', - '_vendored_packages', -] +VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"] # Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path). BLACKLIST: Final = [ - '/six.py\n', # Likely vendored six; too dynamic for us to handle - '/vendored/', # Vendored packages - '/vendor/', # Vendored packages - '/_vendor/', - '/_vendored_packages/', + "/six.py\n", # Likely vendored six; too dynamic for us to handle + "/vendored/", # Vendored packages + "/vendor/", # Vendored packages + "/_vendor/", + "/_vendored_packages/", ] # Special-cased names that are implicitly exported from the stub (from m import y as y). EXTRA_EXPORTED: Final = { - 'pyasn1_modules.rfc2437.univ', - 'pyasn1_modules.rfc2459.char', - 'pyasn1_modules.rfc2459.univ', + "pyasn1_modules.rfc2437.univ", + "pyasn1_modules.rfc2459.char", + "pyasn1_modules.rfc2459.univ", } # These names should be omitted from generated stubs. IGNORED_DUNDERS: Final = { - '__all__', - '__author__', - '__version__', - '__about__', - '__copyright__', - '__email__', - '__license__', - '__summary__', - '__title__', - '__uri__', - '__str__', - '__repr__', - '__getstate__', - '__setstate__', - '__slots__', + "__all__", + "__author__", + "__version__", + "__about__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__str__", + "__repr__", + "__getstate__", + "__setstate__", + "__slots__", } # These methods are expected to always return a non-trivial value. METHODS_WITH_RETURN_VALUE: Final = { - '__ne__', - '__eq__', - '__lt__', - '__le__', - '__gt__', - '__ge__', - '__hash__', - '__iter__', + "__ne__", + "__eq__", + "__lt__", + "__le__", + "__gt__", + "__ge__", + "__hash__", + "__iter__", } @@ -160,22 +196,25 @@ class Options: This class is mutable to simplify testing. """ - def __init__(self, - pyversion: Tuple[int, int], - no_import: bool, - doc_dir: str, - search_path: List[str], - interpreter: str, - parse_only: bool, - ignore_errors: bool, - include_private: bool, - output_dir: str, - modules: List[str], - packages: List[str], - files: List[str], - verbose: bool, - quiet: bool, - export_less: bool) -> None: + + def __init__( + self, + pyversion: Tuple[int, int], + no_import: bool, + doc_dir: str, + search_path: List[str], + interpreter: str, + parse_only: bool, + ignore_errors: bool, + include_private: bool, + output_dir: str, + modules: List[str], + packages: List[str], + files: List[str], + verbose: bool, + quiet: bool, + export_less: bool, + ) -> None: # See parse_options for descriptions of the flags. self.pyversion = pyversion self.no_import = no_import @@ -201,8 +240,10 @@ class StubSource: A simple extension of BuildSource that also carries the AST and the value of __all__ detected at runtime. """ - def __init__(self, module: str, path: Optional[str] = None, - runtime_all: Optional[List[str]] = None) -> None: + + def __init__( + self, module: str, path: Optional[str] = None, runtime_all: Optional[List[str]] = None + ) -> None: self.source = BuildSource(path, module, None) self.runtime_all = runtime_all self.ast: Optional[MypyFile] = None @@ -244,9 +285,10 @@ class AnnotationPrinter(TypeStrVisitor): callable types) since it prints the same string that reveal_type() does. * For Instance types it prints the fully qualified names. """ + # TODO: Generate valid string representation for callable types. # TODO: Use short names for Instances. - def __init__(self, stubgen: 'StubGenerator') -> None: + def __init__(self, stubgen: "StubGenerator") -> None: super().__init__() self.stubgen = stubgen @@ -259,14 +301,14 @@ def visit_unbound_type(self, t: UnboundType) -> str: s = t.name self.stubgen.import_tracker.require_name(s) if t.args: - s += f'[{self.args_str(t.args)}]' + s += f"[{self.args_str(t.args)}]" return s def visit_none_type(self, t: NoneType) -> str: return "None" def visit_type_list(self, t: TypeList) -> str: - return f'[{self.list_str(t.items)}]' + return f"[{self.list_str(t.items)}]" def args_str(self, args: Iterable[Type]) -> str: """Convert an array of arguments to strings and join the results with commas. @@ -274,7 +316,7 @@ def args_str(self, args: Iterable[Type]) -> str: The main difference from list_str is the preservation of quotes for string arguments """ - types = ['builtins.bytes', 'builtins.unicode'] + types = ["builtins.bytes", "builtins.unicode"] res = [] for arg in args: arg_str = arg.accept(self) @@ -282,7 +324,7 @@ def args_str(self, args: Iterable[Type]) -> str: res.append(f"'{arg_str}'") else: res.append(arg_str) - return ', '.join(res) + return ", ".join(res) class AliasPrinter(NodeVisitor[str]): @@ -290,7 +332,8 @@ class AliasPrinter(NodeVisitor[str]): Visit r.h.s of the definition to get the string representation of type alias. """ - def __init__(self, stubgen: 'StubGenerator') -> None: + + def __init__(self, stubgen: "StubGenerator") -> None: self.stubgen = stubgen super().__init__() @@ -303,11 +346,11 @@ def visit_call_expr(self, node: CallExpr) -> str: if kind == ARG_POS: args.append(arg.accept(self)) elif kind == ARG_STAR: - args.append('*' + arg.accept(self)) + args.append("*" + arg.accept(self)) elif kind == ARG_STAR2: - args.append('**' + arg.accept(self)) + args.append("**" + arg.accept(self)) elif kind == ARG_NAMED: - args.append(f'{name}={arg.accept(self)}') + args.append(f"{name}={arg.accept(self)}") else: raise ValueError(f"Unknown argument kind {kind} in call") return f"{callee}({', '.join(args)})" @@ -318,9 +361,9 @@ def visit_name_expr(self, node: NameExpr) -> str: def visit_member_expr(self, o: MemberExpr) -> str: node: Expression = o - trailer = '' + trailer = "" while isinstance(node, MemberExpr): - trailer = '.' + node.name + trailer + trailer = "." + node.name + trailer node = node.expr if not isinstance(node, NameExpr): return ERROR_MARKER @@ -403,10 +446,10 @@ def add_import(self, module: str, alias: Optional[str] = None) -> None: self.module_for[name] = None self.direct_imports[name] = module self.reverse_alias.pop(name, None) - name = name.rpartition('.')[0] + name = name.rpartition(".")[0] def require_name(self, name: str) -> None: - self.required_names.add(name.split('.')[0]) + self.required_names.add(name.split(".")[0]) def reexport(self, name: str) -> None: """Mark a given non qualified name as needed in __all__. @@ -436,9 +479,9 @@ def import_lines(self) -> List[str]: # This name was found in a from ... import ... # Collect the name in the module_map if name in self.reverse_alias: - name = f'{self.reverse_alias[name]} as {name}' + name = f"{self.reverse_alias[name]} as {name}" elif name in self.reexports: - name = f'{name} as {name}' + name = f"{name} as {name}" module_map[m].append(name) else: # This name was found in an import ... @@ -447,7 +490,7 @@ def import_lines(self) -> List[str]: source = self.reverse_alias[name] result.append(f"import {source} as {name}\n") elif name in self.reexports: - assert '.' not in name # Because reexports only has nonqualified names + assert "." not in name # Because reexports only has nonqualified names result.append(f"import {name} as {name}\n") else: result.append(f"import {self.direct_imports[name]}\n") @@ -524,24 +567,27 @@ def visit_callable_type(self, t: CallableType) -> None: t.ret_type.accept(self) def add_ref(self, fullname: str) -> None: - self.refs.add(fullname.split('.')[-1]) + self.refs.add(fullname.split(".")[-1]) class StubGenerator(mypy.traverser.TraverserVisitor): """Generate stub text from a mypy AST.""" - def __init__(self, - _all_: Optional[List[str]], pyversion: Tuple[int, int], - include_private: bool = False, - analyzed: bool = False, - export_less: bool = False) -> None: + def __init__( + self, + _all_: Optional[List[str]], + pyversion: Tuple[int, int], + include_private: bool = False, + analyzed: bool = False, + export_less: bool = False, + ) -> None: # Best known value of __all__. self._all_ = _all_ self._output: List[str] = [] self._decorators: List[str] = [] self._import_lines: List[str] = [] # Current indent level (indent is hardcoded to 4 spaces). - self._indent = '' + self._indent = "" # Stack of defined variables (per scope). self._vars: List[List[str]] = [[]] # What was generated previously in the stub file. @@ -579,17 +625,16 @@ def visit_mypy_file(self, o: MypyFile) -> None: if t not in self.defined_names: alias = None else: - alias = '_' + t + alias = "_" + t self.import_tracker.add_import_from(pkg, [(t, alias)]) super().visit_mypy_file(o) - undefined_names = [name for name in self._all_ or [] - if name not in self._toplevel_names] + undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] if undefined_names: if self._state != EMPTY: - self.add('\n') - self.add('# Names in __all__ with no definition:\n') + self.add("\n") + self.add("# Names in __all__ with no definition:\n") for name in sorted(undefined_names): - self.add(f'# {name}\n') + self.add(f"# {name}\n") def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: """@property with setters and getters, or @overload chain""" @@ -613,15 +658,18 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: # skip the overload implementation and clear the decorator we just processed self.clear_decorators() - def visit_func_def(self, o: FuncDef, is_abstract: bool = False, - is_overload: bool = False) -> None: - if (self.is_private_name(o.name, o.fullname) - or self.is_not_in_all(o.name) - or (self.is_recorded_name(o.name) and not is_overload)): + def visit_func_def( + self, o: FuncDef, is_abstract: bool = False, is_overload: bool = False + ) -> None: + if ( + self.is_private_name(o.name, o.fullname) + or self.is_not_in_all(o.name) + or (self.is_recorded_name(o.name) and not is_overload) + ): self.clear_decorators() return if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine: - self.add('\n') + self.add("\n") if not self.is_top_level(): self_inits = find_self_initializers(o) for init, value in self_inits: @@ -642,12 +690,15 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, var = arg_.variable kind = arg_.kind name = var.name - annotated_type = (o.unanalyzed_type.arg_types[i] - if isinstance(o.unanalyzed_type, CallableType) else None) + annotated_type = ( + o.unanalyzed_type.arg_types[i] + if isinstance(o.unanalyzed_type, CallableType) + else None + ) # I think the name check is incorrect: there are libraries which # name their 0th argument other than self/cls - is_self_arg = i == 0 and name == 'self' - is_cls_arg = i == 0 and name == 'cls' + is_self_arg = i == 0 and name == "self" + is_cls_arg = i == 0 and name == "cls" annotation = "" if annotated_type and not is_self_arg and not is_cls_arg: # Luckily, an argument explicitly annotated with "Any" has @@ -655,28 +706,28 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, if not isinstance(get_proper_type(annotated_type), AnyType): annotation = f": {self.print_annotation(annotated_type)}" - if kind.is_named() and not any(arg.startswith('*') for arg in args): - args.append('*') + if kind.is_named() and not any(arg.startswith("*") for arg in args): + args.append("*") if arg_.initializer: if not annotation: typename = self.get_str_type_of_node(arg_.initializer, True, False) - if typename == '': - annotation = '=...' + if typename == "": + annotation = "=..." else: - annotation = f': {typename} = ...' + annotation = f": {typename} = ..." else: - annotation += ' = ...' + annotation += " = ..." arg = name + annotation elif kind == ARG_STAR: - arg = f'*{name}{annotation}' + arg = f"*{name}{annotation}" elif kind == ARG_STAR2: - arg = f'**{name}{annotation}' + arg = f"**{name}{annotation}" else: arg = name + annotation args.append(arg) retname = None - if o.name != '__init__' and isinstance(o.unanalyzed_type, CallableType): + if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType): if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): # Luckily, a return type explicitly annotated with "Any" has # type "UnboundType" and will enter the else branch. @@ -688,29 +739,29 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, # some dunder methods should not have a None return type. retname = None # implicit Any elif has_yield_expression(o): - self.add_abc_import('Generator') - yield_name = 'None' - send_name = 'None' - return_name = 'None' + self.add_abc_import("Generator") + yield_name = "None" + send_name = "None" + return_name = "None" for expr, in_assignment in all_yield_expressions(o): if expr.expr is not None and not self.is_none_expr(expr.expr): - self.add_typing_import('Incomplete') - yield_name = 'Incomplete' + self.add_typing_import("Incomplete") + yield_name = "Incomplete" if in_assignment: - self.add_typing_import('Incomplete') - send_name = 'Incomplete' + self.add_typing_import("Incomplete") + send_name = "Incomplete" if has_return_statement(o): - self.add_typing_import('Incomplete') - return_name = 'Incomplete' - generator_name = self.typing_name('Generator') - retname = f'{generator_name}[{yield_name}, {send_name}, {return_name}]' + self.add_typing_import("Incomplete") + return_name = "Incomplete" + generator_name = self.typing_name("Generator") + retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]" elif not has_return_statement(o) and not is_abstract: - retname = 'None' - retfield = '' + retname = "None" + retfield = "" if retname is not None: - retfield = ' -> ' + retname + retfield = " -> " + retname - self.add(', '.join(args)) + self.add(", ".join(args)) self.add(f"){retfield}: ...\n") self._state = FUNC @@ -758,36 +809,39 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup is_abstract = False is_overload = False name = expr.name - if name in ('property', 'staticmethod', 'classmethod'): + if name in ("property", "staticmethod", "classmethod"): self.add_decorator(name) - elif self.import_tracker.module_for.get(name) in ('asyncio', - 'asyncio.coroutines', - 'types'): + elif self.import_tracker.module_for.get(name) in ( + "asyncio", + "asyncio.coroutines", + "types", + ): self.add_coroutine_decorator(context.func, name, name) - elif self.refers_to_fullname(name, 'abc.abstractmethod'): + elif self.refers_to_fullname(name, "abc.abstractmethod"): self.add_decorator(name) self.import_tracker.require_name(name) is_abstract = True - elif self.refers_to_fullname(name, 'abc.abstractproperty'): - self.add_decorator('property') - self.add_decorator('abc.abstractmethod') + elif self.refers_to_fullname(name, "abc.abstractproperty"): + self.add_decorator("property") + self.add_decorator("abc.abstractmethod") is_abstract = True elif self.refers_to_fullname(name, OVERLOAD_NAMES): self.add_decorator(name) - self.add_typing_import('overload') + self.add_typing_import("overload") is_overload = True return is_abstract, is_overload def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool: if isinstance(fullname, tuple): return any(self.refers_to_fullname(name, fname) for fname in fullname) - module, short = fullname.rsplit('.', 1) - return (self.import_tracker.module_for.get(name) == module and - (name == short or - self.import_tracker.reverse_alias.get(name) == short)) - - def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> Tuple[bool, - bool]: + module, short = fullname.rsplit(".", 1) + return self.import_tracker.module_for.get(name) == module and ( + name == short or self.import_tracker.reverse_alias.get(name) == short + ) + + def process_member_expr_decorator( + self, expr: MemberExpr, context: Decorator + ) -> Tuple[bool, bool]: """Process a function decorator of form @foo.bar. Only preserve certain special decorators such as @abstractmethod. @@ -798,42 +852,55 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) -> """ is_abstract = False is_overload = False - if expr.name == 'setter' and isinstance(expr.expr, NameExpr): - self.add_decorator(f'{expr.expr.name}.setter') - elif (isinstance(expr.expr, NameExpr) and - (expr.expr.name == 'abc' or - self.import_tracker.reverse_alias.get(expr.expr.name) == 'abc') and - expr.name in ('abstractmethod', 'abstractproperty')): - if expr.name == 'abstractproperty': + if expr.name == "setter" and isinstance(expr.expr, NameExpr): + self.add_decorator(f"{expr.expr.name}.setter") + elif ( + isinstance(expr.expr, NameExpr) + and ( + expr.expr.name == "abc" + or self.import_tracker.reverse_alias.get(expr.expr.name) == "abc" + ) + and expr.name in ("abstractmethod", "abstractproperty") + ): + if expr.name == "abstractproperty": self.import_tracker.require_name(expr.expr.name) - self.add_decorator('%s' % ('property')) - self.add_decorator('{}.{}'.format(expr.expr.name, 'abstractmethod')) + self.add_decorator("%s" % ("property")) + self.add_decorator("{}.{}".format(expr.expr.name, "abstractmethod")) else: self.import_tracker.require_name(expr.expr.name) - self.add_decorator(f'{expr.expr.name}.{expr.name}') + self.add_decorator(f"{expr.expr.name}.{expr.name}") is_abstract = True - elif expr.name == 'coroutine': - if (isinstance(expr.expr, MemberExpr) and - expr.expr.name == 'coroutines' and - isinstance(expr.expr.expr, NameExpr) and - (expr.expr.expr.name == 'asyncio' or - self.import_tracker.reverse_alias.get(expr.expr.expr.name) == - 'asyncio')): - self.add_coroutine_decorator(context.func, - '%s.coroutines.coroutine' % - (expr.expr.expr.name,), - expr.expr.expr.name) - elif (isinstance(expr.expr, NameExpr) and - (expr.expr.name in ('asyncio', 'types') or - self.import_tracker.reverse_alias.get(expr.expr.name) in - ('asyncio', 'asyncio.coroutines', 'types'))): - self.add_coroutine_decorator(context.func, - expr.expr.name + '.coroutine', - expr.expr.name) - elif (isinstance(expr.expr, NameExpr) and - (expr.expr.name in TYPING_MODULE_NAMES or - self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and - expr.name == 'overload'): + elif expr.name == "coroutine": + if ( + isinstance(expr.expr, MemberExpr) + and expr.expr.name == "coroutines" + and isinstance(expr.expr.expr, NameExpr) + and ( + expr.expr.expr.name == "asyncio" + or self.import_tracker.reverse_alias.get(expr.expr.expr.name) == "asyncio" + ) + ): + self.add_coroutine_decorator( + context.func, + "%s.coroutines.coroutine" % (expr.expr.expr.name,), + expr.expr.expr.name, + ) + elif isinstance(expr.expr, NameExpr) and ( + expr.expr.name in ("asyncio", "types") + or self.import_tracker.reverse_alias.get(expr.expr.name) + in ("asyncio", "asyncio.coroutines", "types") + ): + self.add_coroutine_decorator( + context.func, expr.expr.name + ".coroutine", expr.expr.name + ) + elif ( + isinstance(expr.expr, NameExpr) + and ( + expr.expr.name in TYPING_MODULE_NAMES + or self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES + ) + and expr.name == "overload" + ): self.import_tracker.require_name(expr.expr.name) self.add_decorator(f"{expr.expr.name}.overload") is_overload = True @@ -844,8 +911,8 @@ def visit_class_def(self, o: ClassDef) -> None: sep: Optional[int] = None if not self._indent and self._state != EMPTY: sep = len(self._output) - self.add('\n') - self.add(f'{self._indent}class {o.name}') + self.add("\n") + self.add(f"{self._indent}class {o.name}") self.record_name(o.name) base_types = self.get_base_types(o) if base_types: @@ -853,22 +920,22 @@ def visit_class_def(self, o: ClassDef) -> None: self.import_tracker.require_name(base) if isinstance(o.metaclass, (NameExpr, MemberExpr)): meta = o.metaclass.accept(AliasPrinter(self)) - base_types.append('metaclass=' + meta) + base_types.append("metaclass=" + meta) elif self.analyzed and o.info.is_abstract: - base_types.append('metaclass=abc.ABCMeta') - self.import_tracker.add_import('abc') - self.import_tracker.require_name('abc') + base_types.append("metaclass=abc.ABCMeta") + self.import_tracker.add_import("abc") + self.import_tracker.require_name("abc") elif self.analyzed and o.info.is_protocol: - type_str = 'Protocol' + type_str = "Protocol" if o.info.type_vars: type_str += f'[{", ".join(o.info.type_vars)}]' base_types.append(type_str) - self.add_typing_import('Protocol') + self.add_typing_import("Protocol") if base_types: self.add(f"({', '.join(base_types)})") - self.add(':\n') + self.add(":\n") n = len(self._output) - self._indent += ' ' + self._indent += " " self._vars.append([]) super().visit_class_def(o) self._indent = self._indent[:-4] @@ -876,8 +943,8 @@ def visit_class_def(self, o: ClassDef) -> None: self._vars[-1].append(o.name) if len(self._output) == n: if self._state == EMPTY_CLASS and sep is not None: - self._output[sep] = '' - self._output[-1] = self._output[-1][:-1] + ' ...\n' + self._output[sep] = "" + self._output[-1] = self._output[-1][:-1] + " ...\n" self._state = EMPTY_CLASS else: self._state = CLASS @@ -888,11 +955,11 @@ def get_base_types(self, cdef: ClassDef) -> List[str]: base_types: List[str] = [] for base in cdef.base_type_exprs: if isinstance(base, NameExpr): - if base.name != 'object': + if base.name != "object": base_types.append(base.name) elif isinstance(base, MemberExpr): modname = get_qualified_name(base.expr) - base_types.append(f'{modname}.{base.name}') + base_types.append(f"{modname}.{base.name}") elif isinstance(base, IndexExpr): p = AliasPrinter(self) base_types.append(base.accept(p)) @@ -912,10 +979,15 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: assert isinstance(o.rvalue, CallExpr) self.process_namedtuple(lvalue, o.rvalue) continue - if (self.is_top_level() and - isinstance(lvalue, NameExpr) and not self.is_private_name(lvalue.name) and - # it is never an alias with explicit annotation - not o.unanalyzed_type and self.is_alias_expression(o.rvalue)): + if ( + self.is_top_level() + and isinstance(lvalue, NameExpr) + and not self.is_private_name(lvalue.name) + and + # it is never an alias with explicit annotation + not o.unanalyzed_type + and self.is_alias_expression(o.rvalue) + ): self.process_typealias(lvalue, o.rvalue) continue if isinstance(lvalue, TupleExpr) or isinstance(lvalue, ListExpr): @@ -934,9 +1006,8 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: init = self.get_init(item.name, o.rvalue, annotation) if init: found = True - if not sep and not self._indent and \ - self._state not in (EMPTY, VAR): - init = '\n' + init + if not sep and not self._indent and self._state not in (EMPTY, VAR): + init = "\n" + init sep = True self.add(init) self.record_name(item.name) @@ -949,30 +1020,31 @@ def is_namedtuple(self, expr: Expression) -> bool: if not isinstance(expr, CallExpr): return False callee = expr.callee - return ((isinstance(callee, NameExpr) and callee.name.endswith('namedtuple')) or - (isinstance(callee, MemberExpr) and callee.name == 'namedtuple')) + return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or ( + isinstance(callee, MemberExpr) and callee.name == "namedtuple" + ) def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: if self._state != EMPTY: - self.add('\n') + self.add("\n") if isinstance(rvalue.args[1], StrExpr): - items = rvalue.args[1].value.replace(',', ' ').split() + items = rvalue.args[1].value.replace(",", " ").split() elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)): list_items = cast(List[StrExpr], rvalue.args[1].items) items = [item.value for item in list_items] else: - self.add(f'{self._indent}{lvalue.name}: Incomplete') - self.import_tracker.require_name('Incomplete') + self.add(f"{self._indent}{lvalue.name}: Incomplete") + self.import_tracker.require_name("Incomplete") return - self.import_tracker.require_name('NamedTuple') - self.add(f'{self._indent}class {lvalue.name}(NamedTuple):') + self.import_tracker.require_name("NamedTuple") + self.add(f"{self._indent}class {lvalue.name}(NamedTuple):") if len(items) == 0: - self.add(' ...\n') + self.add(" ...\n") else: - self.import_tracker.require_name('Incomplete') - self.add('\n') + self.import_tracker.require_name("Incomplete") + self.add("\n") for item in items: - self.add(f'{self._indent} {item}: Incomplete\n') + self.add(f"{self._indent} {item}: Incomplete\n") self._state = CLASS def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: @@ -982,31 +1054,38 @@ def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: or module alias. """ # Assignment of TypeVar(...) are passed through - if (isinstance(expr, CallExpr) and - isinstance(expr.callee, NameExpr) and - expr.callee.name == 'TypeVar'): + if ( + isinstance(expr, CallExpr) + and isinstance(expr.callee, NameExpr) + and expr.callee.name == "TypeVar" + ): return True elif isinstance(expr, EllipsisExpr): return not top_level elif isinstance(expr, NameExpr): - if expr.name in ('True', 'False'): + if expr.name in ("True", "False"): return False - elif expr.name == 'None': + elif expr.name == "None": return not top_level else: return not self.is_private_name(expr.name) elif isinstance(expr, MemberExpr) and self.analyzed: # Also add function and module aliases. - return ((top_level and isinstance(expr.node, (FuncDef, Decorator, MypyFile)) - or isinstance(expr.node, TypeInfo)) and - not self.is_private_member(expr.node.fullname)) - elif (isinstance(expr, IndexExpr) and isinstance(expr.base, NameExpr) and - not self.is_private_name(expr.base.name)): + return ( + top_level + and isinstance(expr.node, (FuncDef, Decorator, MypyFile)) + or isinstance(expr.node, TypeInfo) + ) and not self.is_private_member(expr.node.fullname) + elif ( + isinstance(expr, IndexExpr) + and isinstance(expr.base, NameExpr) + and not self.is_private_name(expr.base.name) + ): if isinstance(expr.index, TupleExpr): indices = expr.index.items else: indices = [expr.index] - if expr.base.name == 'Callable' and len(indices) == 2: + if expr.base.name == "Callable" and len(indices) == 2: args, ret = indices if isinstance(args, EllipsisExpr): indices = [ret] @@ -1027,11 +1106,13 @@ def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None: def visit_if_stmt(self, o: IfStmt) -> None: # Ignore if __name__ == '__main__'. expr = o.expr[0] - if (isinstance(expr, ComparisonExpr) and - isinstance(expr.operands[0], NameExpr) and - isinstance(expr.operands[1], StrExpr) and - expr.operands[0].name == '__name__' and - '__main__' in expr.operands[1].value): + if ( + isinstance(expr, ComparisonExpr) + and isinstance(expr.operands[0], NameExpr) + and isinstance(expr.operands[1], StrExpr) + and expr.operands[0].name == "__name__" + and "__main__" in expr.operands[1].value + ): return super().visit_if_stmt(o) @@ -1044,41 +1125,44 @@ def visit_import_from(self, o: ImportFrom) -> None: module, relative = translate_module_name(o.id, o.relative) if self.module: full_module, ok = mypy.util.correct_relative_import( - self.module, relative, module, self.path.endswith('.__init__.py') + self.module, relative, module, self.path.endswith(".__init__.py") ) if not ok: full_module = module else: full_module = module - if module == '__future__': + if module == "__future__": return # Not preserved for name, as_name in o.names: - if name == 'six': + if name == "six": # Vendored six -- translate into plain 'import six'. - self.visit_import(Import([('six', None)])) + self.visit_import(Import([("six", None)])) continue exported = False - if as_name is None and self.module and (self.module + '.' + name) in EXTRA_EXPORTED: + if as_name is None and self.module and (self.module + "." + name) in EXTRA_EXPORTED: # Special case certain names that should be exported, against our general rules. exported = True - is_private = self.is_private_name(name, full_module + '.' + name) - if (as_name is None - and name not in self.referenced_names - and (not self._all_ or name in IGNORED_DUNDERS) - and not is_private - and module not in ('abc', 'asyncio') + TYPING_MODULE_NAMES): + is_private = self.is_private_name(name, full_module + "." + name) + if ( + as_name is None + and name not in self.referenced_names + and (not self._all_ or name in IGNORED_DUNDERS) + and not is_private + and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES + ): # An imported name that is never referenced in the module is assumed to be # exported, unless there is an explicit __all__. Note that we need to special # case 'abc' since some references are deleted during semantic analysis. exported = True - top_level = full_module.split('.')[0] - if (as_name is None - and not self.export_less - and (not self._all_ or name in IGNORED_DUNDERS) - and self.module - and not is_private - and top_level in (self.module.split('.')[0], - '_' + self.module.split('.')[0])): + top_level = full_module.split(".")[0] + if ( + as_name is None + and not self.export_less + and (not self._all_ or name in IGNORED_DUNDERS) + and self.module + and not is_private + and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0]) + ): # Export imports from the same package, since we can't reliably tell whether they # are part of the public API. exported = True @@ -1086,29 +1170,33 @@ def visit_import_from(self, o: ImportFrom) -> None: self.import_tracker.reexport(name) as_name = name import_names.append((name, as_name)) - self.import_tracker.add_import_from('.' * relative + module, import_names) + self.import_tracker.add_import_from("." * relative + module, import_names) self._vars[-1].extend(alias or name for name, alias in import_names) for name, alias in import_names: self.record_name(alias or name) if self._all_: # Include "import from"s that import names defined in __all__. - names = [name for name, alias in o.names - if name in self._all_ and alias is None and name not in IGNORED_DUNDERS] + names = [ + name + for name, alias in o.names + if name in self._all_ and alias is None and name not in IGNORED_DUNDERS + ] exported_names.update(names) def visit_import(self, o: Import) -> None: for id, as_id in o.ids: self.import_tracker.add_import(id, as_id) if as_id is None: - target_name = id.split('.')[0] + target_name = id.split(".")[0] else: target_name = as_id self._vars[-1].append(target_name) self.record_name(target_name) - def get_init(self, lvalue: str, rvalue: Expression, - annotation: Optional[Type] = None) -> Optional[str]: + def get_init( + self, lvalue: str, rvalue: Expression, annotation: Optional[Type] = None + ) -> Optional[str]: """Return initializer for a variable. Return None if we've generated one already or if the variable is internal. @@ -1122,15 +1210,18 @@ def get_init(self, lvalue: str, rvalue: Expression, self._vars[-1].append(lvalue) if annotation is not None: typename = self.print_annotation(annotation) - if (isinstance(annotation, UnboundType) and not annotation.args and - annotation.name == 'Final' and - self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES): + if ( + isinstance(annotation, UnboundType) + and not annotation.args + and annotation.name == "Final" + and self.import_tracker.module_for.get("Final") in TYPING_MODULE_NAMES + ): # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) - typename += f'[{final_arg}]' + typename += f"[{final_arg}]" else: typename = self.get_str_type_of_node(rvalue) - return f'{self._indent}{lvalue}: {typename}\n' + return f"{self._indent}{lvalue}: {typename}\n" def add(self, string: str) -> None: """Add text to generated stub.""" @@ -1138,8 +1229,8 @@ def add(self, string: str) -> None: def add_decorator(self, name: str) -> None: if not self._indent and self._state not in (EMPTY, FUNC): - self._decorators.append('\n') - self._decorators.append(f'{self._indent}@{name}\n') + self._decorators.append("\n") + self._decorators.append(f"{self._indent}@{name}\n") def clear_decorators(self) -> None: self._decorators.clear() @@ -1147,7 +1238,7 @@ def clear_decorators(self) -> None: def typing_name(self, name: str) -> str: if name in self.defined_names: # Avoid name clash between name from typing and a name defined in stub. - return '_' + name + return "_" + name else: return name @@ -1179,13 +1270,13 @@ def add_coroutine_decorator(self, func: FuncDef, name: str, require_name: str) - def output(self) -> str: """Return the text for the stub.""" - imports = '' + imports = "" if self._import_lines: - imports += ''.join(self._import_lines) - imports += ''.join(self.import_tracker.import_lines()) + imports += "".join(self._import_lines) + imports += "".join(self.import_tracker.import_lines()) if imports and self._output: - imports += '\n' - return imports + ''.join(self._output) + imports += "\n" + return imports + "".join(self._output) def is_not_in_all(self, name: str) -> bool: if self.is_private_name(name): @@ -1199,40 +1290,38 @@ def is_private_name(self, name: str, fullname: Optional[str] = None) -> bool: return False if fullname in EXTRA_EXPORTED: return False - return name.startswith('_') and (not name.endswith('__') - or name in IGNORED_DUNDERS) + return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS) def is_private_member(self, fullname: str) -> bool: - parts = fullname.split('.') + parts = fullname.split(".") for part in parts: if self.is_private_name(part): return True return False - def get_str_type_of_node(self, rvalue: Expression, - can_infer_optional: bool = False, - can_be_any: bool = True) -> str: + def get_str_type_of_node( + self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True + ) -> str: if isinstance(rvalue, IntExpr): - return 'int' + return "int" if isinstance(rvalue, StrExpr): - return 'str' + return "str" if isinstance(rvalue, BytesExpr): - return 'bytes' + return "bytes" if isinstance(rvalue, FloatExpr): - return 'float' + return "float" if isinstance(rvalue, UnaryExpr) and isinstance(rvalue.expr, IntExpr): - return 'int' - if isinstance(rvalue, NameExpr) and rvalue.name in ('True', 'False'): - return 'bool' - if can_infer_optional and \ - isinstance(rvalue, NameExpr) and rvalue.name == 'None': - self.add_typing_import('Incomplete') + return "int" + if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"): + return "bool" + if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None": + self.add_typing_import("Incomplete") return f"{self.typing_name('Incomplete')} | None" if can_be_any: - self.add_typing_import('Incomplete') - return self.typing_name('Incomplete') + self.add_typing_import("Incomplete") + return self.typing_name("Incomplete") else: - return '' + return "" def print_annotation(self, t: Type) -> str: printer = AnnotationPrinter(self) @@ -1240,7 +1329,7 @@ def print_annotation(self, t: Type) -> str: def is_top_level(self) -> bool: """Are we processing the top level of a file?""" - return self._indent == '' + return self._indent == "" def record_name(self, name: str) -> None: """Mark a name as defined. @@ -1275,9 +1364,11 @@ def __init__(self) -> None: def visit_assignment_stmt(self, o: AssignmentStmt) -> None: lvalue = o.lvalues[0] - if (isinstance(lvalue, MemberExpr) and - isinstance(lvalue.expr, NameExpr) and - lvalue.expr.name == 'self'): + if ( + isinstance(lvalue, MemberExpr) + and isinstance(lvalue.expr, NameExpr) + and lvalue.expr.name == "self" + ): self.results.append((lvalue.name, o.rvalue)) @@ -1295,48 +1386,50 @@ def get_qualified_name(o: Expression) -> str: if isinstance(o, NameExpr): return o.name elif isinstance(o, MemberExpr): - return f'{get_qualified_name(o.expr)}.{o.name}' + return f"{get_qualified_name(o.expr)}.{o.name}" else: return ERROR_MARKER def remove_blacklisted_modules(modules: List[StubSource]) -> List[StubSource]: - return [module for module in modules - if module.path is None or not is_blacklisted_path(module.path)] + return [ + module for module in modules if module.path is None or not is_blacklisted_path(module.path) + ] def is_blacklisted_path(path: str) -> bool: - return any(substr in (normalize_path_separators(path) + '\n') - for substr in BLACKLIST) + return any(substr in (normalize_path_separators(path) + "\n") for substr in BLACKLIST) def normalize_path_separators(path: str) -> str: - if sys.platform == 'win32': - return path.replace('\\', '/') + if sys.platform == "win32": + return path.replace("\\", "/") return path -def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[List[StubSource], - List[StubSource]]: +def collect_build_targets( + options: Options, mypy_opts: MypyOptions +) -> Tuple[List[StubSource], List[StubSource]]: """Collect files for which we need to generate stubs. Return list of Python modules and C modules. """ if options.packages or options.modules: if options.no_import: - py_modules = find_module_paths_using_search(options.modules, - options.packages, - options.search_path, - options.pyversion) + py_modules = find_module_paths_using_search( + options.modules, options.packages, options.search_path, options.pyversion + ) c_modules: List[StubSource] = [] else: # Using imports is the default, since we can also find C modules. - py_modules, c_modules = find_module_paths_using_imports(options.modules, - options.packages, - options.interpreter, - options.pyversion, - options.verbose, - options.quiet) + py_modules, c_modules = find_module_paths_using_imports( + options.modules, + options.packages, + options.interpreter, + options.pyversion, + options.verbose, + options.quiet, + ) else: # Use mypy native source collection for files and directories. try: @@ -1351,13 +1444,14 @@ def collect_build_targets(options: Options, mypy_opts: MypyOptions) -> Tuple[Lis return py_modules, c_modules -def find_module_paths_using_imports(modules: List[str], - packages: List[str], - interpreter: str, - pyversion: Tuple[int, int], - verbose: bool, - quiet: bool) -> Tuple[List[StubSource], - List[StubSource]]: +def find_module_paths_using_imports( + modules: List[str], + packages: List[str], + interpreter: str, + pyversion: Tuple[int, int], + verbose: bool, + quiet: bool, +) -> Tuple[List[StubSource], List[StubSource]]: """Find path and runtime value of __all__ (if possible) for modules and packages. This function uses runtime Python imports to get the information. @@ -1367,9 +1461,9 @@ def find_module_paths_using_imports(modules: List[str], c_modules: List[StubSource] = [] found = list(walk_packages(inspect, packages, verbose)) modules = modules + found - modules = [mod - for mod in modules - if not is_non_library_module(mod)] # We don't want to run any tests or scripts + modules = [ + mod for mod in modules if not is_non_library_module(mod) + ] # We don't want to run any tests or scripts for mod in modules: try: if pyversion[0] == 2: @@ -1393,45 +1487,48 @@ def find_module_paths_using_imports(modules: List[str], def is_non_library_module(module: str) -> bool: """Does module look like a test module or a script?""" - if module.endswith(( - '.tests', - '.test', - '.testing', - '_tests', - '_test_suite', - 'test_util', - 'test_utils', - 'test_base', - '.__main__', - '.conftest', # Used by pytest - '.setup', # Typically an install script - )): + if module.endswith( + ( + ".tests", + ".test", + ".testing", + "_tests", + "_test_suite", + "test_util", + "test_utils", + "test_base", + ".__main__", + ".conftest", # Used by pytest + ".setup", # Typically an install script + ) + ): return True - if module.split('.')[-1].startswith('test_'): + if module.split(".")[-1].startswith("test_"): return True - if ('.tests.' in module - or '.test.' in module - or '.testing.' in module - or '.SelfTest.' in module): + if ( + ".tests." in module + or ".test." in module + or ".testing." in module + or ".SelfTest." in module + ): return True return False def translate_module_name(module: str, relative: int) -> Tuple[str, int]: for pkg in VENDOR_PACKAGES: - for alt in 'six.moves', 'six': - substr = f'{pkg}.{alt}' - if (module.endswith('.' + substr) - or (module == substr and relative)): + for alt in "six.moves", "six": + substr = f"{pkg}.{alt}" + if module.endswith("." + substr) or (module == substr and relative): return alt, 0 - if '.' + substr + '.' in module: - return alt + '.' + module.partition('.' + substr + '.')[2], 0 + if "." + substr + "." in module: + return alt + "." + module.partition("." + substr + ".")[2], 0 return module, relative -def find_module_paths_using_search(modules: List[str], packages: List[str], - search_path: List[str], - pyversion: Tuple[int, int]) -> List[StubSource]: +def find_module_paths_using_search( + modules: List[str], packages: List[str], search_path: List[str], pyversion: Tuple[int, int] +) -> List[StubSource]: """Find sources for modules and packages requested. This function just looks for source files at the file system level. @@ -1440,7 +1537,7 @@ def find_module_paths_using_search(modules: List[str], packages: List[str], """ result: List[StubSource] = [] typeshed_path = default_lib_path(mypy.build.default_data_dir(), pyversion, None) - search_paths = SearchPaths(('.',) + tuple(search_path), (), (), tuple(typeshed_path)) + search_paths = SearchPaths((".",) + tuple(search_path), (), (), tuple(typeshed_path)) cache = FindModuleCache(search_paths, fscache=None, options=None) for module in modules: m_result = cache.find_module(module) @@ -1465,7 +1562,7 @@ def find_module_paths_using_search(modules: List[str], packages: List[str], def mypy_options(stubgen_options: Options) -> MypyOptions: """Generate mypy options using the flag passed by user.""" options = MypyOptions() - options.follow_imports = 'skip' + options.follow_imports = "skip" options.incremental = False options.ignore_errors = True options.semantic_analysis_only = True @@ -1482,29 +1579,29 @@ def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None: If there are syntax errors, print them and exit. """ assert mod.path is not None, "Not found module was not skipped" - with open(mod.path, 'rb') as f: + with open(mod.path, "rb") as f: data = f.read() source = mypy.util.decode_python_encoding(data, mypy_options.python_version) errors = Errors() - mod.ast = mypy.parse.parse(source, fnam=mod.path, module=mod.module, - errors=errors, options=mypy_options) + mod.ast = mypy.parse.parse( + source, fnam=mod.path, module=mod.module, errors=errors, options=mypy_options + ) mod.ast._fullname = mod.module if errors.is_blockers(): # Syntax error! for m in errors.new_messages(): - sys.stderr.write(f'{m}\n') + sys.stderr.write(f"{m}\n") sys.exit(1) -def generate_asts_for_modules(py_modules: List[StubSource], - parse_only: bool, - mypy_options: MypyOptions, - verbose: bool) -> None: +def generate_asts_for_modules( + py_modules: List[StubSource], parse_only: bool, mypy_options: MypyOptions, verbose: bool +) -> None: """Use mypy to parse (and optionally analyze) source files.""" if not py_modules: return # Nothing to do here, but there may be C modules if verbose: - print(f'Processing {len(py_modules)} files...') + print(f"Processing {len(py_modules)} files...") if parse_only: for mod in py_modules: parse_source_file(mod, mypy_options) @@ -1522,22 +1619,26 @@ def generate_asts_for_modules(py_modules: List[StubSource], mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module] -def generate_stub_from_ast(mod: StubSource, - target: str, - parse_only: bool = False, - pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION, - include_private: bool = False, - export_less: bool = False) -> None: +def generate_stub_from_ast( + mod: StubSource, + target: str, + parse_only: bool = False, + pyversion: Tuple[int, int] = defaults.PYTHON3_VERSION, + include_private: bool = False, + export_less: bool = False, +) -> None: """Use analysed (or just parsed) AST to generate type stub for single file. If directory for target doesn't exist it will created. Existing stub will be overwritten. """ - gen = StubGenerator(mod.runtime_all, - pyversion=pyversion, - include_private=include_private, - analyzed=not parse_only, - export_less=export_less) + gen = StubGenerator( + mod.runtime_all, + pyversion=pyversion, + include_private=include_private, + analyzed=not parse_only, + export_less=export_less, + ) assert mod.ast is not None, "This function must be used only with analyzed modules" mod.ast.accept(gen) @@ -1545,8 +1646,8 @@ def generate_stub_from_ast(mod: StubSource, subdir = os.path.dirname(target) if subdir and not os.path.isdir(subdir): os.makedirs(subdir) - with open(target, 'w') as file: - file.write(''.join(gen.output())) + with open(target, "w") as file: + file.write("".join(gen.output())) def collect_docs_signatures(doc_dir: str) -> Tuple[Dict[str, str], Dict[str, str]]: @@ -1557,7 +1658,7 @@ def collect_docs_signatures(doc_dir: str) -> Tuple[Dict[str, str], Dict[str, str """ all_sigs: List[Sig] = [] all_class_sigs: List[Sig] = [] - for path in glob.glob(f'{doc_dir}/*.rst'): + for path in glob.glob(f"{doc_dir}/*.rst"): with open(path) as f: loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines()) all_sigs += loc_sigs @@ -1582,37 +1683,40 @@ def generate_stubs(options: Options) -> None: files = [] for mod in py_modules: assert mod.path is not None, "Not found module was not skipped" - target = mod.module.replace('.', '/') - if os.path.basename(mod.path) == '__init__.py': - target += '/__init__.pyi' + target = mod.module.replace(".", "/") + if os.path.basename(mod.path) == "__init__.py": + target += "/__init__.pyi" else: - target += '.pyi' + target += ".pyi" target = os.path.join(options.output_dir, target) files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): - generate_stub_from_ast(mod, target, - options.parse_only, options.pyversion, - options.include_private, - options.export_less) + generate_stub_from_ast( + mod, + target, + options.parse_only, + options.pyversion, + options.include_private, + options.export_less, + ) # Separately analyse C modules using different logic. for mod in c_modules: - if any(py_mod.module.startswith(mod.module + '.') - for py_mod in py_modules + c_modules): - target = mod.module.replace('.', '/') + '/__init__.pyi' + if any(py_mod.module.startswith(mod.module + ".") for py_mod in py_modules + c_modules): + target = mod.module.replace(".", "/") + "/__init__.pyi" else: - target = mod.module.replace('.', '/') + '.pyi' + target = mod.module.replace(".", "/") + ".pyi" target = os.path.join(options.output_dir, target) files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): generate_stub_for_c_module(mod.module, target, sigs=sigs, class_sigs=class_sigs) num_modules = len(py_modules) + len(c_modules) if not options.quiet and num_modules > 0: - print('Processed %d modules' % num_modules) + print("Processed %d modules" % num_modules) if len(files) == 1: - print(f'Generated {files[0]}') + print(f"Generated {files[0]}") else: - print(f'Generated files under {common_dir_prefix(files)}' + os.sep) + print(f"Generated files under {common_dir_prefix(files)}" + os.sep) HEADER = """%(prog)s [-h] [--py2] [more options, see -h] @@ -1627,51 +1731,98 @@ def generate_stubs(options: Options) -> None: def parse_options(args: List[str]) -> Options: - parser = argparse.ArgumentParser(prog='stubgen', - usage=HEADER, - description=DESCRIPTION) - - parser.add_argument('--py2', action='store_true', - help="run in Python 2 mode (default: Python 3 mode)") - parser.add_argument('--ignore-errors', action='store_true', - help="ignore errors when trying to generate stubs for modules") - parser.add_argument('--no-import', action='store_true', - help="don't import the modules, just parse and analyze them " - "(doesn't work with C extension modules and might not " - "respect __all__)") - parser.add_argument('--parse-only', action='store_true', - help="don't perform semantic analysis of sources, just parse them " - "(only applies to Python modules, might affect quality of stubs)") - parser.add_argument('--include-private', action='store_true', - help="generate stubs for objects and members considered private " - "(single leading underscore and no trailing underscores)") - parser.add_argument('--export-less', action='store_true', - help=("don't implicitly export all names imported from other modules " - "in the same package")) - parser.add_argument('-v', '--verbose', action='store_true', - help="show more verbose messages") - parser.add_argument('-q', '--quiet', action='store_true', - help="show fewer messages") - parser.add_argument('--doc-dir', metavar='PATH', default='', - help="use .rst documentation in PATH (this may result in " - "better stubs in some cases; consider setting this to " - "DIR/Python-X.Y.Z/Doc/library)") - parser.add_argument('--search-path', metavar='PATH', default='', - help="specify module search directories, separated by ':' " - "(currently only used if --no-import is given)") - parser.add_argument('--python-executable', metavar='PATH', dest='interpreter', default='', - help="use Python interpreter at PATH (only works for " - "Python 2 right now)") - parser.add_argument('-o', '--output', metavar='PATH', dest='output_dir', default='out', - help="change the output directory [default: %(default)s]") - parser.add_argument('-m', '--module', action='append', metavar='MODULE', - dest='modules', default=[], - help="generate stub for module; can repeat for more modules") - parser.add_argument('-p', '--package', action='append', metavar='PACKAGE', - dest='packages', default=[], - help="generate stubs for package recursively; can be repeated") - parser.add_argument(metavar='files', nargs='*', dest='files', - help="generate stubs for given files or directories") + parser = argparse.ArgumentParser(prog="stubgen", usage=HEADER, description=DESCRIPTION) + + parser.add_argument( + "--py2", action="store_true", help="run in Python 2 mode (default: Python 3 mode)" + ) + parser.add_argument( + "--ignore-errors", + action="store_true", + help="ignore errors when trying to generate stubs for modules", + ) + parser.add_argument( + "--no-import", + action="store_true", + help="don't import the modules, just parse and analyze them " + "(doesn't work with C extension modules and might not " + "respect __all__)", + ) + parser.add_argument( + "--parse-only", + action="store_true", + help="don't perform semantic analysis of sources, just parse them " + "(only applies to Python modules, might affect quality of stubs)", + ) + parser.add_argument( + "--include-private", + action="store_true", + help="generate stubs for objects and members considered private " + "(single leading underscore and no trailing underscores)", + ) + parser.add_argument( + "--export-less", + action="store_true", + help=( + "don't implicitly export all names imported from other modules " "in the same package" + ), + ) + parser.add_argument("-v", "--verbose", action="store_true", help="show more verbose messages") + parser.add_argument("-q", "--quiet", action="store_true", help="show fewer messages") + parser.add_argument( + "--doc-dir", + metavar="PATH", + default="", + help="use .rst documentation in PATH (this may result in " + "better stubs in some cases; consider setting this to " + "DIR/Python-X.Y.Z/Doc/library)", + ) + parser.add_argument( + "--search-path", + metavar="PATH", + default="", + help="specify module search directories, separated by ':' " + "(currently only used if --no-import is given)", + ) + parser.add_argument( + "--python-executable", + metavar="PATH", + dest="interpreter", + default="", + help="use Python interpreter at PATH (only works for " "Python 2 right now)", + ) + parser.add_argument( + "-o", + "--output", + metavar="PATH", + dest="output_dir", + default="out", + help="change the output directory [default: %(default)s]", + ) + parser.add_argument( + "-m", + "--module", + action="append", + metavar="MODULE", + dest="modules", + default=[], + help="generate stub for module; can repeat for more modules", + ) + parser.add_argument( + "-p", + "--package", + action="append", + metavar="PACKAGE", + dest="packages", + default=[], + help="generate stubs for package recursively; can be repeated", + ) + parser.add_argument( + metavar="files", + nargs="*", + dest="files", + help="generate stubs for given files or directories", + ) ns = parser.parse_args(args) @@ -1681,39 +1832,41 @@ def parse_options(args: List[str]) -> Options: if ns.modules + ns.packages and ns.files: parser.error("May only specify one of: modules/packages or files.") if ns.quiet and ns.verbose: - parser.error('Cannot specify both quiet and verbose messages') + parser.error("Cannot specify both quiet and verbose messages") # Create the output folder if it doesn't already exist. if not os.path.exists(ns.output_dir): os.makedirs(ns.output_dir) - return Options(pyversion=pyversion, - no_import=ns.no_import, - doc_dir=ns.doc_dir, - search_path=ns.search_path.split(':'), - interpreter=ns.interpreter, - ignore_errors=ns.ignore_errors, - parse_only=ns.parse_only, - include_private=ns.include_private, - output_dir=ns.output_dir, - modules=ns.modules, - packages=ns.packages, - files=ns.files, - verbose=ns.verbose, - quiet=ns.quiet, - export_less=ns.export_less) + return Options( + pyversion=pyversion, + no_import=ns.no_import, + doc_dir=ns.doc_dir, + search_path=ns.search_path.split(":"), + interpreter=ns.interpreter, + ignore_errors=ns.ignore_errors, + parse_only=ns.parse_only, + include_private=ns.include_private, + output_dir=ns.output_dir, + modules=ns.modules, + packages=ns.packages, + files=ns.files, + verbose=ns.verbose, + quiet=ns.quiet, + export_less=ns.export_less, + ) def main(args: Optional[List[str]] = None) -> None: - mypy.util.check_python_version('stubgen') + mypy.util.check_python_version("stubgen") # Make sure that the current directory is in sys.path so that # stubgen can be run on packages in the current directory. - if not ('' in sys.path or '.' in sys.path): - sys.path.insert(0, '') + if not ("" in sys.path or "." in sys.path): + sys.path.insert(0, "") options = parse_options(sys.argv[1:] if args is None else args) generate_stubs(options) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index 7c6b8b95b78ea..66db4137fe501 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -8,36 +8,43 @@ import inspect import os.path import re -from typing import List, Dict, Tuple, Optional, Mapping, Any, Set from types import ModuleType +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple + from typing_extensions import Final from mypy.moduleinspect import is_c_module from mypy.stubdoc import ( - infer_sig_from_docstring, infer_prop_type_from_docstring, ArgSig, - infer_arg_sig_from_anon_docstring, infer_ret_type_sig_from_anon_docstring, - infer_ret_type_sig_from_docstring, FunctionSig + ArgSig, + FunctionSig, + infer_arg_sig_from_anon_docstring, + infer_prop_type_from_docstring, + infer_ret_type_sig_from_anon_docstring, + infer_ret_type_sig_from_docstring, + infer_sig_from_docstring, ) # Members of the typing module to consider for importing by default. _DEFAULT_TYPING_IMPORTS: Final = ( - 'Any', - 'Callable', - 'ClassVar', - 'Dict', - 'Iterable', - 'Iterator', - 'List', - 'Optional', - 'Tuple', - 'Union', + "Any", + "Callable", + "ClassVar", + "Dict", + "Iterable", + "Iterator", + "List", + "Optional", + "Tuple", + "Union", ) -def generate_stub_for_c_module(module_name: str, - target: str, - sigs: Optional[Dict[str, str]] = None, - class_sigs: Optional[Dict[str, str]] = None) -> None: +def generate_stub_for_c_module( + module_name: str, + target: str, + sigs: Optional[Dict[str, str]] = None, + class_sigs: Optional[Dict[str, str]] = None, +) -> None: """Generate stub for C module. This combines simple runtime introspection (looking for docstrings and attributes @@ -47,7 +54,7 @@ def generate_stub_for_c_module(module_name: str, will be overwritten. """ module = importlib.import_module(module_name) - assert is_c_module(module), f'{module_name} is not a C module' + assert is_c_module(module), f"{module_name} is not a C module" subdir = os.path.dirname(target) if subdir and not os.path.isdir(subdir): os.makedirs(subdir) @@ -61,44 +68,45 @@ def generate_stub_for_c_module(module_name: str, done.add(name) types: List[str] = [] for name, obj in items: - if name.startswith('__') and name.endswith('__'): + if name.startswith("__") and name.endswith("__"): continue if is_c_type(obj): - generate_c_type_stub(module, name, obj, types, imports=imports, sigs=sigs, - class_sigs=class_sigs) + generate_c_type_stub( + module, name, obj, types, imports=imports, sigs=sigs, class_sigs=class_sigs + ) done.add(name) variables = [] for name, obj in items: - if name.startswith('__') and name.endswith('__'): + if name.startswith("__") and name.endswith("__"): continue if name not in done and not inspect.ismodule(obj): type_str = strip_or_import(get_type_fullname(type(obj)), module, imports) - variables.append(f'{name}: {type_str}') + variables.append(f"{name}: {type_str}") output = sorted(set(imports)) for line in variables: output.append(line) for line in types: - if line.startswith('class') and output and output[-1]: - output.append('') + if line.startswith("class") and output and output[-1]: + output.append("") output.append(line) if output and functions: - output.append('') + output.append("") for line in functions: output.append(line) output = add_typing_import(output) - with open(target, 'w') as file: + with open(target, "w") as file: for line in output: - file.write(f'{line}\n') + file.write(f"{line}\n") def add_typing_import(output: List[str]) -> List[str]: """Add typing imports for collections/types that occur in the generated stub.""" names = [] for name in _DEFAULT_TYPING_IMPORTS: - if any(re.search(r'\b%s\b' % name, line) for line in output): + if any(re.search(r"\b%s\b" % name, line) for line in output): names.append(name) if names: - return [f"from typing import {', '.join(names)}", ''] + output + return [f"from typing import {', '.join(names)}", ""] + output else: return output[:] @@ -108,22 +116,26 @@ def is_c_function(obj: object) -> bool: def is_c_method(obj: object) -> bool: - return inspect.ismethoddescriptor(obj) or type(obj) in (type(str.index), - type(str.__add__), - type(str.__new__)) + return inspect.ismethoddescriptor(obj) or type(obj) in ( + type(str.index), + type(str.__add__), + type(str.__new__), + ) def is_c_classmethod(obj: object) -> bool: - return inspect.isbuiltin(obj) or type(obj).__name__ in ('classmethod', - 'classmethod_descriptor') + return inspect.isbuiltin(obj) or type(obj).__name__ in ( + "classmethod", + "classmethod_descriptor", + ) def is_c_property(obj: object) -> bool: - return inspect.isdatadescriptor(obj) or hasattr(obj, 'fget') + return inspect.isdatadescriptor(obj) or hasattr(obj, "fget") def is_c_property_readonly(prop: Any) -> bool: - return hasattr(prop, 'fset') and prop.fset is None + return hasattr(prop, "fset") and prop.fset is None def is_c_type(obj: object) -> bool: @@ -131,19 +143,20 @@ def is_c_type(obj: object) -> bool: def is_pybind11_overloaded_function_docstring(docstr: str, name: str) -> bool: - return docstr.startswith(f"{name}(*args, **kwargs)\n" + - "Overloaded function.\n\n") - - -def generate_c_function_stub(module: ModuleType, - name: str, - obj: object, - output: List[str], - imports: List[str], - self_var: Optional[str] = None, - sigs: Optional[Dict[str, str]] = None, - class_name: Optional[str] = None, - class_sigs: Optional[Dict[str, str]] = None) -> None: + return docstr.startswith(f"{name}(*args, **kwargs)\n" + "Overloaded function.\n\n") + + +def generate_c_function_stub( + module: ModuleType, + name: str, + obj: object, + output: List[str], + imports: List[str], + self_var: Optional[str] = None, + sigs: Optional[Dict[str, str]] = None, + class_name: Optional[str] = None, + class_sigs: Optional[Dict[str, str]] = None, +) -> None: """Generate stub for a single function or method. The result (always a single line) will be appended to 'output'. @@ -156,7 +169,7 @@ def generate_c_function_stub(module: ModuleType, if class_sigs is None: class_sigs = {} - ret_type = 'None' if name == '__init__' and class_name else 'Any' + ret_type = "None" if name == "__init__" and class_name else "Any" if ( name in ("__new__", "__init__") @@ -172,7 +185,7 @@ def generate_c_function_stub(module: ModuleType, ) ] else: - docstr = getattr(obj, '__doc__', None) + docstr = getattr(obj, "__doc__", None) inferred = infer_sig_from_docstring(docstr, name) if inferred: assert docstr is not None @@ -181,13 +194,19 @@ def generate_c_function_stub(module: ModuleType, del inferred[-1] if not inferred: if class_name and name not in sigs: - inferred = [FunctionSig(name, args=infer_method_sig(name, self_var), - ret_type=ret_type)] + inferred = [ + FunctionSig(name, args=infer_method_sig(name, self_var), ret_type=ret_type) + ] else: - inferred = [FunctionSig(name=name, - args=infer_arg_sig_from_anon_docstring( - sigs.get(name, '(*args, **kwargs)')), - ret_type=ret_type)] + inferred = [ + FunctionSig( + name=name, + args=infer_arg_sig_from_anon_docstring( + sigs.get(name, "(*args, **kwargs)") + ), + ret_type=ret_type, + ) + ] elif class_name and self_var: args = inferred[0].args if not args or args[0].name != self_var: @@ -195,7 +214,7 @@ def generate_c_function_stub(module: ModuleType, is_overloaded = len(inferred) > 1 if inferred else False if is_overloaded: - imports.append('from typing import overload') + imports.append("from typing import overload") if inferred: for signature in inferred: sig = [] @@ -204,8 +223,8 @@ def generate_c_function_stub(module: ModuleType, arg_def = self_var else: arg_def = arg.name - if arg_def == 'None': - arg_def = '_none' # None is not a valid argument name + if arg_def == "None": + arg_def = "_none" # None is not a valid argument name if arg.type: arg_def += ": " + strip_or_import(arg.type, module, imports) @@ -216,12 +235,14 @@ def generate_c_function_stub(module: ModuleType, sig.append(arg_def) if is_overloaded: - output.append('@overload') - output.append('def {function}({args}) -> {ret}: ...'.format( - function=name, - args=", ".join(sig), - ret=strip_or_import(signature.ret_type, module, imports) - )) + output.append("@overload") + output.append( + "def {function}({args}) -> {ret}: ...".format( + function=name, + args=", ".join(sig), + ret=strip_or_import(signature.ret_type, module, imports), + ) + ) def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str: @@ -236,38 +257,38 @@ def strip_or_import(typ: str, module: ModuleType, imports: List[str]) -> str: imports: list of import statements (may be modified during the call) """ stripped_type = typ - if any(c in typ for c in '[,'): - for subtyp in re.split(r'[\[,\]]', typ): + if any(c in typ for c in "[,"): + for subtyp in re.split(r"[\[,\]]", typ): strip_or_import(subtyp.strip(), module, imports) if module: - stripped_type = re.sub( - r'(^|[\[, ]+)' + re.escape(module.__name__ + '.'), - r'\1', - typ, - ) - elif module and typ.startswith(module.__name__ + '.'): - stripped_type = typ[len(module.__name__) + 1:] - elif '.' in typ: - arg_module = typ[:typ.rindex('.')] - if arg_module == 'builtins': - stripped_type = typ[len('builtins') + 1:] + stripped_type = re.sub(r"(^|[\[, ]+)" + re.escape(module.__name__ + "."), r"\1", typ) + elif module and typ.startswith(module.__name__ + "."): + stripped_type = typ[len(module.__name__) + 1 :] + elif "." in typ: + arg_module = typ[: typ.rindex(".")] + if arg_module == "builtins": + stripped_type = typ[len("builtins") + 1 :] else: - imports.append(f'import {arg_module}') - if stripped_type == 'NoneType': - stripped_type = 'None' + imports.append(f"import {arg_module}") + if stripped_type == "NoneType": + stripped_type = "None" return stripped_type def is_static_property(obj: object) -> bool: - return type(obj).__name__ == 'pybind11_static_property' - - -def generate_c_property_stub(name: str, obj: object, - static_properties: List[str], - rw_properties: List[str], - ro_properties: List[str], readonly: bool, - module: Optional[ModuleType] = None, - imports: Optional[List[str]] = None) -> None: + return type(obj).__name__ == "pybind11_static_property" + + +def generate_c_property_stub( + name: str, + obj: object, + static_properties: List[str], + rw_properties: List[str], + ro_properties: List[str], + readonly: bool, + module: Optional[ModuleType] = None, + imports: Optional[List[str]] = None, +) -> None: """Generate property stub using introspection of 'obj'. Try to infer type from docstring, append resulting lines to 'output'. @@ -289,36 +310,36 @@ def infer_prop_type(docstr: Optional[str]) -> Optional[str]: if is_skipped_attribute(name): return - inferred = infer_prop_type(getattr(obj, '__doc__', None)) + inferred = infer_prop_type(getattr(obj, "__doc__", None)) if not inferred: - fget = getattr(obj, 'fget', None) - inferred = infer_prop_type(getattr(fget, '__doc__', None)) + fget = getattr(obj, "fget", None) + inferred = infer_prop_type(getattr(fget, "__doc__", None)) if not inferred: - inferred = 'Any' + inferred = "Any" if module is not None and imports is not None: inferred = strip_or_import(inferred, module, imports) if is_static_property(obj): trailing_comment = " # read-only" if readonly else "" - static_properties.append( - f'{name}: ClassVar[{inferred}] = ...{trailing_comment}' - ) + static_properties.append(f"{name}: ClassVar[{inferred}] = ...{trailing_comment}") else: # regular property if readonly: - ro_properties.append('@property') - ro_properties.append(f'def {name}(self) -> {inferred}: ...') + ro_properties.append("@property") + ro_properties.append(f"def {name}(self) -> {inferred}: ...") else: - rw_properties.append(f'{name}: {inferred}') - - -def generate_c_type_stub(module: ModuleType, - class_name: str, - obj: type, - output: List[str], - imports: List[str], - sigs: Optional[Dict[str, str]] = None, - class_sigs: Optional[Dict[str, str]] = None) -> None: + rw_properties.append(f"{name}: {inferred}") + + +def generate_c_type_stub( + module: ModuleType, + class_name: str, + obj: type, + output: List[str], + imports: List[str], + sigs: Optional[Dict[str, str]] = None, + class_sigs: Optional[Dict[str, str]] = None, +) -> None: """Generate stub for a single class using runtime introspection. The result lines will be appended to 'output'. If necessary, any @@ -338,45 +359,64 @@ def generate_c_type_stub(module: ModuleType, if is_c_method(value) or is_c_classmethod(value): done.add(attr) if not is_skipped_attribute(attr): - if attr == '__new__': + if attr == "__new__": # TODO: We should support __new__. - if '__init__' in obj_dict: + if "__init__" in obj_dict: # Avoid duplicate functions if both are present. # But is there any case where .__new__() has a # better signature than __init__() ? continue - attr = '__init__' + attr = "__init__" if is_c_classmethod(value): - methods.append('@classmethod') - self_var = 'cls' + methods.append("@classmethod") + self_var = "cls" else: - self_var = 'self' - generate_c_function_stub(module, attr, value, methods, imports=imports, - self_var=self_var, sigs=sigs, class_name=class_name, - class_sigs=class_sigs) + self_var = "self" + generate_c_function_stub( + module, + attr, + value, + methods, + imports=imports, + self_var=self_var, + sigs=sigs, + class_name=class_name, + class_sigs=class_sigs, + ) elif is_c_property(value): done.add(attr) - generate_c_property_stub(attr, value, static_properties, rw_properties, ro_properties, - is_c_property_readonly(value), - module=module, imports=imports) + generate_c_property_stub( + attr, + value, + static_properties, + rw_properties, + ro_properties, + is_c_property_readonly(value), + module=module, + imports=imports, + ) elif is_c_type(value): - generate_c_type_stub(module, attr, value, types, imports=imports, sigs=sigs, - class_sigs=class_sigs) + generate_c_type_stub( + module, attr, value, types, imports=imports, sigs=sigs, class_sigs=class_sigs + ) done.add(attr) for attr, value in items: if is_skipped_attribute(attr): continue if attr not in done: - static_properties.append('{}: ClassVar[{}] = ...'.format( - attr, strip_or_import(get_type_fullname(type(value)), module, imports))) + static_properties.append( + "{}: ClassVar[{}] = ...".format( + attr, strip_or_import(get_type_fullname(type(value)), module, imports) + ) + ) all_bases = type.mro(obj) if all_bases[-1] is object: # TODO: Is this always object? del all_bases[-1] # remove pybind11_object. All classes generated by pybind11 have pybind11_object in their MRO, # which only overrides a few functions in object type - if all_bases and all_bases[-1].__name__ == 'pybind11_object': + if all_bases and all_bases[-1].__name__ == "pybind11_object": del all_bases[-1] # remove the class itself all_bases = all_bases[1:] @@ -386,32 +426,32 @@ def generate_c_type_stub(module: ModuleType, if not any(issubclass(b, base) for b in bases): bases.append(base) if bases: - bases_str = '(%s)' % ', '.join( - strip_or_import( - get_type_fullname(base), - module, - imports - ) for base in bases + bases_str = "(%s)" % ", ".join( + strip_or_import(get_type_fullname(base), module, imports) for base in bases ) else: - bases_str = '' + bases_str = "" if types or static_properties or rw_properties or methods or ro_properties: - output.append(f'class {class_name}{bases_str}:') + output.append(f"class {class_name}{bases_str}:") for line in types: - if output and output[-1] and \ - not output[-1].startswith('class') and line.startswith('class'): - output.append('') - output.append(' ' + line) + if ( + output + and output[-1] + and not output[-1].startswith("class") + and line.startswith("class") + ): + output.append("") + output.append(" " + line) for line in static_properties: - output.append(f' {line}') + output.append(f" {line}") for line in rw_properties: - output.append(f' {line}') + output.append(f" {line}") for line in methods: - output.append(f' {line}') + output.append(f" {line}") for line in ro_properties: - output.append(f' {line}') + output.append(f" {line}") else: - output.append(f'class {class_name}{bases_str}: ...') + output.append(f"class {class_name}{bases_str}: ...") def get_type_fullname(typ: type) -> str: @@ -423,9 +463,9 @@ def method_name_sort_key(name: str) -> Tuple[int, str]: I.e.: constructor, normal methods, special methods. """ - if name in ('__new__', '__init__'): + if name in ("__new__", "__init__"): return 0, name - if name.startswith('__') and name.endswith('__'): + if name.startswith("__") and name.endswith("__"): return 2, name return 1, name @@ -435,64 +475,118 @@ def is_pybind_skipped_attribute(attr: str) -> bool: def is_skipped_attribute(attr: str) -> bool: - return (attr in ('__getattribute__', - '__str__', - '__repr__', - '__doc__', - '__dict__', - '__module__', - '__weakref__') # For pickling - or is_pybind_skipped_attribute(attr) - ) + return attr in ( + "__getattribute__", + "__str__", + "__repr__", + "__doc__", + "__dict__", + "__module__", + "__weakref__", + ) or is_pybind_skipped_attribute( # For pickling + attr + ) def infer_method_sig(name: str, self_var: Optional[str] = None) -> List[ArgSig]: args: Optional[List[ArgSig]] = None - if name.startswith('__') and name.endswith('__'): + if name.startswith("__") and name.endswith("__"): name = name[2:-2] - if name in ('hash', 'iter', 'next', 'sizeof', 'copy', 'deepcopy', 'reduce', 'getinitargs', - 'int', 'float', 'trunc', 'complex', 'bool', 'abs', 'bytes', 'dir', 'len', - 'reversed', 'round', 'index', 'enter'): + if name in ( + "hash", + "iter", + "next", + "sizeof", + "copy", + "deepcopy", + "reduce", + "getinitargs", + "int", + "float", + "trunc", + "complex", + "bool", + "abs", + "bytes", + "dir", + "len", + "reversed", + "round", + "index", + "enter", + ): args = [] - elif name == 'getitem': - args = [ArgSig(name='index')] - elif name == 'setitem': - args = [ArgSig(name='index'), - ArgSig(name='object')] - elif name in ('delattr', 'getattr'): - args = [ArgSig(name='name')] - elif name == 'setattr': - args = [ArgSig(name='name'), - ArgSig(name='value')] - elif name == 'getstate': + elif name == "getitem": + args = [ArgSig(name="index")] + elif name == "setitem": + args = [ArgSig(name="index"), ArgSig(name="object")] + elif name in ("delattr", "getattr"): + args = [ArgSig(name="name")] + elif name == "setattr": + args = [ArgSig(name="name"), ArgSig(name="value")] + elif name == "getstate": args = [] - elif name == 'setstate': - args = [ArgSig(name='state')] - elif name in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', - 'add', 'radd', 'sub', 'rsub', 'mul', 'rmul', - 'mod', 'rmod', 'floordiv', 'rfloordiv', 'truediv', 'rtruediv', - 'divmod', 'rdivmod', 'pow', 'rpow', - 'xor', 'rxor', 'or', 'ror', 'and', 'rand', 'lshift', 'rlshift', - 'rshift', 'rrshift', - 'contains', 'delitem', - 'iadd', 'iand', 'ifloordiv', 'ilshift', 'imod', 'imul', 'ior', - 'ipow', 'irshift', 'isub', 'itruediv', 'ixor'): - args = [ArgSig(name='other')] - elif name in ('neg', 'pos', 'invert'): + elif name == "setstate": + args = [ArgSig(name="state")] + elif name in ( + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "add", + "radd", + "sub", + "rsub", + "mul", + "rmul", + "mod", + "rmod", + "floordiv", + "rfloordiv", + "truediv", + "rtruediv", + "divmod", + "rdivmod", + "pow", + "rpow", + "xor", + "rxor", + "or", + "ror", + "and", + "rand", + "lshift", + "rlshift", + "rshift", + "rrshift", + "contains", + "delitem", + "iadd", + "iand", + "ifloordiv", + "ilshift", + "imod", + "imul", + "ior", + "ipow", + "irshift", + "isub", + "itruediv", + "ixor", + ): + args = [ArgSig(name="other")] + elif name in ("neg", "pos", "invert"): args = [] - elif name == 'get': - args = [ArgSig(name='instance'), - ArgSig(name='owner')] - elif name == 'set': - args = [ArgSig(name='instance'), - ArgSig(name='value')] - elif name == 'reduce_ex': - args = [ArgSig(name='protocol')] - elif name == 'exit': - args = [ArgSig(name='type'), - ArgSig(name='value'), - ArgSig(name='traceback')] + elif name == "get": + args = [ArgSig(name="instance"), ArgSig(name="owner")] + elif name == "set": + args = [ArgSig(name="instance"), ArgSig(name="value")] + elif name == "reduce_ex": + args = [ArgSig(name="protocol")] + elif name == "exit": + args = [ArgSig(name="type"), ArgSig(name="value"), ArgSig(name="traceback")] if args is None: - args = [ArgSig(name='*args'), - ArgSig(name='**kwargs')] - return [ArgSig(name=self_var or 'self')] + args + args = [ArgSig(name="*args"), ArgSig(name="**kwargs")] + return [ArgSig(name=self_var or "self")] + args diff --git a/mypy/stubinfo.py b/mypy/stubinfo.py index fb034162c7dc5..943623a6743b8 100644 --- a/mypy/stubinfo.py +++ b/mypy/stubinfo.py @@ -21,66 +21,66 @@ def is_legacy_bundled_package(prefix: str, py_version: int) -> bool: # # Package name can have one or two components ('a' or 'a.b'). legacy_bundled_packages = { - 'aiofiles': StubInfo('types-aiofiles', py_version=3), - 'atomicwrites': StubInfo('types-atomicwrites'), - 'attr': StubInfo('types-attrs'), - 'backports': StubInfo('types-backports'), - 'backports_abc': StubInfo('types-backports_abc'), - 'bleach': StubInfo('types-bleach'), - 'boto': StubInfo('types-boto'), - 'cachetools': StubInfo('types-cachetools'), - 'chardet': StubInfo('types-chardet'), - 'click_spinner': StubInfo('types-click-spinner'), - 'concurrent': StubInfo('types-futures', py_version=2), - 'contextvars': StubInfo('types-contextvars', py_version=3), - 'croniter': StubInfo('types-croniter'), - 'dataclasses': StubInfo('types-dataclasses', py_version=3), - 'dateparser': StubInfo('types-dateparser'), - 'datetimerange': StubInfo('types-DateTimeRange'), - 'dateutil': StubInfo('types-python-dateutil'), - 'decorator': StubInfo('types-decorator'), - 'deprecated': StubInfo('types-Deprecated'), - 'docutils': StubInfo('types-docutils', py_version=3), - 'emoji': StubInfo('types-emoji'), - 'enum': StubInfo('types-enum34', py_version=2), - 'fb303': StubInfo('types-fb303', py_version=2), - 'first': StubInfo('types-first'), - 'geoip2': StubInfo('types-geoip2'), - 'gflags': StubInfo('types-python-gflags'), - 'google.protobuf': StubInfo('types-protobuf'), - 'ipaddress': StubInfo('types-ipaddress', py_version=2), - 'kazoo': StubInfo('types-kazoo', py_version=2), - 'markdown': StubInfo('types-Markdown'), - 'maxminddb': StubInfo('types-maxminddb'), - 'mock': StubInfo('types-mock'), - 'OpenSSL': StubInfo('types-pyOpenSSL'), - 'paramiko': StubInfo('types-paramiko'), - 'pathlib2': StubInfo('types-pathlib2', py_version=2), - 'pkg_resources': StubInfo('types-setuptools', py_version=3), - 'polib': StubInfo('types-polib'), - 'pycurl': StubInfo('types-pycurl'), - 'pymssql': StubInfo('types-pymssql', py_version=2), - 'pymysql': StubInfo('types-PyMySQL'), - 'pyrfc3339': StubInfo('types-pyRFC3339', py_version=3), - 'python2': StubInfo('types-six'), - 'pytz': StubInfo('types-pytz'), - 'pyVmomi': StubInfo('types-pyvmomi'), - 'redis': StubInfo('types-redis'), - 'requests': StubInfo('types-requests'), - 'retry': StubInfo('types-retry'), - 'routes': StubInfo('types-Routes', py_version=2), - 'scribe': StubInfo('types-scribe', py_version=2), - 'simplejson': StubInfo('types-simplejson'), - 'singledispatch': StubInfo('types-singledispatch'), - 'six': StubInfo('types-six'), - 'slugify': StubInfo('types-python-slugify'), - 'tabulate': StubInfo('types-tabulate'), - 'termcolor': StubInfo('types-termcolor'), - 'toml': StubInfo('types-toml'), - 'tornado': StubInfo('types-tornado', py_version=2), - 'typed_ast': StubInfo('types-typed-ast', py_version=3), - 'tzlocal': StubInfo('types-tzlocal'), - 'ujson': StubInfo('types-ujson'), - 'waitress': StubInfo('types-waitress', py_version=3), - 'yaml': StubInfo('types-PyYAML'), + "aiofiles": StubInfo("types-aiofiles", py_version=3), + "atomicwrites": StubInfo("types-atomicwrites"), + "attr": StubInfo("types-attrs"), + "backports": StubInfo("types-backports"), + "backports_abc": StubInfo("types-backports_abc"), + "bleach": StubInfo("types-bleach"), + "boto": StubInfo("types-boto"), + "cachetools": StubInfo("types-cachetools"), + "chardet": StubInfo("types-chardet"), + "click_spinner": StubInfo("types-click-spinner"), + "concurrent": StubInfo("types-futures", py_version=2), + "contextvars": StubInfo("types-contextvars", py_version=3), + "croniter": StubInfo("types-croniter"), + "dataclasses": StubInfo("types-dataclasses", py_version=3), + "dateparser": StubInfo("types-dateparser"), + "datetimerange": StubInfo("types-DateTimeRange"), + "dateutil": StubInfo("types-python-dateutil"), + "decorator": StubInfo("types-decorator"), + "deprecated": StubInfo("types-Deprecated"), + "docutils": StubInfo("types-docutils", py_version=3), + "emoji": StubInfo("types-emoji"), + "enum": StubInfo("types-enum34", py_version=2), + "fb303": StubInfo("types-fb303", py_version=2), + "first": StubInfo("types-first"), + "geoip2": StubInfo("types-geoip2"), + "gflags": StubInfo("types-python-gflags"), + "google.protobuf": StubInfo("types-protobuf"), + "ipaddress": StubInfo("types-ipaddress", py_version=2), + "kazoo": StubInfo("types-kazoo", py_version=2), + "markdown": StubInfo("types-Markdown"), + "maxminddb": StubInfo("types-maxminddb"), + "mock": StubInfo("types-mock"), + "OpenSSL": StubInfo("types-pyOpenSSL"), + "paramiko": StubInfo("types-paramiko"), + "pathlib2": StubInfo("types-pathlib2", py_version=2), + "pkg_resources": StubInfo("types-setuptools", py_version=3), + "polib": StubInfo("types-polib"), + "pycurl": StubInfo("types-pycurl"), + "pymssql": StubInfo("types-pymssql", py_version=2), + "pymysql": StubInfo("types-PyMySQL"), + "pyrfc3339": StubInfo("types-pyRFC3339", py_version=3), + "python2": StubInfo("types-six"), + "pytz": StubInfo("types-pytz"), + "pyVmomi": StubInfo("types-pyvmomi"), + "redis": StubInfo("types-redis"), + "requests": StubInfo("types-requests"), + "retry": StubInfo("types-retry"), + "routes": StubInfo("types-Routes", py_version=2), + "scribe": StubInfo("types-scribe", py_version=2), + "simplejson": StubInfo("types-simplejson"), + "singledispatch": StubInfo("types-singledispatch"), + "six": StubInfo("types-six"), + "slugify": StubInfo("types-python-slugify"), + "tabulate": StubInfo("types-tabulate"), + "termcolor": StubInfo("types-termcolor"), + "toml": StubInfo("types-toml"), + "tornado": StubInfo("types-tornado", py_version=2), + "typed_ast": StubInfo("types-typed-ast", py_version=3), + "tzlocal": StubInfo("types-tzlocal"), + "ujson": StubInfo("types-ujson"), + "waitress": StubInfo("types-waitress", py_version=3), + "yaml": StubInfo("types-PyYAML"), } diff --git a/mypy/stubtest.py b/mypy/stubtest.py index b4447d798cdd2..d2f9cfcca974e 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -16,13 +16,13 @@ import traceback import types import typing -import typing_extensions import warnings -from contextlib import redirect_stdout, redirect_stderr +from contextlib import redirect_stderr, redirect_stdout from functools import singledispatch from pathlib import Path from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union, cast +import typing_extensions from typing_extensions import Type import mypy.build @@ -33,7 +33,7 @@ from mypy import nodes from mypy.config_parser import parse_config_file from mypy.options import Options -from mypy.util import FancyFormatter, bytes_to_human_readable_repr, plural_s, is_dunder +from mypy.util import FancyFormatter, bytes_to_human_readable_repr, is_dunder, plural_s class Missing: @@ -59,7 +59,7 @@ def _style(message: str, **kwargs: Any) -> str: def _truncate(message: str, length: int) -> str: if len(message) > length: - return message[:length - 3] + "..." + return message[: length - 3] + "..." return message @@ -76,7 +76,7 @@ def __init__( runtime_object: MaybeMissing[Any], *, stub_desc: Optional[str] = None, - runtime_desc: Optional[str] = None + runtime_desc: Optional[str] = None, ) -> None: """Represents an error found by stubtest. @@ -166,6 +166,7 @@ def get_description(self, concise: bool = False) -> str: # Core logic # ==================== + def silent_import_module(module_name: str) -> types.ModuleType: with open(os.devnull, "w") as devnull: with warnings.catch_warnings(), redirect_stdout(devnull), redirect_stderr(devnull): @@ -318,8 +319,10 @@ def verify_typeinfo( return try: + class SubClass(runtime): # type: ignore pass + except TypeError: # Enum classes are implicitly @final if not stub.is_final and not issubclass(runtime, enum.Enum): @@ -423,7 +426,7 @@ def _verify_arg_name( return def strip_prefix(s: str, prefix: str) -> str: - return s[len(prefix):] if s.startswith(prefix) else s + return s[len(prefix) :] if s.startswith(prefix) else s if strip_prefix(stub_arg.variable.name, "__") == runtime_arg.name: return @@ -701,7 +704,7 @@ def _verify_signature( # parameters and b) below, we don't enforce that the stub takes *args, since runtime logic # may prevent those arguments from actually being accepted. if runtime.varpos is None: - for stub_arg in stub.pos[len(runtime.pos):]: + for stub_arg in stub.pos[len(runtime.pos) :]: # If the variable is in runtime.kwonly, it's just mislabelled as not a # keyword-only argument if stub_arg.variable.name not in runtime.kwonly: @@ -711,7 +714,7 @@ def _verify_signature( if stub.varpos is not None: yield f'runtime does not have *args argument "{stub.varpos.variable.name}"' elif len(stub.pos) < len(runtime.pos): - for runtime_arg in runtime.pos[len(stub.pos):]: + for runtime_arg in runtime.pos[len(stub.pos) :]: if runtime_arg.name not in stub.kwonly: yield f'stub does not have argument "{runtime_arg.name}"' else: @@ -783,7 +786,7 @@ def verify_funcitem( stub_sig = Signature.from_funcitem(stub) runtime_sig = Signature.from_inspect_signature(signature) runtime_sig_desc = f'{"async " if runtime_is_coroutine else ""}def {signature}' - stub_desc = f'def {stub_sig!r}' + stub_desc = f"def {stub_sig!r}" else: runtime_sig_desc, stub_desc = None, None @@ -797,7 +800,7 @@ def verify_funcitem( stub, runtime, stub_desc=stub_desc, - runtime_desc=runtime_sig_desc + runtime_desc=runtime_sig_desc, ) if not signature: @@ -835,12 +838,7 @@ def verify_var( and is_read_only_property(runtime) and (stub.is_settable_property or not stub.is_property) ): - yield Error( - object_path, - "is read-only at runtime but not in the stub", - stub, - runtime - ) + yield Error(object_path, "is read-only at runtime but not in the stub", stub, runtime) runtime_type = get_mypy_type_of_runtime_value(runtime) if ( @@ -858,10 +856,7 @@ def verify_var( if should_error: yield Error( - object_path, - f"variable differs from runtime type {runtime_type}", - stub, - runtime, + object_path, f"variable differs from runtime type {runtime_type}", stub, runtime ) @@ -876,12 +871,7 @@ def verify_overloadedfuncdef( if stub.is_property: # Any property with a setter is represented as an OverloadedFuncDef if is_read_only_property(runtime): - yield Error( - object_path, - "is read-only at runtime but not in the stub", - stub, - runtime - ) + yield Error(object_path, "is read-only at runtime but not in the stub", stub, runtime) return if not is_probably_a_function(runtime): @@ -940,7 +930,8 @@ def verify_paramspecexpr( yield Error(object_path, "is not present at runtime", stub, runtime) return maybe_paramspec_types = ( - getattr(typing, "ParamSpec", None), getattr(typing_extensions, "ParamSpec", None) + getattr(typing, "ParamSpec", None), + getattr(typing_extensions, "ParamSpec", None), ) paramspec_types = tuple([t for t in maybe_paramspec_types if t is not None]) if not paramspec_types or not isinstance(runtime, paramspec_types): @@ -989,10 +980,10 @@ def apply_decorator_to_funcitem( if decorator.fullname is None: # Happens with namedtuple return None - if decorator.fullname in ( - "builtins.staticmethod", - "abc.abstractmethod", - ) or decorator.fullname in mypy.types.OVERLOAD_NAMES: + if ( + decorator.fullname in ("builtins.staticmethod", "abc.abstractmethod") + or decorator.fullname in mypy.types.OVERLOAD_NAMES + ): return func if decorator.fullname == "builtins.classmethod": if func.arguments[0].variable.name not in ("cls", "mcs", "metacls"): @@ -1042,8 +1033,11 @@ def verify_typealias( stub_target = mypy.types.get_proper_type(stub.target) if isinstance(runtime, Missing): yield Error( - object_path, "is not present at runtime", stub, runtime, - stub_desc=f"Type alias for: {stub_target}" + object_path, + "is not present at runtime", + stub, + runtime, + stub_desc=f"Type alias for: {stub_target}", ) return if isinstance(stub_target, mypy.types.Instance): @@ -1057,8 +1051,11 @@ def verify_typealias( if isinstance(stub_target, mypy.types.TupleType): if tuple not in getattr(runtime, "__mro__", ()): yield Error( - object_path, "is not a subclass of tuple", stub, runtime, - stub_desc=str(stub_target) + object_path, + "is not a subclass of tuple", + stub, + runtime, + stub_desc=str(stub_target), ) # could check Tuple contents here... return @@ -1208,8 +1205,7 @@ def anytype() -> mypy.types.AnyType: if isinstance( runtime, - (types.FunctionType, types.BuiltinFunctionType, - types.MethodType, types.BuiltinMethodType) + (types.FunctionType, types.BuiltinFunctionType, types.MethodType, types.BuiltinMethodType), ): builtins = get_stub("builtins") assert builtins is not None @@ -1361,8 +1357,7 @@ def get_stub(module: str) -> Optional[nodes.MypyFile]: def get_typeshed_stdlib_modules( - custom_typeshed_dir: Optional[str], - version_info: Optional[Tuple[int, int]] = None + custom_typeshed_dir: Optional[str], version_info: Optional[Tuple[int, int]] = None ) -> List[str]: """Returns a list of stdlib modules in typeshed (for current Python version).""" stdlib_py_versions = mypy.modulefinder.load_stdlib_py_versions(custom_typeshed_dir) @@ -1454,10 +1449,7 @@ def test_stubs(args: _Arguments, use_builtins_fixtures: bool = False) -> int: modules = [m for m in modules if m not in annoying_modules] if not modules: - print( - _style("error:", color="red", bold=True), - "no modules to check", - ) + print(_style("error:", color="red", bold=True), "no modules to check") return 1 options = Options() @@ -1467,8 +1459,10 @@ def test_stubs(args: _Arguments, use_builtins_fixtures: bool = False) -> int: options.use_builtins_fixtures = use_builtins_fixtures if options.config_file: + def set_strict_flags() -> None: # not needed yet return + parse_config_file(options, set_strict_flags, options.config_file, sys.stdout, sys.stderr) try: @@ -1530,14 +1524,16 @@ def set_strict_flags() -> None: # not needed yet _style( f"Found {error_count} error{plural_s(error_count)}" f" (checked {len(modules)} module{plural_s(modules)})", - color="red", bold=True + color="red", + bold=True, ) ) else: print( _style( f"Success: no issues found in {len(modules)} module{plural_s(modules)}", - color="green", bold=True + color="green", + bold=True, ) ) @@ -1591,10 +1587,7 @@ def parse_options(args: List[str]) -> _Arguments: parser.add_argument( "--mypy-config-file", metavar="FILE", - help=( - "Use specified mypy config file to determine mypy plugins " - "and mypy path" - ), + help=("Use specified mypy config file to determine mypy plugins " "and mypy path"), ) parser.add_argument( "--custom-typeshed-dir", metavar="DIR", help="Use the custom typeshed in DIR" diff --git a/mypy/stubutil.py b/mypy/stubutil.py index 55f8c0b29345f..81deec985371e 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -1,18 +1,17 @@ """Utilities for mypy.stubgen, mypy.stubgenc, and mypy.stubdoc modules.""" -import sys -import os.path import json -import subprocess +import os.path import re +import subprocess +import sys from contextlib import contextmanager +from typing import Iterator, List, Optional, Tuple, Union -from typing import Optional, Tuple, List, Iterator, Union from typing_extensions import overload -from mypy.moduleinspect import ModuleInspect, InspectError from mypy.modulefinder import ModuleNotFoundReason - +from mypy.moduleinspect import InspectError, ModuleInspect # Modules that may fail when imported, or that may have side effects (fully qualified). NOT_IMPORTABLE_MODULES = () @@ -30,20 +29,22 @@ def default_py2_interpreter() -> str: Return full path or exit if failed. """ # TODO: Make this do something reasonable in Windows. - for candidate in ('/usr/bin/python2', '/usr/bin/python'): + for candidate in ("/usr/bin/python2", "/usr/bin/python"): if not os.path.exists(candidate): continue - output = subprocess.check_output([candidate, '--version'], - stderr=subprocess.STDOUT).strip() - if b'Python 2' in output: + output = subprocess.check_output( + [candidate, "--version"], stderr=subprocess.STDOUT + ).strip() + if b"Python 2" in output: return candidate - raise SystemExit("Can't find a Python 2 interpreter -- " - "please use the --python-executable option") + raise SystemExit( + "Can't find a Python 2 interpreter -- " "please use the --python-executable option" + ) -def walk_packages(inspect: ModuleInspect, - packages: List[str], - verbose: bool = False) -> Iterator[str]: +def walk_packages( + inspect: ModuleInspect, packages: List[str], verbose: bool = False +) -> Iterator[str]: """Iterates through all packages and sub-packages in the given list. This uses runtime imports (in another process) to find both Python and C modules. @@ -54,10 +55,10 @@ def walk_packages(inspect: ModuleInspect, """ for package_name in packages: if package_name in NOT_IMPORTABLE_MODULES: - print(f'{package_name}: Skipped (blacklisted)') + print(f"{package_name}: Skipped (blacklisted)") continue if verbose: - print(f'Trying to import {package_name!r} for runtime introspection') + print(f"Trying to import {package_name!r} for runtime introspection") try: prop = inspect.get_package_properties(package_name) except InspectError: @@ -71,9 +72,9 @@ def walk_packages(inspect: ModuleInspect, yield from prop.subpackages -def find_module_path_and_all_py2(module: str, - interpreter: str) -> Optional[Tuple[Optional[str], - Optional[List[str]]]]: +def find_module_path_and_all_py2( + module: str, interpreter: str +) -> Optional[Tuple[Optional[str], Optional[List[str]]]]: """Return tuple (module path, module __all__) for a Python 2 module. The path refers to the .py/.py[co] file. The second tuple item is @@ -82,8 +83,10 @@ def find_module_path_and_all_py2(module: str, Raise CantImport if the module can't be imported, or exit if it's a C extension module. """ cmd_template = f'{interpreter} -c "%s"' - code = ("import importlib, json; mod = importlib.import_module('%s'); " - "print(mod.__file__); print(json.dumps(getattr(mod, '__all__', None)))") % module + code = ( + "import importlib, json; mod = importlib.import_module('%s'); " + "print(mod.__file__); print(json.dumps(getattr(mod, '__all__', None)))" + ) % module try: output_bytes = subprocess.check_output(cmd_template % code, shell=True) except subprocess.CalledProcessError as e: @@ -91,36 +94,34 @@ def find_module_path_and_all_py2(module: str, if path is None: raise CantImport(module, str(e)) from e return path, None - output = output_bytes.decode('ascii').strip().splitlines() + output = output_bytes.decode("ascii").strip().splitlines() module_path = output[0] - if not module_path.endswith(('.py', '.pyc', '.pyo')): - raise SystemExit('%s looks like a C module; they are not supported for Python 2' % - module) - if module_path.endswith(('.pyc', '.pyo')): + if not module_path.endswith((".py", ".pyc", ".pyo")): + raise SystemExit("%s looks like a C module; they are not supported for Python 2" % module) + if module_path.endswith((".pyc", ".pyo")): module_path = module_path[:-1] module_all = json.loads(output[1]) return module_path, module_all -def find_module_path_using_py2_sys_path(module: str, - interpreter: str) -> Optional[str]: +def find_module_path_using_py2_sys_path(module: str, interpreter: str) -> Optional[str]: """Try to find the path of a .py file for a module using Python 2 sys.path. Return None if no match was found. """ out = subprocess.run( - [interpreter, '-c', 'import sys; import json; print(json.dumps(sys.path))'], + [interpreter, "-c", "import sys; import json; print(json.dumps(sys.path))"], check=True, - stdout=subprocess.PIPE + stdout=subprocess.PIPE, ).stdout - sys_path = json.loads(out.decode('utf-8')) + sys_path = json.loads(out.decode("utf-8")) return find_module_path_using_sys_path(module, sys_path) def find_module_path_using_sys_path(module: str, sys_path: List[str]) -> Optional[str]: relative_candidates = ( - module.replace('.', '/') + '.py', - os.path.join(module.replace('.', '/'), '__init__.py') + module.replace(".", "/") + ".py", + os.path.join(module.replace(".", "/"), "__init__.py"), ) for base in sys_path: for relative_path in relative_candidates: @@ -130,21 +131,20 @@ def find_module_path_using_sys_path(module: str, sys_path: List[str]) -> Optiona return None -def find_module_path_and_all_py3(inspect: ModuleInspect, - module: str, - verbose: bool) -> Optional[Tuple[Optional[str], - Optional[List[str]]]]: +def find_module_path_and_all_py3( + inspect: ModuleInspect, module: str, verbose: bool +) -> Optional[Tuple[Optional[str], Optional[List[str]]]]: """Find module and determine __all__ for a Python 3 module. Return None if the module is a C module. Return (module_path, __all__) if it is a Python module. Raise CantImport if import failed. """ if module in NOT_IMPORTABLE_MODULES: - raise CantImport(module, '') + raise CantImport(module, "") # TODO: Support custom interpreters. if verbose: - print(f'Trying to import {module!r} for runtime introspection') + print(f"Trying to import {module!r} for runtime introspection") try: mod = inspect.get_package_properties(module) except InspectError as e: @@ -159,14 +159,15 @@ def find_module_path_and_all_py3(inspect: ModuleInspect, @contextmanager -def generate_guarded(mod: str, target: str, - ignore_errors: bool = True, verbose: bool = False) -> Iterator[None]: +def generate_guarded( + mod: str, target: str, ignore_errors: bool = True, verbose: bool = False +) -> Iterator[None]: """Ignore or report errors during stub generation. Optionally report success. """ if verbose: - print(f'Processing {mod}') + print(f"Processing {mod}") try: yield except Exception as e: @@ -177,21 +178,21 @@ def generate_guarded(mod: str, target: str, print("Stub generation failed for", mod, file=sys.stderr) else: if verbose: - print(f'Created {target}') + print(f"Created {target}") -PY2_MODULES = {'cStringIO', 'urlparse', 'collections.UserDict'} +PY2_MODULES = {"cStringIO", "urlparse", "collections.UserDict"} -def report_missing(mod: str, message: Optional[str] = '', traceback: str = '') -> None: +def report_missing(mod: str, message: Optional[str] = "", traceback: str = "") -> None: if message: - message = ' with error: ' + message - print(f'{mod}: Failed to import, skipping{message}') + message = " with error: " + message + print(f"{mod}: Failed to import, skipping{message}") m = re.search(r"ModuleNotFoundError: No module named '([^']*)'", traceback) if m: missing_module = m.group(1) if missing_module in PY2_MODULES: - print('note: Try --py2 for Python 2 mode') + print("note: Try --py2 for Python 2 mode") def fail_missing(mod: str, reason: ModuleNotFoundReason) -> None: @@ -205,11 +206,13 @@ def fail_missing(mod: str, reason: ModuleNotFoundReason) -> None: @overload -def remove_misplaced_type_comments(source: bytes) -> bytes: ... +def remove_misplaced_type_comments(source: bytes) -> bytes: + ... @overload -def remove_misplaced_type_comments(source: str) -> str: ... +def remove_misplaced_type_comments(source: str) -> str: + ... def remove_misplaced_type_comments(source: Union[str, bytes]) -> Union[str, bytes]: @@ -220,13 +223,13 @@ def remove_misplaced_type_comments(source: Union[str, bytes]) -> Union[str, byte """ if isinstance(source, bytes): # This gives us a 1-1 character code mapping, so it's roundtrippable. - text = source.decode('latin1') + text = source.decode("latin1") else: text = source # Remove something that looks like a variable type comment but that's by itself # on a line, as it will often generate a parse error (unless it's # type: ignore). - text = re.sub(r'^[ \t]*# +type: +["\'a-zA-Z_].*$', '', text, flags=re.MULTILINE) + text = re.sub(r'^[ \t]*# +type: +["\'a-zA-Z_].*$', "", text, flags=re.MULTILINE) # Remove something that looks like a function type comment after docstring, # which will result in a parse error. @@ -234,17 +237,17 @@ def remove_misplaced_type_comments(source: Union[str, bytes]) -> Union[str, byte text = re.sub(r"''' *\n[ \t\n]*# +type: +\(.*$", "'''\n", text, flags=re.MULTILINE) # Remove something that looks like a badly formed function type comment. - text = re.sub(r'^[ \t]*# +type: +\([^()]+(\)[ \t]*)?$', '', text, flags=re.MULTILINE) + text = re.sub(r"^[ \t]*# +type: +\([^()]+(\)[ \t]*)?$", "", text, flags=re.MULTILINE) if isinstance(source, bytes): - return text.encode('latin1') + return text.encode("latin1") else: return text def common_dir_prefix(paths: List[str]) -> str: if not paths: - return '.' + return "." cur = os.path.dirname(os.path.normpath(paths[0])) for path in paths[1:]: while True: @@ -252,4 +255,4 @@ def common_dir_prefix(paths: List[str]) -> str: if (cur + os.sep).startswith(path + os.sep): cur = path break - return cur or '.' + return cur or "." diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 64f1de2c68284..c5d2cc5e98c15 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1,33 +1,64 @@ from contextlib import contextmanager +from typing import Any, Callable, Iterator, List, Optional, Set, Tuple, TypeVar, Union, cast -from typing import Any, List, Optional, Callable, Tuple, Iterator, Set, Union, cast, TypeVar from typing_extensions import Final, TypeAlias as _TypeAlias -from mypy.types import ( - Type, AnyType, UnboundType, TypeVisitor, FormalArgument, NoneType, - Instance, TypeVarType, CallableType, TupleType, TypedDictType, UnionType, Overloaded, - ErasedType, PartialType, DeletedType, UninhabitedType, TypeType, is_named_instance, - FunctionLike, TypeOfAny, LiteralType, get_proper_type, TypeAliasType, ParamSpecType, - Parameters, UnpackType, TUPLE_LIKE_INSTANCE_NAMES, TYPED_NAMEDTUPLE_NAMES, - TypeVarTupleType, ProperType -) import mypy.applytype import mypy.constraints -import mypy.typeops import mypy.sametypes +import mypy.typeops from mypy.erasetype import erase_type +from mypy.expandtype import expand_type_by_instance +from mypy.maptype import map_instance_to_supertype + # Circular import; done in the function instead. # import mypy.solve from mypy.nodes import ( - FuncBase, Var, Decorator, OverloadedFuncDef, TypeInfo, CONTRAVARIANT, COVARIANT, - + CONTRAVARIANT, + COVARIANT, + Decorator, + FuncBase, + OverloadedFuncDef, + TypeInfo, + Var, ) -from mypy.maptype import map_instance_to_supertype -from mypy.expandtype import expand_type_by_instance -from mypy.typestate import TypeState, SubtypeKind from mypy.options import Options from mypy.state import state -from mypy.typevartuples import split_with_instance, extract_unpack +from mypy.types import ( + TUPLE_LIKE_INSTANCE_NAMES, + TYPED_NAMEDTUPLE_NAMES, + AnyType, + CallableType, + DeletedType, + ErasedType, + FormalArgument, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeType, + TypeVarTupleType, + TypeVarType, + TypeVisitor, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, + is_named_instance, +) +from mypy.typestate import SubtypeKind, TypeState +from mypy.typevartuples import extract_unpack, split_with_instance # Flags for detected protocol members IS_SETTABLE: Final = 1 @@ -50,13 +81,16 @@ def ignore_type_parameter(s: Type, t: Type, v: int) -> bool: return True -def is_subtype(left: Type, right: Type, - *, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, - options: Optional[Options] = None) -> bool: +def is_subtype( + left: Type, + right: Type, + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False, + options: Optional[Options] = None, +) -> bool: """Is 'left' subtype of 'right'? Also consider Any to be a subtype of any type, and vice versa. This @@ -70,8 +104,12 @@ def is_subtype(left: Type, right: Type, """ if TypeState.is_assumed_subtype(left, right): return True - if (isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType) and - left.is_recursive and right.is_recursive): + if ( + isinstance(left, TypeAliasType) + and isinstance(right, TypeAliasType) + and left.is_recursive + and right.is_recursive + ): # This case requires special care because it may cause infinite recursion. # Our view on recursive types is known under a fancy name of equirecursive mu-types. # Roughly this means that a recursive type is defined as an alias where right hand side @@ -90,62 +128,86 @@ def is_subtype(left: Type, right: Type, # When checking if A <: B we push pair (A, B) onto 'assuming' stack, then when after few # steps we come back to initial call is_subtype(A, B) and immediately return True. with pop_on_exit(TypeState._assuming, left, right): - return _is_subtype(left, right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options) - return _is_subtype(left, right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options) - - -def _is_subtype(left: Type, right: Type, - *, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, - options: Optional[Options] = None) -> bool: + return _is_subtype( + left, + right, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + return _is_subtype( + left, + right, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + + +def _is_subtype( + left: Type, + right: Type, + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False, + options: Optional[Options] = None, +) -> bool: orig_right = right orig_left = left left = get_proper_type(left) right = get_proper_type(right) - if (isinstance(right, AnyType) or isinstance(right, UnboundType) - or isinstance(right, ErasedType)): + if ( + isinstance(right, AnyType) + or isinstance(right, UnboundType) + or isinstance(right, ErasedType) + ): return True elif isinstance(right, UnionType) and not isinstance(left, UnionType): # Normally, when 'left' is not itself a union, the only way # 'left' can be a subtype of the union 'right' is if it is a # subtype of one of the items making up the union. - is_subtype_of_item = any(is_subtype(orig_left, item, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options) - for item in right.items) + is_subtype_of_item = any( + is_subtype( + orig_left, + item, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + for item in right.items + ) # Recombine rhs literal types, to make an enum type a subtype # of a union of all enum items as literal types. Only do it if # the previous check didn't succeed, since recombining can be # expensive. # `bool` is a special case, because `bool` is `Literal[True, False]`. - if (not is_subtype_of_item - and isinstance(left, Instance) - and (left.type.is_enum or left.type.fullname == 'builtins.bool')): + if ( + not is_subtype_of_item + and isinstance(left, Instance) + and (left.type.is_enum or left.type.fullname == "builtins.bool") + ): right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items)) - is_subtype_of_item = any(is_subtype(orig_left, item, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options) - for item in right.items) + is_subtype_of_item = any( + is_subtype( + orig_left, + item, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + for item in right.items + ) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be # handled below by the SubtypeVisitor. We have to check both @@ -157,71 +219,96 @@ def _is_subtype(left: Type, right: Type, elif is_subtype_of_item: return True # otherwise, fall through - return left.accept(SubtypeVisitor(orig_right, - ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, - ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions, - options=options)) - - -def is_equivalent(a: Type, b: Type, - *, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - options: Optional[Options] = None - ) -> bool: - return ( - is_subtype(a, b, ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, options=options) - and is_subtype(b, a, ignore_type_params=ignore_type_params, - ignore_pos_arg_names=ignore_pos_arg_names, options=options)) + return left.accept( + SubtypeVisitor( + orig_right, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions, + options=options, + ) + ) + + +def is_equivalent( + a: Type, + b: Type, + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + options: Optional[Options] = None, +) -> bool: + return is_subtype( + a, + b, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + options=options, + ) and is_subtype( + b, + a, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + options=options, + ) class SubtypeVisitor(TypeVisitor[bool]): - - def __init__(self, right: Type, - *, - ignore_type_params: bool, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False, - options: Optional[Options] = None) -> None: + def __init__( + self, + right: Type, + *, + ignore_type_params: bool, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False, + options: Optional[Options] = None, + ) -> None: self.right = get_proper_type(right) self.orig_right = right self.ignore_type_params = ignore_type_params self.ignore_pos_arg_names = ignore_pos_arg_names self.ignore_declared_variance = ignore_declared_variance self.ignore_promotions = ignore_promotions - self.check_type_parameter = (ignore_type_parameter if ignore_type_params else - check_type_parameter) + self.check_type_parameter = ( + ignore_type_parameter if ignore_type_params else check_type_parameter + ) self.options = options self._subtype_kind = SubtypeVisitor.build_subtype_kind( ignore_type_params=ignore_type_params, ignore_pos_arg_names=ignore_pos_arg_names, ignore_declared_variance=ignore_declared_variance, - ignore_promotions=ignore_promotions) + ignore_promotions=ignore_promotions, + ) @staticmethod - def build_subtype_kind(*, - ignore_type_params: bool = False, - ignore_pos_arg_names: bool = False, - ignore_declared_variance: bool = False, - ignore_promotions: bool = False) -> SubtypeKind: - return (state.strict_optional, - False, # is proper subtype? - ignore_type_params, - ignore_pos_arg_names, - ignore_declared_variance, - ignore_promotions) + def build_subtype_kind( + *, + ignore_type_params: bool = False, + ignore_pos_arg_names: bool = False, + ignore_declared_variance: bool = False, + ignore_promotions: bool = False, + ) -> SubtypeKind: + return ( + state.strict_optional, + False, # is proper subtype? + ignore_type_params, + ignore_pos_arg_names, + ignore_declared_variance, + ignore_promotions, + ) def _is_subtype(self, left: Type, right: Type) -> bool: - return is_subtype(left, right, - ignore_type_params=self.ignore_type_params, - ignore_pos_arg_names=self.ignore_pos_arg_names, - ignore_declared_variance=self.ignore_declared_variance, - ignore_promotions=self.ignore_promotions, - options=self.options) + return is_subtype( + left, + right, + ignore_type_params=self.ignore_type_params, + ignore_pos_arg_names=self.ignore_pos_arg_names, + ignore_declared_variance=self.ignore_declared_variance, + ignore_promotions=self.ignore_promotions, + options=self.options, + ) # visit_x(left) means: is left (which is an instance of X) a subtype of # right? @@ -234,8 +321,9 @@ def visit_any(self, left: AnyType) -> bool: def visit_none_type(self, left: NoneType) -> bool: if state.strict_optional: - if isinstance(self.right, NoneType) or is_named_instance(self.right, - 'builtins.object'): + if isinstance(self.right, NoneType) or is_named_instance( + self.right, "builtins.object" + ): return True if isinstance(self.right, Instance) and self.right.type.is_protocol: members = self.right.type.protocol_members @@ -273,8 +361,9 @@ def visit_instance(self, left: Instance) -> bool: return True if not self.ignore_promotions: for base in left.type.mro: - if base._promote and any(self._is_subtype(p, self.right) - for p in base._promote): + if base._promote and any( + self._is_subtype(p, self.right) for p in base._promote + ): TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return True # Special case: Low-level integer types are compatible with 'int'. We can't @@ -288,32 +377,27 @@ def visit_instance(self, left: Instance) -> bool: # NamedTuples are a special case, because `NamedTuple` is not listed # in `TypeInfo.mro`, so when `(a: NamedTuple) -> None` is used, # we need to check for `is_named_tuple` property - if ((left.type.has_base(rname) or rname == 'builtins.object' - or (rname in TYPED_NAMEDTUPLE_NAMES - and any(l.is_named_tuple for l in left.type.mro))) - and not self.ignore_declared_variance): + if ( + left.type.has_base(rname) + or rname == "builtins.object" + or ( + rname in TYPED_NAMEDTUPLE_NAMES + and any(l.is_named_tuple for l in left.type.mro) + ) + ) and not self.ignore_declared_variance: # Map left type to corresponding right instances. t = map_instance_to_supertype(left, right.type) nominal = True if right.type.has_type_var_tuple_type: - left_prefix, left_middle, left_suffix = ( - split_with_instance(left) - ) - right_prefix, right_middle, right_suffix = ( - split_with_instance(right) - ) + left_prefix, left_middle, left_suffix = split_with_instance(left) + right_prefix, right_middle, right_suffix = split_with_instance(right) - left_unpacked = extract_unpack( - left_middle - ) - right_unpacked = extract_unpack( - right_middle - ) + left_unpacked = extract_unpack(left_middle) + right_unpacked = extract_unpack(right_middle) # Helper for case 2 below so we can treat them the same. def check_mixed( - unpacked_type: ProperType, - compare_to: Tuple[Type, ...] + unpacked_type: ProperType, compare_to: Tuple[Type, ...] ) -> bool: if isinstance(unpacked_type, TypeVarTupleType): return False @@ -346,8 +430,7 @@ def check_mixed( and right_unpacked.type.fullname == "builtins.tuple" ): return all( - is_equivalent(l, right_unpacked.args[0]) - for l in left_middle + is_equivalent(l, right_unpacked.args[0]) for l in left_middle ) if not check_mixed(right_unpacked, left_middle): return False @@ -361,19 +444,19 @@ def check_mixed( if not is_equivalent(left_t, right_t): return False - left_items = t.args[:right.type.type_var_tuple_prefix] - right_items = right.args[:right.type.type_var_tuple_prefix] + left_items = t.args[: right.type.type_var_tuple_prefix] + right_items = right.args[: right.type.type_var_tuple_prefix] if right.type.type_var_tuple_suffix: - left_items += t.args[-right.type.type_var_tuple_suffix:] - right_items += right.args[-right.type.type_var_tuple_suffix:] + left_items += t.args[-right.type.type_var_tuple_suffix :] + right_items += right.args[-right.type.type_var_tuple_suffix :] unpack_index = right.type.type_var_tuple_prefix assert unpack_index is not None type_params = zip( left_prefix + right_suffix, right_prefix + right_suffix, - right.type.defn.type_vars[:unpack_index] + - right.type.defn.type_vars[unpack_index+1:] + right.type.defn.type_vars[:unpack_index] + + right.type.defn.type_vars[unpack_index + 1 :], ) else: type_params = zip(t.args, right.args, right.type.defn.type_vars) @@ -394,18 +477,18 @@ def check_mixed( item = right.item if isinstance(item, TupleType): item = mypy.typeops.tuple_fallback(item) - if is_named_instance(left, 'builtins.type'): + if is_named_instance(left, "builtins.type"): return self._is_subtype(TypeType(AnyType(TypeOfAny.special_form)), right) if left.type.is_metaclass(): if isinstance(item, AnyType): return True if isinstance(item, Instance): - return is_named_instance(item, 'builtins.object') + return is_named_instance(item, "builtins.object") if isinstance(right, LiteralType) and left.last_known_value is not None: return self._is_subtype(left.last_known_value, right) if isinstance(right, CallableType): # Special case: Instance can be a subtype of Callable. - call = find_member('__call__', left, left, is_operator=True) + call = find_member("__call__", left, left, is_operator=True) if call: return self._is_subtype(call, right) return False @@ -417,7 +500,8 @@ def visit_type_var(self, left: TypeVarType) -> bool: if isinstance(right, TypeVarType) and left.id == right.id: return True if left.values and self._is_subtype( - mypy.typeops.make_simplified_union(left.values), right): + mypy.typeops.make_simplified_union(left.values), right + ): return True return self._is_subtype(left.upper_bound, self.right) @@ -433,10 +517,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool: right = self.right - if ( - isinstance(right, TypeVarTupleType) - and right.id == left.id - ): + if isinstance(right, TypeVarTupleType) and right.id == left.id: return True return self._is_subtype(left.upper_bound, self.right) @@ -449,9 +530,11 @@ def visit_parameters(self, left: Parameters) -> bool: right = self.right if isinstance(right, Parameters) or isinstance(right, CallableType): return are_parameters_compatible( - left, right, + left, + right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names) + ignore_pos_arg_names=self.ignore_pos_arg_names, + ) else: return False @@ -466,17 +549,19 @@ def visit_callable_type(self, left: CallableType) -> bool: # They are not compatible. See https://github.com/python/mypy/issues/11307 return False return is_callable_compatible( - left, right, + left, + right, is_compat=self._is_subtype, ignore_pos_arg_names=self.ignore_pos_arg_names, - strict_concatenate=self.options.strict_concatenate if self.options else True) + strict_concatenate=self.options.strict_concatenate if self.options else True, + ) elif isinstance(right, Overloaded): return all(self._is_subtype(left, item) for item in right.items) elif isinstance(right, Instance): - if right.type.is_protocol and right.type.protocol_members == ['__call__']: + if right.type.is_protocol and right.type.protocol_members == ["__call__"]: # OK, a callable can implement a protocol with a single `__call__` member. # TODO: we should probably explicitly exclude self-types in this case. - call = find_member('__call__', right, left, is_operator=True) + call = find_member("__call__", right, left, is_operator=True) assert call is not None if self._is_subtype(left, call): return True @@ -487,16 +572,18 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, Parameters): # this doesn't check return types.... but is needed for is_equivalent return are_parameters_compatible( - left, right, + left, + right, is_compat=self._is_subtype, - ignore_pos_arg_names=self.ignore_pos_arg_names) + ignore_pos_arg_names=self.ignore_pos_arg_names, + ) else: return False def visit_tuple_type(self, left: TupleType) -> bool: right = self.right if isinstance(right, Instance): - if is_named_instance(right, 'typing.Sized'): + if is_named_instance(right, "typing.Sized"): return True elif is_named_instance(right, TUPLE_LIKE_INSTANCE_NAMES): if right.args: @@ -514,7 +601,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: if not self._is_subtype(l, r): return False rfallback = mypy.typeops.tuple_fallback(right) - if is_named_instance(rfallback, 'builtins.tuple'): + if is_named_instance(rfallback, "builtins.tuple"): # No need to verify fallback. This is useful since the calculated fallback # may be inconsistent due to how we calculate joins between unions vs. # non-unions. For example, join(int, str) == object, whereas @@ -535,9 +622,9 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if not left.names_are_wider_than(right): return False for name, l, r in left.zip(right): - if not is_equivalent(l, r, - ignore_type_params=self.ignore_type_params, - options=self.options): + if not is_equivalent( + l, r, ignore_type_params=self.ignore_type_params, options=self.options + ): return False # Non-required key is not compatible with a required key since # indexing may fail unexpectedly if a required key is missing. @@ -564,9 +651,9 @@ def visit_literal_type(self, left: LiteralType) -> bool: def visit_overloaded(self, left: Overloaded) -> bool: right = self.right if isinstance(right, Instance): - if right.type.is_protocol and right.type.protocol_members == ['__call__']: + if right.type.is_protocol and right.type.protocol_members == ["__call__"]: # same as for CallableType - call = find_member('__call__', right, left, is_operator=True) + call = find_member("__call__", right, left, is_operator=True) assert call is not None if self._is_subtype(left, call): return True @@ -605,14 +692,21 @@ def visit_overloaded(self, left: Overloaded) -> bool: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. strict_concat = self.options.strict_concatenate if self.options else True - if (is_callable_compatible(left_item, right_item, - is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names, - strict_concatenate=strict_concat) or - is_callable_compatible(right_item, left_item, - is_compat=self._is_subtype, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names, - strict_concatenate=strict_concat)): + if is_callable_compatible( + left_item, + right_item, + is_compat=self._is_subtype, + ignore_return=True, + ignore_pos_arg_names=self.ignore_pos_arg_names, + strict_concatenate=strict_concat, + ) or is_callable_compatible( + right_item, + left_item, + is_compat=self._is_subtype, + ignore_return=True, + ignore_pos_arg_names=self.ignore_pos_arg_names, + strict_concatenate=strict_concat, + ): # If this is an overload that's already been matched, there's no # problem. if left_item not in matched_overloads: @@ -671,7 +765,7 @@ def visit_type_type(self, left: TypeType) -> bool: # This is unsound, we don't check the __init__ signature. return self._is_subtype(left.item, right.ret_type) if isinstance(right, Instance): - if right.type.fullname in ['builtins.object', 'builtins.type']: + if right.type.fullname in ["builtins.object", "builtins.type"]: return True item = left.item if isinstance(item, TypeVarType): @@ -685,19 +779,19 @@ def visit_type_alias_type(self, left: TypeAliasType) -> bool: assert False, f"This should be never called, got {left}" -T = TypeVar('T', Instance, TypeAliasType) +T = TypeVar("T", Instance, TypeAliasType) @contextmanager -def pop_on_exit(stack: List[Tuple[T, T]], - left: T, right: T) -> Iterator[None]: +def pop_on_exit(stack: List[Tuple[T, T]], left: T, right: T) -> Iterator[None]: stack.append((left, right)) yield stack.pop() -def is_protocol_implementation(left: Instance, right: Instance, - proper_subtype: bool = False) -> bool: +def is_protocol_implementation( + left: Instance, right: Instance, proper_subtype: bool = False +) -> bool: """Check whether 'left' implements the protocol 'right'. If 'proper_subtype' is True, then check for a proper subtype. @@ -719,7 +813,7 @@ def f(self) -> A: ... # We need to record this check to generate protocol fine-grained dependencies. TypeState.record_protocol_subtype_check(left.type, right.type) # nominal subtyping currently ignores '__init__' and '__new__' signatures - members_not_to_check = {'__init__', '__new__'} + members_not_to_check = {"__init__", "__new__"} # Trivial check that circumvents the bug described in issue 9771: if left.type.is_protocol: members_right = set(right.type.protocol_members) - members_not_to_check @@ -734,7 +828,7 @@ def f(self) -> A: ... for member in right.type.protocol_members: if member in members_not_to_check: continue - ignore_names = member != '__call__' # __call__ can be passed kwargs + ignore_names = member != "__call__" # __call__ can be passed kwargs # The third argument below indicates to what self type is bound. # We always bind self to the subtype. (Similarly to nominal types). supertype = get_proper_type(find_member(member, right, left)) @@ -746,8 +840,13 @@ def f(self) -> A: ... if not subtype: return False if isinstance(subtype, PartialType): - subtype = NoneType() if subtype.type is None else Instance( - subtype.type, [AnyType(TypeOfAny.unannotated)] * len(subtype.type.type_vars) + subtype = ( + NoneType() + if subtype.type is None + else Instance( + subtype.type, + [AnyType(TypeOfAny.unannotated)] * len(subtype.type.type_vars), + ) ) if not proper_subtype: # Nominal check currently ignores arg names @@ -777,7 +876,7 @@ def f(self) -> A: ... if not proper_subtype: # Nominal check currently ignores arg names, but __call__ is special for protocols - ignore_names = right.type.protocol_members != ['__call__'] + ignore_names = right.type.protocol_members != ["__call__"] subtype_kind = SubtypeVisitor.build_subtype_kind(ignore_pos_arg_names=ignore_names) else: subtype_kind = ProperSubtypeVisitor.build_subtype_kind() @@ -785,10 +884,9 @@ def f(self) -> A: ... return True -def find_member(name: str, - itype: Instance, - subtype: Type, - is_operator: bool = False) -> Optional[Type]: +def find_member( + name: str, itype: Instance, subtype: Type, is_operator: bool = False +) -> Optional[Type]: """Find the type of member by 'name' in 'itype's TypeInfo. Find the member type after applying type arguments from 'itype', and binding @@ -813,15 +911,18 @@ def find_member(name: str, v = node.node if node else None if isinstance(v, Var): return find_node_type(v, itype, subtype) - if (not v and name not in ['__getattr__', '__setattr__', '__getattribute__'] and - not is_operator): - for method_name in ('__getattribute__', '__getattr__'): + if ( + not v + and name not in ["__getattr__", "__setattr__", "__getattribute__"] + and not is_operator + ): + for method_name in ("__getattribute__", "__getattr__"): # Normally, mypy assumes that instances that define __getattr__ have all # attributes with the corresponding return type. If this will produce # many false negatives, then this could be prohibited for # structural subtyping. method = info.get_method(method_name) - if method and method.info.fullname != 'builtins.object': + if method and method.info.fullname != "builtins.object": if isinstance(method, Decorator): getattr_type = get_proper_type(find_node_type(method.var, itype, subtype)) else: @@ -846,7 +947,7 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]: with @staticmethod. """ method = info.get_method(name) - setattr_meth = info.get_method('__setattr__') + setattr_meth = info.get_method("__setattr__") if method: if isinstance(method, Decorator): if method.var.is_staticmethod or method.var.is_classmethod: @@ -889,13 +990,13 @@ def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) - if typ is None: return AnyType(TypeOfAny.from_error) # We don't need to bind 'self' for static methods, since there is no 'self'. - if (isinstance(node, FuncBase) - or (isinstance(typ, FunctionLike) - and node.is_initialized_in_class - and not node.is_staticmethod)): + if isinstance(node, FuncBase) or ( + isinstance(typ, FunctionLike) and node.is_initialized_in_class and not node.is_staticmethod + ): assert isinstance(typ, FunctionLike) - signature = bind_self(typ, subtype, - is_classmethod=isinstance(node, Var) and node.is_classmethod) + signature = bind_self( + typ, subtype, is_classmethod=isinstance(node, Var) and node.is_classmethod + ) if node.is_property: assert isinstance(signature, CallableType) typ = signature.ret_type @@ -921,15 +1022,18 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]: return result -def is_callable_compatible(left: CallableType, right: CallableType, - *, - is_compat: Callable[[Type, Type], bool], - is_compat_return: Optional[Callable[[Type, Type], bool]] = None, - ignore_return: bool = False, - ignore_pos_arg_names: bool = False, - check_args_covariantly: bool = False, - allow_partial_overlap: bool = False, - strict_concatenate: bool = False) -> bool: +def is_callable_compatible( + left: CallableType, + right: CallableType, + *, + is_compat: Callable[[Type, Type], bool], + is_compat_return: Optional[Callable[[Type, Type], bool]] = None, + ignore_return: bool = False, + ignore_pos_arg_names: bool = False, + check_args_covariantly: bool = False, + allow_partial_overlap: bool = False, + strict_concatenate: bool = False, +) -> bool: """Is the left compatible with the right, using the provided compatibility check? is_compat: @@ -1070,21 +1174,27 @@ def g(x: int) -> int: ... else: strict_concatenate_check = True - return are_parameters_compatible(left, right, is_compat=is_compat, - ignore_pos_arg_names=ignore_pos_arg_names, - check_args_covariantly=check_args_covariantly, - allow_partial_overlap=allow_partial_overlap, - strict_concatenate_check=strict_concatenate_check) - - -def are_parameters_compatible(left: Union[Parameters, CallableType], - right: Union[Parameters, CallableType], - *, - is_compat: Callable[[Type, Type], bool], - ignore_pos_arg_names: bool = False, - check_args_covariantly: bool = False, - allow_partial_overlap: bool = False, - strict_concatenate_check: bool = True) -> bool: + return are_parameters_compatible( + left, + right, + is_compat=is_compat, + ignore_pos_arg_names=ignore_pos_arg_names, + check_args_covariantly=check_args_covariantly, + allow_partial_overlap=allow_partial_overlap, + strict_concatenate_check=strict_concatenate_check, + ) + + +def are_parameters_compatible( + left: Union[Parameters, CallableType], + right: Union[Parameters, CallableType], + *, + is_compat: Callable[[Type, Type], bool], + ignore_pos_arg_names: bool = False, + check_args_covariantly: bool = False, + allow_partial_overlap: bool = False, + strict_concatenate_check: bool = True, +) -> bool: """Helper function for is_callable_compatible, used for Parameter compatibility""" if right.is_ellipsis_args: return True @@ -1119,8 +1229,9 @@ def are_parameters_compatible(left: Union[Parameters, CallableType], # Furthermore, if we're checking for compatibility in all cases, # we confirm that if R accepts an infinite number of arguments, # L must accept the same. - def _incompatible(left_arg: Optional[FormalArgument], - right_arg: Optional[FormalArgument]) -> bool: + def _incompatible( + left_arg: Optional[FormalArgument], right_arg: Optional[FormalArgument] + ) -> bool: if right_arg is None: return False if left_arg is None: @@ -1139,8 +1250,9 @@ def _incompatible(left_arg: Optional[FormalArgument], if allow_partial_overlap and not right_arg.required: continue return False - if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, - allow_partial_overlap, is_compat): + if not are_args_compatible( + left_arg, right_arg, ignore_pos_arg_names, allow_partial_overlap, is_compat + ): return False # Phase 1c: Check var args. Right has an infinite series of optional positional @@ -1160,9 +1272,13 @@ def _incompatible(left_arg: Optional[FormalArgument], left_by_position = left.argument_by_position(i) assert left_by_position is not None - if not are_args_compatible(left_by_position, right_by_position, - ignore_pos_arg_names, allow_partial_overlap, - is_compat): + if not are_args_compatible( + left_by_position, + right_by_position, + ignore_pos_arg_names, + allow_partial_overlap, + is_compat, + ): return False i += 1 @@ -1173,9 +1289,12 @@ def _incompatible(left_arg: Optional[FormalArgument], right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): - if (name is None or kind.is_star() - or name in right_names - or not strict_concatenate_check): + if ( + name is None + or kind.is_star() + or name in right_names + or not strict_concatenate_check + ): continue left_only_names.add(name) @@ -1190,29 +1309,32 @@ def _incompatible(left_arg: Optional[FormalArgument], if allow_partial_overlap and not left_by_name.required: continue - if not are_args_compatible(left_by_name, right_by_name, ignore_pos_arg_names, - allow_partial_overlap, is_compat): + if not are_args_compatible( + left_by_name, right_by_name, ignore_pos_arg_names, allow_partial_overlap, is_compat + ): return False # Phase 2: Left must not impose additional restrictions. # (Every required argument in L must have a corresponding argument in R) # Note: we already checked the *arg and **kwarg arguments in phase 1a. for left_arg in left.formal_arguments(): - right_by_name = (right.argument_by_name(left_arg.name) - if left_arg.name is not None - else None) + right_by_name = ( + right.argument_by_name(left_arg.name) if left_arg.name is not None else None + ) - right_by_pos = (right.argument_by_position(left_arg.pos) - if left_arg.pos is not None - else None) + right_by_pos = ( + right.argument_by_position(left_arg.pos) if left_arg.pos is not None else None + ) # If the left hand argument corresponds to two right-hand arguments, # neither of them can be required. - if (right_by_name is not None - and right_by_pos is not None - and right_by_name != right_by_pos - and (right_by_pos.required or right_by_name.required) - and strict_concatenate_check): + if ( + right_by_name is not None + and right_by_pos is not None + and right_by_name != right_by_pos + and (right_by_pos.required or right_by_name.required) + and strict_concatenate_check + ): return False # All *required* left-hand arguments must have a corresponding @@ -1224,11 +1346,12 @@ def _incompatible(left_arg: Optional[FormalArgument], def are_args_compatible( - left: FormalArgument, - right: FormalArgument, - ignore_pos_arg_names: bool, - allow_partial_overlap: bool, - is_compat: Callable[[Type, Type], bool]) -> bool: + left: FormalArgument, + right: FormalArgument, + ignore_pos_arg_names: bool, + allow_partial_overlap: bool, + is_compat: Callable[[Type, Type], bool], +) -> bool: def is_different(left_item: Optional[object], right_item: Optional[object]) -> bool: """Checks if the left and right items are different. @@ -1272,13 +1395,16 @@ def is_different(left_item: Optional[object], right_item: Optional[object]) -> b def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]: def new_is_compat(left: Type, right: Type) -> bool: return is_compat(right, left) + return new_is_compat -def unify_generic_callable(type: CallableType, target: CallableType, - ignore_return: bool, - return_constraint_direction: Optional[int] = None, - ) -> Optional[CallableType]: +def unify_generic_callable( + type: CallableType, + target: CallableType, + ignore_return: bool, + return_constraint_direction: Optional[int] = None, +) -> Optional[CallableType]: """Try to unify a generic callable type with another callable type. Return unified CallableType if successful; otherwise, return None. @@ -1291,11 +1417,13 @@ def unify_generic_callable(type: CallableType, target: CallableType, constraints: List[mypy.constraints.Constraint] = [] for arg_type, target_arg_type in zip(type.arg_types, target.arg_types): c = mypy.constraints.infer_constraints( - arg_type, target_arg_type, mypy.constraints.SUPERTYPE_OF) + arg_type, target_arg_type, mypy.constraints.SUPERTYPE_OF + ) constraints.extend(c) if not ignore_return: c = mypy.constraints.infer_constraints( - type.ret_type, target.ret_type, return_constraint_direction) + type.ret_type, target.ret_type, return_constraint_direction + ) constraints.extend(c) type_var_ids = [tvar.id for tvar in type.variables] inferred_vars = mypy.solve.solve_constraints(type_var_ids, constraints) @@ -1308,8 +1436,9 @@ def report(*args: Any) -> None: nonlocal had_errors had_errors = True - applied = mypy.applytype.apply_generic_arguments(type, non_none_inferred_vars, report, - context=target) + applied = mypy.applytype.apply_generic_arguments( + type, non_none_inferred_vars, report, context=target + ) if had_errors: return None return applied @@ -1354,8 +1483,10 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) new_items = [ restrict_subtype_away(item, s, ignore_promotions=ignore_promotions) for item in t.relevant_items() - if (isinstance(get_proper_type(item), AnyType) or - not covers_at_runtime(item, s, ignore_promotions)) + if ( + isinstance(get_proper_type(item), AnyType) + or not covers_at_runtime(item, s, ignore_promotions) + ) ] return UnionType.make_union(new_items) elif covers_at_runtime(t, s, ignore_promotions): @@ -1371,8 +1502,9 @@ def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> b # Since runtime type checks will ignore type arguments, erase the types. supertype = erase_type(supertype) - if is_proper_subtype(erase_type(item), supertype, ignore_promotions=ignore_promotions, - erase_instances=True): + if is_proper_subtype( + erase_type(item), supertype, ignore_promotions=ignore_promotions, erase_instances=True + ): return True if isinstance(supertype, Instance) and supertype.type.is_protocol: # TODO: Implement more robust support for runtime isinstance() checks, see issue #3827. @@ -1380,16 +1512,20 @@ def covers_at_runtime(item: Type, supertype: Type, ignore_promotions: bool) -> b return True if isinstance(item, TypedDictType) and isinstance(supertype, Instance): # Special case useful for selecting TypedDicts from unions using isinstance(x, dict). - if supertype.type.fullname == 'builtins.dict': + if supertype.type.fullname == "builtins.dict": return True # TODO: Add more special cases. return False -def is_proper_subtype(left: Type, right: Type, *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False) -> bool: +def is_proper_subtype( + left: Type, + right: Type, + *, + ignore_promotions: bool = False, + erase_instances: bool = False, + keep_erased_types: bool = False, +) -> bool: """Is left a proper subtype of right? For proper subtypes, there's no need to rely on compatibility due to @@ -1401,47 +1537,74 @@ def is_proper_subtype(left: Type, right: Type, *, """ if TypeState.is_assumed_proper_subtype(left, right): return True - if (isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType) and - left.is_recursive and right.is_recursive): + if ( + isinstance(left, TypeAliasType) + and isinstance(right, TypeAliasType) + and left.is_recursive + and right.is_recursive + ): # This case requires special care because it may cause infinite recursion. # See is_subtype() for more info. with pop_on_exit(TypeState._assuming_proper, left, right): - return _is_proper_subtype(left, right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types) - return _is_proper_subtype(left, right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types) - - -def _is_proper_subtype(left: Type, right: Type, *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False) -> bool: + return _is_proper_subtype( + left, + right, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types, + ) + return _is_proper_subtype( + left, + right, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types, + ) + + +def _is_proper_subtype( + left: Type, + right: Type, + *, + ignore_promotions: bool = False, + erase_instances: bool = False, + keep_erased_types: bool = False, +) -> bool: orig_left = left orig_right = right left = get_proper_type(left) right = get_proper_type(right) if isinstance(right, UnionType) and not isinstance(left, UnionType): - return any(is_proper_subtype(orig_left, item, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types) - for item in right.items) - return left.accept(ProperSubtypeVisitor(orig_right, - ignore_promotions=ignore_promotions, - erase_instances=erase_instances, - keep_erased_types=keep_erased_types)) + return any( + is_proper_subtype( + orig_left, + item, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types, + ) + for item in right.items + ) + return left.accept( + ProperSubtypeVisitor( + orig_right, + ignore_promotions=ignore_promotions, + erase_instances=erase_instances, + keep_erased_types=keep_erased_types, + ) + ) class ProperSubtypeVisitor(TypeVisitor[bool]): - def __init__(self, right: Type, *, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False) -> None: + def __init__( + self, + right: Type, + *, + ignore_promotions: bool = False, + erase_instances: bool = False, + keep_erased_types: bool = False, + ) -> None: self.right = get_proper_type(right) self.orig_right = right self.ignore_promotions = ignore_promotions @@ -1450,25 +1613,26 @@ def __init__(self, right: Type, *, self._subtype_kind = ProperSubtypeVisitor.build_subtype_kind( ignore_promotions=ignore_promotions, erase_instances=erase_instances, - keep_erased_types=keep_erased_types + keep_erased_types=keep_erased_types, ) @staticmethod - def build_subtype_kind(*, - ignore_promotions: bool = False, - erase_instances: bool = False, - keep_erased_types: bool = False) -> SubtypeKind: - return (state.strict_optional, - True, - ignore_promotions, - erase_instances, - keep_erased_types) + def build_subtype_kind( + *, + ignore_promotions: bool = False, + erase_instances: bool = False, + keep_erased_types: bool = False, + ) -> SubtypeKind: + return (state.strict_optional, True, ignore_promotions, erase_instances, keep_erased_types) def _is_proper_subtype(self, left: Type, right: Type) -> bool: - return is_proper_subtype(left, right, - ignore_promotions=self.ignore_promotions, - erase_instances=self.erase_instances, - keep_erased_types=self.keep_erased_types) + return is_proper_subtype( + left, + right, + ignore_promotions=self.ignore_promotions, + erase_instances=self.erase_instances, + keep_erased_types=self.keep_erased_types, + ) def visit_unbound_type(self, left: UnboundType) -> bool: # This can be called if there is a bad type annotation. The result probably @@ -1481,8 +1645,9 @@ def visit_any(self, left: AnyType) -> bool: def visit_none_type(self, left: NoneType) -> bool: if state.strict_optional: - return (isinstance(self.right, NoneType) or - is_named_instance(self.right, 'builtins.object')) + return isinstance(self.right, NoneType) or is_named_instance( + self.right, "builtins.object" + ) return True def visit_uninhabited_type(self, left: UninhabitedType) -> bool: @@ -1506,8 +1671,9 @@ def visit_instance(self, left: Instance) -> bool: return True if not self.ignore_promotions: for base in left.type.mro: - if base._promote and any(self._is_proper_subtype(p, right) - for p in base._promote): + if base._promote and any( + self._is_proper_subtype(p, right) for p in base._promote + ): TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return True @@ -1537,12 +1703,13 @@ def visit_instance(self, left: Instance) -> bool: if nominal: TypeState.record_subtype_cache_entry(self._subtype_kind, left, right) return nominal - if (right.type.is_protocol and - is_protocol_implementation(left, right, proper_subtype=True)): + if right.type.is_protocol and is_protocol_implementation( + left, right, proper_subtype=True + ): return True return False if isinstance(right, CallableType): - call = find_member('__call__', left, left, is_operator=True) + call = find_member("__call__", left, left, is_operator=True) if call: return self._is_proper_subtype(call, right) return False @@ -1552,7 +1719,8 @@ def visit_type_var(self, left: TypeVarType) -> bool: if isinstance(self.right, TypeVarType) and left.id == self.right.id: return True if left.values and self._is_proper_subtype( - mypy.typeops.make_simplified_union(left.values), self.right): + mypy.typeops.make_simplified_union(left.values), self.right + ): return True return self._is_proper_subtype(left.upper_bound, self.right) @@ -1568,10 +1736,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool: right = self.right - if ( - isinstance(right, TypeVarTupleType) - and right.id == left.id - ): + if isinstance(right, TypeVarTupleType) and right.id == left.id: return True return self._is_proper_subtype(left.upper_bound, self.right) @@ -1592,8 +1757,7 @@ def visit_callable_type(self, left: CallableType) -> bool: if isinstance(right, CallableType): return is_callable_compatible(left, right, is_compat=self._is_proper_subtype) elif isinstance(right, Overloaded): - return all(self._is_proper_subtype(left, item) - for item in right.items) + return all(self._is_proper_subtype(left, item) for item in right.items) elif isinstance(right, Instance): return self._is_proper_subtype(left.fallback, right) elif isinstance(right, TypeType): @@ -1608,7 +1772,7 @@ def visit_tuple_type(self, left: TupleType) -> bool: if not right.args: return False iter_type = get_proper_type(right.args[0]) - if is_named_instance(right, 'builtins.tuple') and isinstance(iter_type, AnyType): + if is_named_instance(right, "builtins.tuple") and isinstance(iter_type, AnyType): # TODO: We shouldn't need this special case. This is currently needed # for isinstance(x, tuple), though it's unclear why. return True @@ -1620,16 +1784,16 @@ def visit_tuple_type(self, left: TupleType) -> bool: for l, r in zip(left.items, right.items): if not self._is_proper_subtype(l, r): return False - return self._is_proper_subtype(mypy.typeops.tuple_fallback(left), - mypy.typeops.tuple_fallback(right)) + return self._is_proper_subtype( + mypy.typeops.tuple_fallback(left), mypy.typeops.tuple_fallback(right) + ) return False def visit_typeddict_type(self, left: TypedDictType) -> bool: right = self.right if isinstance(right, TypedDictType): for name, typ in left.items.items(): - if (name in right.items - and not mypy.sametypes.is_same_type(typ, right.items[name])): + if name in right.items and not mypy.sametypes.is_same_type(typ, right.items[name]): return False for name, typ in right.items.items(): if name not in left.items: @@ -1663,13 +1827,13 @@ def visit_type_type(self, left: TypeType) -> bool: # This is also unsound because of __init__. return right.is_type_obj() and self._is_proper_subtype(left.item, right.ret_type) if isinstance(right, Instance): - if right.type.fullname == 'builtins.type': + if right.type.fullname == "builtins.type": # TODO: Strictly speaking, the type builtins.type is considered equivalent to # Type[Any]. However, this would break the is_proper_subtype check in # conditional_types for cases like isinstance(x, type) when the type # of x is Type[int]. It's unclear what's the right way to address this. return True - if right.type.fullname == 'builtins.object': + if right.type.fullname == "builtins.object": return True item = left.item if isinstance(item, TypeVarType): diff --git a/mypy/suggestions.py b/mypy/suggestions.py index d311d0edde638..3829fc26b84a1 100644 --- a/mypy/suggestions.py +++ b/mypy/suggestions.py @@ -22,44 +22,75 @@ * No understanding of type variables at *all* """ +import itertools +import json +import os +from contextlib import contextmanager from typing import ( - List, Optional, Tuple, Dict, Callable, Union, NamedTuple, TypeVar, Iterator, cast, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, + cast, ) + from typing_extensions import TypedDict -from mypy.state import state -from mypy.types import ( - Type, AnyType, TypeOfAny, CallableType, UnionType, NoneType, Instance, TupleType, - TypeVarType, FunctionLike, UninhabitedType, - TypeStrVisitor, TypeTranslator, - is_optional, remove_optional, ProperType, get_proper_type, - TypedDictType, TypeAliasType -) -from mypy.build import State, Graph +from mypy.build import Graph, State +from mypy.checkexpr import has_any_type, map_actuals_to_formals +from mypy.find_sources import InvalidSourceList, SourceFinder +from mypy.join import join_type_list +from mypy.meet import meet_type_list +from mypy.modulefinder import PYTHON_EXTENSIONS from mypy.nodes import ( - ArgKind, ARG_STAR, ARG_STAR2, FuncDef, MypyFile, SymbolTable, - Decorator, RefExpr, - SymbolNode, TypeInfo, Expression, ReturnStmt, CallExpr, + ARG_STAR, + ARG_STAR2, + ArgKind, + CallExpr, + Decorator, + Expression, + FuncDef, + MypyFile, + RefExpr, + ReturnStmt, + SymbolNode, + SymbolTable, + TypeInfo, reverse_builtin_aliases, ) +from mypy.plugin import FunctionContext, MethodContext, Plugin +from mypy.sametypes import is_same_type from mypy.server.update import FineGrainedBuildManager -from mypy.util import split_target -from mypy.find_sources import SourceFinder, InvalidSourceList -from mypy.modulefinder import PYTHON_EXTENSIONS -from mypy.plugin import Plugin, FunctionContext, MethodContext +from mypy.state import state from mypy.traverser import TraverserVisitor -from mypy.checkexpr import has_any_type, map_actuals_to_formals - -from mypy.join import join_type_list -from mypy.meet import meet_type_list -from mypy.sametypes import is_same_type from mypy.typeops import make_simplified_union - -from contextlib import contextmanager - -import itertools -import json -import os +from mypy.types import ( + AnyType, + CallableType, + FunctionLike, + Instance, + NoneType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeOfAny, + TypeStrVisitor, + TypeTranslator, + TypeVarType, + UninhabitedType, + UnionType, + get_proper_type, + is_optional, + remove_optional, +) +from mypy.util import split_target class PyAnnotateSignature(TypedDict): @@ -80,36 +111,37 @@ class SuggestionPlugin(Plugin): """Plugin that records all calls to a given target.""" def __init__(self, target: str) -> None: - if target.endswith(('.__new__', '.__init__')): - target = target.rsplit('.', 1)[0] + if target.endswith((".__new__", ".__init__")): + target = target.rsplit(".", 1)[0] self.target = target # List of call sites found by dmypy suggest: # (path, line, , , ) self.mystery_hits: List[Callsite] = [] - def get_function_hook(self, fullname: str - ) -> Optional[Callable[[FunctionContext], Type]]: + def get_function_hook(self, fullname: str) -> Optional[Callable[[FunctionContext], Type]]: if fullname == self.target: return self.log else: return None - def get_method_hook(self, fullname: str - ) -> Optional[Callable[[MethodContext], Type]]: + def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], Type]]: if fullname == self.target: return self.log else: return None def log(self, ctx: Union[FunctionContext, MethodContext]) -> Type: - self.mystery_hits.append(Callsite( - ctx.api.path, - ctx.context.line, - ctx.arg_kinds, - ctx.callee_arg_names, - ctx.arg_names, - ctx.arg_types)) + self.mystery_hits.append( + Callsite( + ctx.api.path, + ctx.context.line, + ctx.arg_kinds, + ctx.callee_arg_names, + ctx.arg_names, + ctx.arg_types, + ) + ) return ctx.default_return_type @@ -117,6 +149,7 @@ def log(self, ctx: Union[FunctionContext, MethodContext]) -> Type: # traversing into expressions class ReturnFinder(TraverserVisitor): """Visitor for finding all types returned from a function.""" + def __init__(self, typemap: Dict[Expression, Type]) -> None: self.typemap = typemap self.return_types: List[Type] = [] @@ -142,6 +175,7 @@ class ArgUseFinder(TraverserVisitor): This is extremely simple minded but might be effective anyways. """ + def __init__(self, func: FuncDef, typemap: Dict[Expression, Type]) -> None: self.typemap = typemap self.arg_types: Dict[SymbolNode, List[Type]] = {arg.variable: [] for arg in func.arguments} @@ -155,8 +189,12 @@ def visit_call_expr(self, o: CallExpr) -> None: return formal_to_actual = map_actuals_to_formals( - o.arg_kinds, o.arg_names, typ.arg_kinds, typ.arg_names, - lambda n: AnyType(TypeOfAny.special_form)) + o.arg_kinds, + o.arg_names, + typ.arg_kinds, + typ.arg_names, + lambda n: AnyType(TypeOfAny.special_form), + ) for i, args in enumerate(formal_to_actual): for arg_idx in args: @@ -204,16 +242,18 @@ def is_implicit_any(typ: Type) -> bool: class SuggestionEngine: """Engine for finding call sites and suggesting signatures.""" - def __init__(self, fgmanager: FineGrainedBuildManager, - *, - json: bool, - no_errors: bool = False, - no_any: bool = False, - try_text: bool = False, - flex_any: Optional[float] = None, - use_fixme: Optional[str] = None, - max_guesses: Optional[int] = None - ) -> None: + def __init__( + self, + fgmanager: FineGrainedBuildManager, + *, + json: bool, + no_errors: bool = False, + no_any: bool = False, + try_text: bool = False, + flex_any: Optional[float] = None, + use_fixme: Optional[str] = None, + max_guesses: Optional[int] = None, + ) -> None: self.fgmanager = fgmanager self.manager = fgmanager.manager self.plugin = self.manager.plugin @@ -249,10 +289,14 @@ def suggest_callsites(self, function: str) -> str: with self.restore_after(mod): callsites, _ = self.get_callsites(node) - return '\n'.join(dedup( - [f"{path}:{line}: {self.format_args(arg_kinds, arg_names, arg_types)}" - for path, line, arg_kinds, _, arg_names, arg_types in callsites] - )) + return "\n".join( + dedup( + [ + f"{path}:{line}: {self.format_args(arg_kinds, arg_names, arg_types)}" + for path, line, arg_kinds, _, arg_names, arg_types in callsites + ] + ) + ) @contextmanager def restore_after(self, module: str) -> Iterator[None]: @@ -288,7 +332,8 @@ def get_trivial_type(self, fdef: FuncDef) -> CallableType: fdef.arg_kinds, fdef.arg_names, AnyType(TypeOfAny.suggestion_engine), - self.named_type('builtins.function')) + self.named_type("builtins.function"), + ) def get_starting_type(self, fdef: FuncDef) -> CallableType: if isinstance(fdef.type, CallableType): @@ -296,10 +341,14 @@ def get_starting_type(self, fdef: FuncDef) -> CallableType: else: return self.get_trivial_type(fdef) - def get_args(self, is_method: bool, - base: CallableType, defaults: List[Optional[Type]], - callsites: List[Callsite], - uses: List[List[Type]]) -> List[List[Type]]: + def get_args( + self, + is_method: bool, + base: CallableType, + defaults: List[Optional[Type]], + callsites: List[Callsite], + uses: List[List[Type]], + ) -> List[List[Type]]: """Produce a list of type suggestions for each argument type.""" types: List[List[Type]] = [] for i in range(len(base.arg_kinds)): @@ -328,10 +377,12 @@ def get_args(self, is_method: bool, arg_types = [] - if (all_arg_types - and all(isinstance(get_proper_type(tp), NoneType) for tp in all_arg_types)): + if all_arg_types and all( + isinstance(get_proper_type(tp), NoneType) for tp in all_arg_types + ): arg_types.append( - UnionType.make_union([all_arg_types[0], AnyType(TypeOfAny.explicit)])) + UnionType.make_union([all_arg_types[0], AnyType(TypeOfAny.explicit)]) + ) elif all_arg_types: arg_types.extend(generate_type_combinations(all_arg_types)) else: @@ -356,9 +407,14 @@ def add_adjustments(self, typs: List[Type]) -> List[Type]: translator = StrToText(self.named_type) return dedup(typs + [tp.accept(translator) for tp in typs]) - def get_guesses(self, is_method: bool, base: CallableType, defaults: List[Optional[Type]], - callsites: List[Callsite], - uses: List[List[Type]]) -> List[CallableType]: + def get_guesses( + self, + is_method: bool, + base: CallableType, + defaults: List[Optional[Type]], + callsites: List[Callsite], + uses: List[List[Type]], + ) -> List[CallableType]: """Compute a list of guesses for a function's type. This focuses just on the argument types, and doesn't change the provided return type. @@ -391,7 +447,8 @@ def filter_options( Currently the only option is filtering based on Any prevalance.""" return [ - t for t in guesses + t + for t in guesses if self.flex_any is None or any_score_callable(t, is_method, ignore_return) >= self.flex_any ] @@ -404,8 +461,7 @@ def find_best(self, func: FuncDef, guesses: List[CallableType]) -> Tuple[Callabl if not guesses: raise SuggestionFailure("No guesses that match criteria!") errors = {guess: self.try_type(func, guess) for guess in guesses} - best = min(guesses, - key=lambda s: (count_errors(errors[s]), self.score_callable(s))) + best = min(guesses, key=lambda s: (count_errors(errors[s]), self.score_callable(s))) return best, count_errors(errors[best]) def get_guesses_from_parent(self, node: FuncDef) -> List[CallableType]: @@ -469,18 +525,20 @@ def get_suggestion(self, mod: str, node: FuncDef) -> PyAnnotateSignature: return self.pyannotate_signature(mod, is_method, best) - def format_args(self, - arg_kinds: List[List[ArgKind]], - arg_names: List[List[Optional[str]]], - arg_types: List[List[Type]]) -> str: + def format_args( + self, + arg_kinds: List[List[ArgKind]], + arg_names: List[List[Optional[str]]], + arg_types: List[List[Type]], + ) -> str: args: List[str] = [] for i in range(len(arg_types)): for kind, name, typ in zip(arg_kinds[i], arg_names[i], arg_types[i]): arg = self.format_type(None, typ) if kind == ARG_STAR: - arg = '*' + arg + arg = "*" + arg elif kind == ARG_STAR2: - arg = '**' + arg + arg = "**" + arg elif kind.is_named(): if name: arg = f"{name}={arg}" @@ -497,17 +555,18 @@ def find_node(self, key: str) -> Tuple[str, str, FuncDef]: """ # TODO: Also return OverloadedFuncDef -- currently these are ignored. node: Optional[SymbolNode] = None - if ':' in key: - if key.count(':') > 1: + if ":" in key: + if key.count(":") > 1: raise SuggestionFailure( - 'Malformed location for function: {}. Must be either' - ' package.module.Class.method or path/to/file.py:line'.format(key)) - file, line = key.split(':') + "Malformed location for function: {}. Must be either" + " package.module.Class.method or path/to/file.py:line".format(key) + ) + file, line = key.split(":") if not line.isdigit(): - raise SuggestionFailure(f'Line number must be a number. Got {line}') + raise SuggestionFailure(f"Line number must be a number. Got {line}") line_number = int(line) modname, node = self.find_node_by_file_and_line(file, line_number) - tail = node.fullname[len(modname) + 1:] # add one to account for '.' + tail = node.fullname[len(modname) + 1 :] # add one to account for '.' else: target = split_target(self.fgmanager.graph, key) if not target: @@ -538,23 +597,26 @@ def find_node_by_module_and_name(self, modname: str, tail: str) -> Optional[Symb names: SymbolTable = tree.names # Look through any classes - components = tail.split('.') + components = tail.split(".") for i, component in enumerate(components[:-1]): if component not in names: - raise SuggestionFailure("Unknown class %s.%s" % - (modname, '.'.join(components[:i + 1]))) + raise SuggestionFailure( + "Unknown class %s.%s" % (modname, ".".join(components[: i + 1])) + ) node: Optional[SymbolNode] = names[component].node if not isinstance(node, TypeInfo): - raise SuggestionFailure("Object %s.%s is not a class" % - (modname, '.'.join(components[:i + 1]))) + raise SuggestionFailure( + "Object %s.%s is not a class" % (modname, ".".join(components[: i + 1])) + ) names = node.names # Look for the actual function/method funcname = components[-1] if funcname not in names: - key = modname + '.' + tail - raise SuggestionFailure("Unknown %s %s" % - ("method" if len(components) > 1 else "function", key)) + key = modname + "." + tail + raise SuggestionFailure( + "Unknown %s %s" % ("method" if len(components) > 1 else "function", key) + ) return names[funcname].node def find_node_by_file_and_line(self, file: str, line: int) -> Tuple[str, SymbolNode]: @@ -565,13 +627,13 @@ def find_node_by_file_and_line(self, file: str, line: int) -> Tuple[str, SymbolN Return module id and the node found. Raise SuggestionFailure if can't find one. """ if not any(file.endswith(ext) for ext in PYTHON_EXTENSIONS): - raise SuggestionFailure('Source file is not a Python file') + raise SuggestionFailure("Source file is not a Python file") try: modname, _ = self.finder.crawl_up(os.path.normpath(file)) except InvalidSourceList as e: - raise SuggestionFailure('Invalid source file name: ' + file) from e + raise SuggestionFailure("Invalid source file name: " + file) from e if modname not in self.graph: - raise SuggestionFailure('Unknown module: ' + modname) + raise SuggestionFailure("Unknown module: " + modname) # We must be sure about any edits in this file as this might affect the line numbers. tree = self.ensure_loaded(self.fgmanager.graph[modname], force=True) node: Optional[SymbolNode] = None @@ -589,27 +651,30 @@ def find_node_by_file_and_line(self, file: str, line: int) -> Tuple[str, SymbolN closest_line = sym_line node = sym.node if not node: - raise SuggestionFailure(f'Cannot find a function at line {line}') + raise SuggestionFailure(f"Cannot find a function at line {line}") return modname, node def extract_from_decorator(self, node: Decorator) -> Optional[FuncDef]: for dec in node.decorators: typ = None - if (isinstance(dec, RefExpr) - and isinstance(dec.node, FuncDef)): + if isinstance(dec, RefExpr) and isinstance(dec.node, FuncDef): typ = dec.node.type - elif (isinstance(dec, CallExpr) - and isinstance(dec.callee, RefExpr) - and isinstance(dec.callee.node, FuncDef) - and isinstance(dec.callee.node.type, CallableType)): + elif ( + isinstance(dec, CallExpr) + and isinstance(dec.callee, RefExpr) + and isinstance(dec.callee.node, FuncDef) + and isinstance(dec.callee.node.type, CallableType) + ): typ = get_proper_type(dec.callee.node.type.ret_type) if not isinstance(typ, FunctionLike): return None for ct in typ.items: - if not (len(ct.arg_types) == 1 - and isinstance(ct.arg_types[0], TypeVarType) - and ct.arg_types[0] == ct.ret_type): + if not ( + len(ct.arg_types) == 1 + and isinstance(ct.arg_types[0], TypeVarType) + and ct.arg_types[0] == ct.ret_type + ): return None return node.func @@ -650,12 +715,13 @@ def ensure_loaded(self, state: State, force: bool = False) -> MypyFile: def named_type(self, s: str) -> Instance: return self.manager.semantic_analyzer.named_type(s) - def json_suggestion(self, mod: str, func_name: str, node: FuncDef, - suggestion: PyAnnotateSignature) -> str: + def json_suggestion( + self, mod: str, func_name: str, node: FuncDef, suggestion: PyAnnotateSignature + ) -> str: """Produce a json blob for a suggestion suitable for application by pyannotate.""" # pyannotate irritatingly drops class names for class and static methods if node.is_class or node.is_static: - func_name = func_name.split('.', 1)[-1] + func_name = func_name.split(".", 1)[-1] # pyannotate works with either paths relative to where the # module is rooted or with absolute paths. We produce absolute @@ -663,25 +729,22 @@ def json_suggestion(self, mod: str, func_name: str, node: FuncDef, path = os.path.abspath(self.graph[mod].xpath) obj = { - 'signature': suggestion, - 'line': node.line, - 'path': path, - 'func_name': func_name, - 'samples': 0 + "signature": suggestion, + "line": node.line, + "path": path, + "func_name": func_name, + "samples": 0, } return json.dumps([obj], sort_keys=True) def pyannotate_signature( - self, - cur_module: Optional[str], - is_method: bool, - typ: CallableType + self, cur_module: Optional[str], is_method: bool, typ: CallableType ) -> PyAnnotateSignature: """Format a callable type as a pyannotate dict""" start = int(is_method) return { - 'arg_types': [self.format_type(cur_module, t) for t in typ.arg_types[start:]], - 'return_type': self.format_type(cur_module, typ.ret_type), + "arg_types": [self.format_type(cur_module, t) for t in typ.arg_types[start:]], + "return_type": self.format_type(cur_module, typ.ret_type), } def format_signature(self, sig: PyAnnotateSignature) -> str: @@ -712,13 +775,14 @@ def score_type(self, t: Type, arg_pos: bool) -> int: return 10 if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)): return 10 - if self.try_text and isinstance(t, Instance) and t.type.fullname == 'builtins.str': + if self.try_text and isinstance(t, Instance) and t.type.fullname == "builtins.str": return 1 return 0 def score_callable(self, t: CallableType) -> int: - return (sum(self.score_type(x, arg_pos=True) for x in t.arg_types) + - self.score_type(t.ret_type, arg_pos=False)) + return sum(self.score_type(x, arg_pos=True) for x in t.arg_types) + self.score_type( + t.ret_type, arg_pos=False + ) def any_score_type(ut: Type, arg_pos: bool) -> float: @@ -746,7 +810,7 @@ def any_score_type(ut: Type, arg_pos: bool) -> float: def any_score_callable(t: CallableType, is_method: bool, ignore_return: bool) -> float: # Ignore the first argument of methods - scores = [any_score_type(x, arg_pos=True) for x in t.arg_types[int(is_method):]] + scores = [any_score_type(x, arg_pos=True) for x in t.arg_types[int(is_method) :]] # Return type counts twice (since it spreads type information), unless it is # None in which case it does not count at all. (Though it *does* still count # if there are no arguments.) @@ -763,8 +827,8 @@ def is_tricky_callable(t: CallableType) -> bool: class TypeFormatter(TypeStrVisitor): - """Visitor used to format types - """ + """Visitor used to format types""" + # TODO: Probably a lot def __init__(self, module: Optional[str], graph: Graph) -> None: super().__init__() @@ -780,7 +844,7 @@ def visit_any(self, t: AnyType) -> str: def visit_instance(self, t: Instance) -> str: s = t.type.fullname or t.type.name or None if s is None: - return '' + return "" if s in reverse_builtin_aliases: s = reverse_builtin_aliases[s] @@ -792,31 +856,31 @@ def visit_instance(self, t: Instance) -> str: # to point to the current module. This helps the annotation tool avoid # inserting redundant imports when a type has been reexported. if self.module: - parts = obj.split('.') # need to split the object part if it is a nested class + parts = obj.split(".") # need to split the object part if it is a nested class tree = self.graph[self.module].tree if tree and parts[0] in tree.names: mod = self.module - if (mod, obj) == ('builtins', 'tuple'): - mod, obj = 'typing', 'Tuple[' + t.args[0].accept(self) + ', ...]' + if (mod, obj) == ("builtins", "tuple"): + mod, obj = "typing", "Tuple[" + t.args[0].accept(self) + ", ...]" elif t.args: - obj += f'[{self.list_str(t.args)}]' + obj += f"[{self.list_str(t.args)}]" - if mod_obj == ('builtins', 'unicode'): - return 'Text' - elif mod == 'builtins': + if mod_obj == ("builtins", "unicode"): + return "Text" + elif mod == "builtins": return obj else: - delim = '.' if '.' not in obj else ':' + delim = "." if "." not in obj else ":" return mod + delim + obj def visit_tuple_type(self, t: TupleType) -> str: if t.partial_fallback and t.partial_fallback.type: fallback_name = t.partial_fallback.type.fullname - if fallback_name != 'builtins.tuple': + if fallback_name != "builtins.tuple": return t.partial_fallback.accept(self) s = self.list_str(t.items) - return f'Tuple[{s}]' + return f"Tuple[{s}]" def visit_uninhabited_type(self, t: UninhabitedType) -> str: return "Any" @@ -847,22 +911,22 @@ def visit_callable_type(self, t: CallableType) -> str: class StrToText(TypeTranslator): def __init__(self, named_type: Callable[[str], Instance]) -> None: - self.text_type = named_type('builtins.unicode') + self.text_type = named_type("builtins.unicode") def visit_type_alias_type(self, t: TypeAliasType) -> Type: exp_t = get_proper_type(t) - if isinstance(exp_t, Instance) and exp_t.type.fullname == 'builtins.str': + if isinstance(exp_t, Instance) and exp_t.type.fullname == "builtins.str": return self.text_type return t.copy_modified(args=[a.accept(self) for a in t.args]) def visit_instance(self, t: Instance) -> Type: - if t.type.fullname == 'builtins.str': + if t.type.fullname == "builtins.str": return self.text_type else: return super().visit_instance(t) -TType = TypeVar('TType', bound=Type) +TType = TypeVar("TType", bound=Type) def make_suggestion_anys(t: TType) -> TType: @@ -900,7 +964,7 @@ def generate_type_combinations(types: List[Type]) -> List[Type]: def count_errors(msgs: List[str]) -> int: - return len([x for x in msgs if ' error: ' in x]) + return len([x for x in msgs if " error: " in x]) def refine_type(ti: Type, si: Type) -> Type: @@ -1009,7 +1073,7 @@ def refine_callable(t: CallableType, s: CallableType) -> CallableType: ) -T = TypeVar('T') +T = TypeVar("T") def dedup(old: List[T]) -> List[T]: diff --git a/mypy/test/config.py b/mypy/test/config.py index 0c2dfc9a21a9f..00e0edc2918eb 100644 --- a/mypy/test/config.py +++ b/mypy/test/config.py @@ -1,6 +1,6 @@ import os.path -provided_prefix = os.getenv('MYPY_TEST_PREFIX', None) +provided_prefix = os.getenv("MYPY_TEST_PREFIX", None) if provided_prefix: PREFIX = provided_prefix else: @@ -8,13 +8,13 @@ PREFIX = os.path.dirname(os.path.dirname(this_file_dir)) # Location of test data files such as test case descriptions. -test_data_prefix = os.path.join(PREFIX, 'test-data', 'unit') -package_path = os.path.join(PREFIX, 'test-data', 'packages') +test_data_prefix = os.path.join(PREFIX, "test-data", "unit") +package_path = os.path.join(PREFIX, "test-data", "packages") # Temp directory used for the temp files created when running test cases. # This is *within* the tempfile.TemporaryDirectory that is chroot'ed per testcase. # It is also hard-coded in numerous places, so don't change it. -test_temp_dir = 'tmp' +test_temp_dir = "tmp" # The PEP 561 tests do a bunch of pip installs which, even though they operate # on distinct temporary virtual environments, run into race conditions on shared @@ -22,5 +22,5 @@ # FileLock courtesy of the tox-dev/py-filelock package. # Ref. https://github.com/python/mypy/issues/12615 # Ref. mypy/test/testpep561.py -pip_lock = os.path.join(package_path, '.pip_lock') +pip_lock = os.path.join(package_path, ".pip_lock") pip_timeout = 60 diff --git a/mypy/test/data.py b/mypy/test/data.py index 18d25fc74c04b..de84736ac34c9 100644 --- a/mypy/test/data.py +++ b/mypy/test/data.py @@ -1,19 +1,19 @@ """Utilities for processing .test files containing test case descriptions.""" -import os.path import os -import tempfile +import os.path import posixpath import re import shutil -from abc import abstractmethod import sys +import tempfile +from abc import abstractmethod +from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Pattern, Set, Tuple, Union import pytest -from typing import List, Tuple, Set, Optional, Iterator, Any, Dict, NamedTuple, Union, Pattern from typing_extensions import Final -from mypy.test.config import test_data_prefix, test_temp_dir, PREFIX +from mypy.test.config import PREFIX, test_data_prefix, test_temp_dir root_dir = os.path.normpath(PREFIX) @@ -39,7 +39,7 @@ class DeleteFile(NamedTuple): FileOperation = Union[UpdateFile, DeleteFile] -def parse_test_case(case: 'DataDrivenTestCase') -> None: +def parse_test_case(case: "DataDrivenTestCase") -> None: """Parse and prepare a single case from suite with test case descriptions. This method is part of the setup phase, just before the test case is run. @@ -68,55 +68,55 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None: # optionally followed by lines of text. item = first_item = test_items[0] for item in test_items[1:]: - if item.id in {'file', 'outfile', 'outfile-re'}: + if item.id in {"file", "outfile", "outfile-re"}: # Record an extra file needed for the test case. assert item.arg is not None - contents = expand_variables('\n'.join(item.data)) + contents = expand_variables("\n".join(item.data)) file_entry = (join(base_path, item.arg), contents) - if item.id == 'file': + if item.id == "file": files.append(file_entry) - elif item.id == 'outfile-re': + elif item.id == "outfile-re": output_files.append((file_entry[0], re.compile(file_entry[1].rstrip(), re.S))) else: output_files.append(file_entry) - elif item.id in ('builtins', 'builtins_py2'): + elif item.id in ("builtins", "builtins_py2"): # Use an alternative stub file for the builtins module. assert item.arg is not None mpath = join(os.path.dirname(case.file), item.arg) - fnam = 'builtins.pyi' if item.id == 'builtins' else '__builtin__.pyi' - with open(mpath, encoding='utf8') as f: + fnam = "builtins.pyi" if item.id == "builtins" else "__builtin__.pyi" + with open(mpath, encoding="utf8") as f: files.append((join(base_path, fnam), f.read())) - elif item.id == 'typing': + elif item.id == "typing": # Use an alternative stub file for the typing module. assert item.arg is not None src_path = join(os.path.dirname(case.file), item.arg) - with open(src_path, encoding='utf8') as f: - files.append((join(base_path, 'typing.pyi'), f.read())) - elif re.match(r'stale[0-9]*$', item.id): - passnum = 1 if item.id == 'stale' else int(item.id[len('stale'):]) + with open(src_path, encoding="utf8") as f: + files.append((join(base_path, "typing.pyi"), f.read())) + elif re.match(r"stale[0-9]*$", item.id): + passnum = 1 if item.id == "stale" else int(item.id[len("stale") :]) assert passnum > 0 - modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')}) + modules = set() if item.arg is None else {t.strip() for t in item.arg.split(",")} stale_modules[passnum] = modules - elif re.match(r'rechecked[0-9]*$', item.id): - passnum = 1 if item.id == 'rechecked' else int(item.id[len('rechecked'):]) + elif re.match(r"rechecked[0-9]*$", item.id): + passnum = 1 if item.id == "rechecked" else int(item.id[len("rechecked") :]) assert passnum > 0 - modules = (set() if item.arg is None else {t.strip() for t in item.arg.split(',')}) + modules = set() if item.arg is None else {t.strip() for t in item.arg.split(",")} rechecked_modules[passnum] = modules - elif re.match(r'targets[0-9]*$', item.id): - passnum = 1 if item.id == 'targets' else int(item.id[len('targets'):]) + elif re.match(r"targets[0-9]*$", item.id): + passnum = 1 if item.id == "targets" else int(item.id[len("targets") :]) assert passnum > 0 - reprocessed = [] if item.arg is None else [t.strip() for t in item.arg.split(',')] + reprocessed = [] if item.arg is None else [t.strip() for t in item.arg.split(",")] targets[passnum] = reprocessed - elif item.id == 'delete': + elif item.id == "delete": # File/directory to delete during a multi-step test case assert item.arg is not None - m = re.match(r'(.*)\.([0-9]+)$', item.arg) - assert m, f'Invalid delete section: {item.arg}' + m = re.match(r"(.*)\.([0-9]+)$", item.arg) + assert m, f"Invalid delete section: {item.arg}" num = int(m.group(2)) assert num >= 2, f"Can't delete during step {num}" full = join(base_path, m.group(1)) deleted_paths.setdefault(num, set()).add(full) - elif re.match(r'out[0-9]*$', item.id): + elif re.match(r"out[0-9]*$", item.id): if item.arg is None: args = [] else: @@ -124,14 +124,13 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None: version_check = True for arg in args: - if arg == 'skip-path-normalization': + if arg == "skip-path-normalization": normalize_output = False if arg.startswith("version"): compare_op = arg[7:9] if compare_op not in {">=", "=="}: raise ValueError( - "{}, line {}: Only >= and == version checks are currently supported" - .format( + "{}, line {}: Only >= and == version checks are currently supported".format( case.file, item.line ) ) @@ -141,55 +140,61 @@ def parse_test_case(case: 'DataDrivenTestCase') -> None: except ValueError: raise ValueError( '{}, line {}: "{}" is not a valid python version'.format( - case.file, item.line, version_str)) + case.file, item.line, version_str + ) + ) if compare_op == ">=": version_check = sys.version_info >= version elif compare_op == "==": if not 1 < len(version) < 4: raise ValueError( - '{}, line {}: Only minor or patch version checks ' + "{}, line {}: Only minor or patch version checks " 'are currently supported with "==": "{}"'.format( case.file, item.line, version_str ) ) - version_check = sys.version_info[:len(version)] == version + version_check = sys.version_info[: len(version)] == version if version_check: tmp_output = [expand_variables(line) for line in item.data] - if os.path.sep == '\\' and normalize_output: + if os.path.sep == "\\" and normalize_output: tmp_output = [fix_win_path(line) for line in tmp_output] - if item.id == 'out' or item.id == 'out1': + if item.id == "out" or item.id == "out1": output = tmp_output else: - passnum = int(item.id[len('out'):]) + passnum = int(item.id[len("out") :]) assert passnum > 1 output2[passnum] = tmp_output out_section_missing = False - elif item.id == 'triggered' and item.arg is None: + elif item.id == "triggered" and item.arg is None: triggered = item.data else: raise ValueError( - f'Invalid section header {item.id} in {case.file} at line {item.line}') + f"Invalid section header {item.id} in {case.file} at line {item.line}" + ) if out_section_missing: - raise ValueError( - f'{case.file}, line {first_item.line}: Required output section not found') + raise ValueError(f"{case.file}, line {first_item.line}: Required output section not found") for passnum in stale_modules.keys(): if passnum not in rechecked_modules: # If the set of rechecked modules isn't specified, make it the same as the set # of modules with a stale public interface. rechecked_modules[passnum] = stale_modules[passnum] - if (passnum in stale_modules - and passnum in rechecked_modules - and not stale_modules[passnum].issubset(rechecked_modules[passnum])): + if ( + passnum in stale_modules + and passnum in rechecked_modules + and not stale_modules[passnum].issubset(rechecked_modules[passnum]) + ): raise ValueError( - ('Stale modules after pass {} must be a subset of rechecked ' - 'modules ({}:{})').format(passnum, case.file, first_item.line)) + ( + "Stale modules after pass {} must be a subset of rechecked " "modules ({}:{})" + ).format(passnum, case.file, first_item.line) + ) input = first_item.data - expand_errors(input, output, 'main') + expand_errors(input, output, "main") for file_path, contents in files: - expand_errors(contents.split('\n'), output, file_path) + expand_errors(contents.split("\n"), output, file_path) case.input = input case.output = output @@ -216,7 +221,7 @@ class DataDrivenTestCase(pytest.Item): output2: Dict[int, List[str]] # Output for runs 2+, indexed by run number # full path of test suite - file = '' + file = "" line = 0 # (file path, file content) tuples @@ -235,25 +240,28 @@ class DataDrivenTestCase(pytest.Item): deleted_paths: Dict[int, Set[str]] # Mapping run number -> paths triggered: List[str] # Active triggers (one line per incremental step) - def __init__(self, - parent: 'DataSuiteCollector', - suite: 'DataSuite', - file: str, - name: str, - writescache: bool, - only_when: str, - platform: Optional[str], - skip: bool, - xfail: bool, - data: str, - line: int) -> None: + def __init__( + self, + parent: "DataSuiteCollector", + suite: "DataSuite", + file: str, + name: str, + writescache: bool, + only_when: str, + platform: Optional[str], + skip: bool, + xfail: bool, + data: str, + line: int, + ) -> None: super().__init__(name, parent) self.suite = suite self.file = file self.writescache = writescache self.only_when = only_when - if ((platform == 'windows' and sys.platform != 'win32') - or (platform == 'posix' and sys.platform == 'win32')): + if (platform == "windows" and sys.platform != "win32") or ( + platform == "posix" and sys.platform == "win32" + ): skip = True self.skip = skip self.xfail = xfail @@ -269,7 +277,7 @@ def runtest(self) -> None: elif self.xfail: self.add_marker(pytest.mark.xfail) parent = self.getparent(DataSuiteCollector) - assert parent is not None, 'Should not happen' + assert parent is not None, "Should not happen" suite = parent.obj() suite.setup() try: @@ -290,7 +298,7 @@ def runtest(self) -> None: def setup(self) -> None: parse_test_case(case=self) self.old_cwd = os.getcwd() - self.tmpdir = tempfile.TemporaryDirectory(prefix='mypy-test-') + self.tmpdir = tempfile.TemporaryDirectory(prefix="mypy-test-") os.chdir(self.tmpdir.name) os.mkdir(test_temp_dir) @@ -298,13 +306,13 @@ def setup(self) -> None: steps: Dict[int, List[FileOperation]] = {} for path, content in self.files: - m = re.match(r'.*\.([0-9]+)$', path) + m = re.match(r".*\.([0-9]+)$", path) if m: # Skip writing subsequent incremental steps - rather # store them as operations. num = int(m.group(1)) assert num >= 2 - target_path = re.sub(r'\.[0-9]+$', '', path) + target_path = re.sub(r"\.[0-9]+$", "", path) module = module_from_path(target_path) operation = UpdateFile(module, content, target_path) steps.setdefault(num, []).append(operation) @@ -312,7 +320,7 @@ def setup(self) -> None: # Write the first incremental steps dir = os.path.dirname(path) os.makedirs(dir, exist_ok=True) - with open(path, 'w', encoding='utf8') as f: + with open(path, "w", encoding="utf8") as f: f.write(content) for num, paths in self.deleted_paths.items(): @@ -324,8 +332,7 @@ def setup(self) -> None: self.steps = [steps.get(num, []) for num in range(2, max_step + 1)] def teardown(self) -> None: - assert self.old_cwd is not None and self.tmpdir is not None, \ - "test was not properly set up" + assert self.old_cwd is not None and self.tmpdir is not None, "test was not properly set up" os.chdir(self.old_cwd) try: self.tmpdir.cleanup() @@ -346,7 +353,7 @@ def repr_failure(self, excinfo: Any, style: Optional[Any] = None) -> str: excrepr = excinfo.exconly() else: self.parent._prunetraceback(excinfo) - excrepr = excinfo.getrepr(style='short') + excrepr = excinfo.getrepr(style="short") return f"data: {self.file}:{self.line}:\n{excrepr}" @@ -363,12 +370,12 @@ def find_steps(self) -> List[List[FileOperation]]: def module_from_path(path: str) -> str: - path = re.sub(r'\.pyi?$', '', path) + path = re.sub(r"\.pyi?$", "", path) # We can have a mix of Unix-style and Windows-style separators. - parts = re.split(r'[/\\]', path) + parts = re.split(r"[/\\]", path) del parts[0] - module = '.'.join(parts) - module = re.sub(r'\.__init__$', '', module) + module = ".".join(parts) + module = re.sub(r"\.__init__$", "", module) return module @@ -386,11 +393,10 @@ class TestItem: # Text data, array of 8-bit strings data: List[str] - file = '' + file = "" line = 0 # Line number in file - def __init__(self, id: str, arg: Optional[str], data: List[str], - line: int) -> None: + def __init__(self, id: str, arg: Optional[str], data: List[str], line: int) -> None: self.id = id self.arg = arg self.data = data @@ -400,7 +406,7 @@ def __init__(self, id: str, arg: Optional[str], data: List[str], def parse_test_data(raw_data: str, name: str) -> List[TestItem]: """Parse a list of lines that represent a sequence of test items.""" - lines = ['', '[case ' + name + ']'] + raw_data.split('\n') + lines = ["", "[case " + name + "]"] + raw_data.split("\n") ret: List[TestItem] = [] data: List[str] = [] @@ -412,7 +418,7 @@ def parse_test_data(raw_data: str, name: str) -> List[TestItem]: while i < len(lines): s = lines[i].strip() - if lines[i].startswith('[') and s.endswith(']'): + if lines[i].startswith("[") and s.endswith("]"): if id: data = collapse_line_continuation(data) data = strip_list(data) @@ -421,15 +427,15 @@ def parse_test_data(raw_data: str, name: str) -> List[TestItem]: i0 = i id = s[1:-1] arg = None - if ' ' in id: - arg = id[id.index(' ') + 1:] - id = id[:id.index(' ')] + if " " in id: + arg = id[id.index(" ") + 1 :] + id = id[: id.index(" ")] data = [] - elif lines[i].startswith('\\['): + elif lines[i].startswith("\\["): data.append(lines[i][1:]) - elif not lines[i].startswith('--'): + elif not lines[i].startswith("--"): data.append(lines[i]) - elif lines[i].startswith('----'): + elif lines[i].startswith("----"): data.append(lines[i][2:]) i += 1 @@ -452,9 +458,9 @@ def strip_list(l: List[str]) -> List[str]: r: List[str] = [] for s in l: # Strip spaces at end of line - r.append(re.sub(r'\s+$', '', s)) + r.append(re.sub(r"\s+$", "", s)) - while len(r) > 0 and r[-1] == '': + while len(r) > 0 and r[-1] == "": r.pop() return r @@ -464,17 +470,17 @@ def collapse_line_continuation(l: List[str]) -> List[str]: r: List[str] = [] cont = False for s in l: - ss = re.sub(r'\\$', '', s) + ss = re.sub(r"\\$", "", s) if cont: - r[-1] += re.sub('^ +', '', ss) + r[-1] += re.sub("^ +", "", ss) else: r.append(ss) - cont = s.endswith('\\') + cont = s.endswith("\\") return r def expand_variables(s: str) -> str: - return s.replace('', root_dir) + return s.replace("", root_dir) def expand_errors(input: List[str], output: List[str], fnam: str) -> None: @@ -486,25 +492,24 @@ def expand_errors(input: List[str], output: List[str], fnam: str) -> None: for i in range(len(input)): # The first in the split things isn't a comment - for possible_err_comment in input[i].split(' # ')[1:]: + for possible_err_comment in input[i].split(" # ")[1:]: m = re.search( - r'^([ENW]):((?P\d+):)? (?P.*)$', - possible_err_comment.strip()) + r"^([ENW]):((?P\d+):)? (?P.*)$", possible_err_comment.strip() + ) if m: - if m.group(1) == 'E': - severity = 'error' - elif m.group(1) == 'N': - severity = 'note' - elif m.group(1) == 'W': - severity = 'warning' - col = m.group('col') - message = m.group('message') - message = message.replace('\\#', '#') # adds back escaped # character + if m.group(1) == "E": + severity = "error" + elif m.group(1) == "N": + severity = "note" + elif m.group(1) == "W": + severity = "warning" + col = m.group("col") + message = m.group("message") + message = message.replace("\\#", "#") # adds back escaped # character if col is None: - output.append( - f'{fnam}:{i + 1}: {severity}: {message}') + output.append(f"{fnam}:{i + 1}: {severity}: {message}") else: - output.append(f'{fnam}:{i + 1}:{col}: {severity}: {message}') + output.append(f"{fnam}:{i + 1}:{col}: {severity}: {message}") def fix_win_path(line: str) -> str: @@ -512,14 +517,13 @@ def fix_win_path(line: str) -> str: E.g. foo\bar.py -> foo/bar.py. """ - line = line.replace(root_dir, root_dir.replace('\\', '/')) - m = re.match(r'^([\S/]+):(\d+:)?(\s+.*)', line) + line = line.replace(root_dir, root_dir.replace("\\", "/")) + m = re.match(r"^([\S/]+):(\d+:)?(\s+.*)", line) if not m: return line else: filename, lineno, message = m.groups() - return '{}:{}{}'.format(filename.replace('\\', '/'), - lineno or '', message) + return "{}:{}{}".format(filename.replace("\\", "/"), lineno or "", message) def fix_cobertura_filename(line: str) -> str: @@ -530,9 +534,9 @@ def fix_cobertura_filename(line: str) -> str: m = re.search(r' str: # This function name is special to pytest. See # https://docs.pytest.org/en/latest/reference.html#initialization-hooks def pytest_addoption(parser: Any) -> None: - group = parser.getgroup('mypy') - group.addoption('--update-data', action='store_true', default=False, - help='Update test data to reflect actual output' - ' (supported only for certain tests)') - group.addoption('--save-failures-to', default=None, - help='Copy the temp directories from failing tests to a target directory') - group.addoption('--mypy-verbose', action='count', - help='Set the verbose flag when creating mypy Options') - group.addoption('--mypyc-showc', action='store_true', default=False, - help='Display C code on mypyc test failures') + group = parser.getgroup("mypy") + group.addoption( + "--update-data", + action="store_true", + default=False, + help="Update test data to reflect actual output" " (supported only for certain tests)", + ) + group.addoption( + "--save-failures-to", + default=None, + help="Copy the temp directories from failing tests to a target directory", + ) + group.addoption( + "--mypy-verbose", action="count", help="Set the verbose flag when creating mypy Options" + ) + group.addoption( + "--mypyc-showc", + action="store_true", + default=False, + help="Display C code on mypyc test failures", + ) group.addoption( "--mypyc-debug", default=None, @@ -566,8 +581,7 @@ def pytest_addoption(parser: Any) -> None: # This function name is special to pytest. See # http://doc.pytest.org/en/latest/writing_plugins.html#collection-hooks -def pytest_pycollect_makeitem(collector: Any, name: str, - obj: object) -> 'Optional[Any]': +def pytest_pycollect_makeitem(collector: Any, name: str, obj: object) -> "Optional[Any]": """Called by pytest on each object in modules configured in conftest.py files. collector is pytest.Collector, returns Optional[pytest.Class] @@ -579,39 +593,44 @@ def pytest_pycollect_makeitem(collector: Any, name: str, # The collect method of the returned DataSuiteCollector instance will be called later, # with self.obj being obj. return DataSuiteCollector.from_parent( # type: ignore[no-untyped-call] - parent=collector, name=name, + parent=collector, name=name ) return None -def split_test_cases(parent: 'DataFileCollector', suite: 'DataSuite', - file: str) -> Iterator['DataDrivenTestCase']: +def split_test_cases( + parent: "DataFileCollector", suite: "DataSuite", file: str +) -> Iterator["DataDrivenTestCase"]: """Iterate over raw test cases in file, at collection time, ignoring sub items. The collection phase is slow, so any heavy processing should be deferred to after uninteresting tests are filtered (when using -k PATTERN switch). """ - with open(file, encoding='utf-8') as f: + with open(file, encoding="utf-8") as f: data = f.read() # number of groups in the below regex NUM_GROUPS = 7 - cases = re.split(r'^\[case ([a-zA-Z_0-9]+)' - r'(-writescache)?' - r'(-only_when_cache|-only_when_nocache)?' - r'(-posix|-windows)?' - r'(-skip)?' - r'(-xfail)?' - r'\][ \t]*$\n', - data, - flags=re.DOTALL | re.MULTILINE) - line_no = cases[0].count('\n') + 1 + cases = re.split( + r"^\[case ([a-zA-Z_0-9]+)" + r"(-writescache)?" + r"(-only_when_cache|-only_when_nocache)?" + r"(-posix|-windows)?" + r"(-skip)?" + r"(-xfail)?" + r"\][ \t]*$\n", + data, + flags=re.DOTALL | re.MULTILINE, + ) + line_no = cases[0].count("\n") + 1 test_names = set() for i in range(1, len(cases), NUM_GROUPS): - name, writescache, only_when, platform_flag, skip, xfail, data = cases[i:i + NUM_GROUPS] + name, writescache, only_when, platform_flag, skip, xfail, data = cases[i : i + NUM_GROUPS] if name in test_names: - raise RuntimeError('Found a duplicate test name "{}" in {} on line {}'.format( - name, parent.name, line_no, - )) + raise RuntimeError( + 'Found a duplicate test name "{}" in {} on line {}'.format( + name, parent.name, line_no + ) + ) platform = platform_flag[1:] if platform_flag else None yield DataDrivenTestCase.from_parent( parent=parent, @@ -626,21 +645,22 @@ def split_test_cases(parent: 'DataFileCollector', suite: 'DataSuite', data=data, line=line_no, ) - line_no += data.count('\n') + 1 + line_no += data.count("\n") + 1 # Record existing tests to prevent duplicates: test_names.update({name}) class DataSuiteCollector(pytest.Class): - def collect(self) -> Iterator['DataFileCollector']: + def collect(self) -> Iterator["DataFileCollector"]: """Called by pytest on each of the object returned from pytest_pycollect_makeitem""" # obj is the object for which pytest_pycollect_makeitem returned self. suite: DataSuite = self.obj - assert os.path.isdir(suite.data_prefix), \ - f'Test data prefix ({suite.data_prefix}) not set correctly' + assert os.path.isdir( + suite.data_prefix + ), f"Test data prefix ({suite.data_prefix}) not set correctly" for data_file in suite.files: yield DataFileCollector.from_parent(parent=self, name=data_file) @@ -651,18 +671,16 @@ class DataFileCollector(pytest.Collector): More context: https://github.com/python/mypy/issues/11662 """ + parent: DataSuiteCollector @classmethod # We have to fight with pytest here: def from_parent( # type: ignore[override] - cls, - parent: DataSuiteCollector, - *, - name: str, - ) -> 'DataFileCollector': + cls, parent: DataSuiteCollector, *, name: str + ) -> "DataFileCollector": return super().from_parent(parent, name=name) - def collect(self) -> Iterator['DataDrivenTestCase']: + def collect(self) -> Iterator["DataDrivenTestCase"]: yield from split_test_cases( parent=self, suite=self.parent.obj, @@ -672,26 +690,26 @@ def collect(self) -> Iterator['DataDrivenTestCase']: def add_test_name_suffix(name: str, suffix: str) -> str: # Find magic suffix of form "-foobar" (used for things like "-skip"). - m = re.search(r'-[-A-Za-z0-9]+$', name) + m = re.search(r"-[-A-Za-z0-9]+$", name) if m: # Insert suite-specific test name suffix before the magic suffix # which must be the last thing in the test case name since we # are using endswith() checks. magic_suffix = m.group(0) - return name[:-len(magic_suffix)] + suffix + magic_suffix + return name[: -len(magic_suffix)] + suffix + magic_suffix else: return name + suffix def is_incremental(testcase: DataDrivenTestCase) -> bool: - return 'incremental' in testcase.name.lower() or 'incremental' in testcase.file + return "incremental" in testcase.name.lower() or "incremental" in testcase.file def has_stable_flags(testcase: DataDrivenTestCase) -> bool: - if any(re.match(r'# flags[2-9]:', line) for line in testcase.input): + if any(re.match(r"# flags[2-9]:", line) for line in testcase.input): return False for filename, contents in testcase.files: - if os.path.basename(filename).startswith('mypy.ini.'): + if os.path.basename(filename).startswith("mypy.ini."): return False return True @@ -711,7 +729,7 @@ class DataSuite: # Name suffix automatically added to each test case in the suite (can be # used to distinguish test cases in suites that share data files) - test_name_suffix = '' + test_name_suffix = "" def setup(self) -> None: """Setup fixtures (ad-hoc)""" diff --git a/mypy/test/helpers.py b/mypy/test/helpers.py index 2f97a0851941c..a77cc38c7cb8c 100644 --- a/mypy/test/helpers.py +++ b/mypy/test/helpers.py @@ -1,29 +1,25 @@ +import contextlib import os import pathlib import re +import shutil import sys import time -import shutil -import contextlib - -from typing import List, Iterable, Dict, Tuple, Callable, Any, Iterator, Union, Pattern, Optional - -from mypy import defaults -import mypy.api as api - -import pytest +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Pattern, Tuple, Union # Exporting Suite as alias to TestCase for backwards compatibility # TODO: avoid aliasing - import and subclass TestCase directly from unittest import TestCase as Suite # noqa: F401 (re-exporting) +import pytest + +import mypy.api as api +import mypy.version +from mypy import defaults from mypy.main import process_options from mypy.options import Options -from mypy.test.data import ( - DataDrivenTestCase, fix_cobertura_filename, UpdateFile, DeleteFile -) -from mypy.test.config import test_temp_dir, test_data_prefix -import mypy.version +from mypy.test.config import test_data_prefix, test_temp_dir +from mypy.test.data import DataDrivenTestCase, DeleteFile, UpdateFile, fix_cobertura_filename skip = pytest.mark.skip @@ -36,16 +32,14 @@ def run_mypy(args: List[str]) -> None: __tracebackhide__ = True # We must enable site packages even though they could cause problems, # since stubs for typing_extensions live there. - outval, errval, status = api.run(args + ['--show-traceback', - '--no-silence-site-packages']) + outval, errval, status = api.run(args + ["--show-traceback", "--no-silence-site-packages"]) if status != 0: sys.stdout.write(outval) sys.stderr.write(errval) pytest.fail(msg="Sample check failed", pytrace=False) -def assert_string_arrays_equal(expected: List[str], actual: List[str], - msg: str) -> None: +def assert_string_arrays_equal(expected: List[str], actual: List[str], msg: str) -> None: """Assert that two string arrays are equal. We consider "can't" and "cannot" equivalent, by replacing the @@ -63,12 +57,12 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str], num_skip_start = num_skipped_prefix_lines(expected, actual) num_skip_end = num_skipped_suffix_lines(expected, actual) - sys.stderr.write('Expected:\n') + sys.stderr.write("Expected:\n") # If omit some lines at the beginning, indicate it by displaying a line # with '...'. if num_skip_start > 0: - sys.stderr.write(' ...\n') + sys.stderr.write(" ...\n") # Keep track of the first different line. first_diff = -1 @@ -80,40 +74,41 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str], if i >= len(actual) or expected[i] != actual[i]: if first_diff < 0: first_diff = i - sys.stderr.write(f' {expected[i]:<45} (diff)') + sys.stderr.write(f" {expected[i]:<45} (diff)") else: e = expected[i] - sys.stderr.write(' ' + e[:width]) + sys.stderr.write(" " + e[:width]) if len(e) > width: - sys.stderr.write('...') - sys.stderr.write('\n') + sys.stderr.write("...") + sys.stderr.write("\n") if num_skip_end > 0: - sys.stderr.write(' ...\n') + sys.stderr.write(" ...\n") - sys.stderr.write('Actual:\n') + sys.stderr.write("Actual:\n") if num_skip_start > 0: - sys.stderr.write(' ...\n') + sys.stderr.write(" ...\n") for j in range(num_skip_start, len(actual) - num_skip_end): if j >= len(expected) or expected[j] != actual[j]: - sys.stderr.write(f' {actual[j]:<45} (diff)') + sys.stderr.write(f" {actual[j]:<45} (diff)") else: a = actual[j] - sys.stderr.write(' ' + a[:width]) + sys.stderr.write(" " + a[:width]) if len(a) > width: - sys.stderr.write('...') - sys.stderr.write('\n') + sys.stderr.write("...") + sys.stderr.write("\n") if not actual: - sys.stderr.write(' (empty)\n') + sys.stderr.write(" (empty)\n") if num_skip_end > 0: - sys.stderr.write(' ...\n') + sys.stderr.write(" ...\n") - sys.stderr.write('\n') + sys.stderr.write("\n") if 0 <= first_diff < len(actual) and ( - len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT - or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT): + len(expected[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT + or len(actual[first_diff]) >= MIN_LINE_LENGTH_FOR_ALIGNMENT + ): # Display message that helps visualize the differences between two # long lines. show_align_message(expected[first_diff], actual[first_diff]) @@ -121,46 +116,42 @@ def assert_string_arrays_equal(expected: List[str], actual: List[str], raise AssertionError(msg) -def assert_module_equivalence(name: str, - expected: Iterable[str], actual: Iterable[str]) -> None: +def assert_module_equivalence(name: str, expected: Iterable[str], actual: Iterable[str]) -> None: expected_normalized = sorted(expected) actual_normalized = sorted(set(actual).difference({"__main__"})) assert_string_arrays_equal( expected_normalized, actual_normalized, - ('Actual modules ({}) do not match expected modules ({}) ' - 'for "[{} ...]"').format( - ', '.join(actual_normalized), - ', '.join(expected_normalized), - name)) + ("Actual modules ({}) do not match expected modules ({}) " 'for "[{} ...]"').format( + ", ".join(actual_normalized), ", ".join(expected_normalized), name + ), + ) -def assert_target_equivalence(name: str, - expected: List[str], actual: List[str]) -> None: +def assert_target_equivalence(name: str, expected: List[str], actual: List[str]) -> None: """Compare actual and expected targets (order sensitive).""" assert_string_arrays_equal( expected, actual, - ('Actual targets ({}) do not match expected targets ({}) ' - 'for "[{} ...]"').format( - ', '.join(actual), - ', '.join(expected), - name)) + ("Actual targets ({}) do not match expected targets ({}) " 'for "[{} ...]"').format( + ", ".join(actual), ", ".join(expected), name + ), + ) def update_testcase_output(testcase: DataDrivenTestCase, output: List[str]) -> None: assert testcase.old_cwd is not None, "test was not properly set up" testcase_path = os.path.join(testcase.old_cwd, testcase.file) - with open(testcase_path, encoding='utf8') as f: + with open(testcase_path, encoding="utf8") as f: data_lines = f.read().splitlines() - test = '\n'.join(data_lines[testcase.line:testcase.last_line]) + test = "\n".join(data_lines[testcase.line : testcase.last_line]) mapping: Dict[str, List[str]] = {} for old, new in zip(testcase.output, output): - PREFIX = 'error:' + PREFIX = "error:" ind = old.find(PREFIX) if ind != -1 and old[:ind] == new[:ind]: - old, new = old[ind + len(PREFIX):], new[ind + len(PREFIX):] + old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :] mapping.setdefault(old, []).append(new) for old in mapping: @@ -169,13 +160,15 @@ def update_testcase_output(testcase: DataDrivenTestCase, output: List[str]) -> N # Interleave betweens and mapping[old] from itertools import chain - interleaved = [betweens[0]] + \ - list(chain.from_iterable(zip(mapping[old], betweens[1:]))) - test = ''.join(interleaved) - data_lines[testcase.line:testcase.last_line] = [test] - data = '\n'.join(data_lines) - with open(testcase_path, 'w', encoding='utf8') as f: + interleaved = [betweens[0]] + list( + chain.from_iterable(zip(mapping[old], betweens[1:])) + ) + test = "".join(interleaved) + + data_lines[testcase.line : testcase.last_line] = [test] + data = "\n".join(data_lines) + with open(testcase_path, "w", encoding="utf8") as f: print(data, file=f) @@ -200,7 +193,7 @@ def show_align_message(s1: str, s2: str) -> None: maxw = 72 # Maximum number of characters shown - sys.stderr.write('Alignment of first line difference:\n') + sys.stderr.write("Alignment of first line difference:\n") trunc = False while s1[:30] == s2[:30]: @@ -209,26 +202,26 @@ def show_align_message(s1: str, s2: str) -> None: trunc = True if trunc: - s1 = '...' + s1 - s2 = '...' + s2 + s1 = "..." + s1 + s2 = "..." + s2 max_len = max(len(s1), len(s2)) - extra = '' + extra = "" if max_len > maxw: - extra = '...' + extra = "..." # Write a chunk of both lines, aligned. - sys.stderr.write(f' E: {s1[:maxw]}{extra}\n') - sys.stderr.write(f' A: {s2[:maxw]}{extra}\n') + sys.stderr.write(f" E: {s1[:maxw]}{extra}\n") + sys.stderr.write(f" A: {s2[:maxw]}{extra}\n") # Write an indicator character under the different columns. - sys.stderr.write(' ') + sys.stderr.write(" ") for j in range(min(maxw, max(len(s1), len(s2)))): - if s1[j:j + 1] != s2[j:j + 1]: - sys.stderr.write('^') # Difference + if s1[j : j + 1] != s2[j : j + 1]: + sys.stderr.write("^") # Difference break else: - sys.stderr.write(' ') # Equal - sys.stderr.write('\n') + sys.stderr.write(" ") # Equal + sys.stderr.write("\n") def clean_up(a: List[str]) -> List[str]: @@ -239,18 +232,18 @@ def clean_up(a: List[str]) -> List[str]: """ res = [] pwd = os.getcwd() - driver = pwd + '/driver.py' + driver = pwd + "/driver.py" for s in a: prefix = os.sep ss = s - for p in prefix, prefix.replace(os.sep, '/'): - if p != '/' and p != '//' and p != '\\' and p != '\\\\': - ss = ss.replace(p, '') + for p in prefix, prefix.replace(os.sep, "/"): + if p != "/" and p != "//" and p != "\\" and p != "\\\\": + ss = ss.replace(p, "") # Ignore spaces at end of line. - ss = re.sub(' +$', '', ss) + ss = re.sub(" +$", "", ss) # Remove pwd from driver.py's path - ss = ss.replace(driver, 'driver.py') - res.append(re.sub('\\r$', '', ss)) + ss = ss.replace(driver, "driver.py") + res.append(re.sub("\\r$", "", ss)) return res @@ -262,8 +255,8 @@ def local_sys_path_set() -> Iterator[None]: by the stubgen tests. """ old_sys_path = sys.path[:] - if not ('' in sys.path or '.' in sys.path): - sys.path.insert(0, '') + if not ("" in sys.path or "." in sys.path): + sys.path.insert(0, "") try: yield finally: @@ -279,21 +272,20 @@ def num_skipped_prefix_lines(a1: List[str], a2: List[str]) -> int: def num_skipped_suffix_lines(a1: List[str], a2: List[str]) -> int: num_eq = 0 - while (num_eq < min(len(a1), len(a2)) - and a1[-num_eq - 1] == a2[-num_eq - 1]): + while num_eq < min(len(a1), len(a2)) and a1[-num_eq - 1] == a2[-num_eq - 1]: num_eq += 1 return max(0, num_eq - 4) def testfile_pyversion(path: str) -> Tuple[int, int]: - if path.endswith('python310.test'): + if path.endswith("python310.test"): return 3, 10 else: return defaults.PYTHON3_VERSION def testcase_pyversion(path: str, testcase_name: str) -> Tuple[int, int]: - if testcase_name.endswith('python2'): + if testcase_name.endswith("python2"): raise ValueError(testcase_name) return defaults.PYTHON2_VERSION else: @@ -305,7 +297,7 @@ def normalize_error_messages(messages: List[str]) -> List[str]: a = [] for m in messages: - a.append(m.replace(os.sep, '/')) + a.append(m.replace(os.sep, "/")) return a @@ -334,25 +326,25 @@ def retry_on_error(func: Callable[[], Any], max_wait: float = 1.0) -> None: def good_repr(obj: object) -> str: if isinstance(obj, str): - if obj.count('\n') > 1: + if obj.count("\n") > 1: bits = ["'''\\"] - for line in obj.split('\n'): + for line in obj.split("\n"): # force repr to use ' not ", then cut it off bits.append(repr('"' + line)[2:-1]) bits[-1] += "'''" - return '\n'.join(bits) + return "\n".join(bits) return repr(obj) -def assert_equal(a: object, b: object, fmt: str = '{} != {}') -> None: +def assert_equal(a: object, b: object, fmt: str = "{} != {}") -> None: __tracebackhide__ = True if a != b: raise AssertionError(fmt.format(good_repr(a), good_repr(b))) def typename(t: type) -> str: - if '.' in str(t): - return str(t).split('.')[-1].rstrip("'>") + if "." in str(t): + return str(t).split(".")[-1].rstrip("'>") else: return str(t)[8:-2] @@ -360,28 +352,29 @@ def typename(t: type) -> str: def assert_type(typ: type, value: object) -> None: __tracebackhide__ = True if type(value) != typ: - raise AssertionError('Invalid type {}, expected {}'.format( - typename(type(value)), typename(typ))) + raise AssertionError( + "Invalid type {}, expected {}".format(typename(type(value)), typename(typ)) + ) -def parse_options(program_text: str, testcase: DataDrivenTestCase, - incremental_step: int) -> Options: +def parse_options( + program_text: str, testcase: DataDrivenTestCase, incremental_step: int +) -> Options: """Parse comments like '# flags: --foo' in a test case.""" options = Options() - flags = re.search('# flags: (.*)$', program_text, flags=re.MULTILINE) + flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE) if incremental_step > 1: - flags2 = re.search(f'# flags{incremental_step}: (.*)$', program_text, - flags=re.MULTILINE) + flags2 = re.search(f"# flags{incremental_step}: (.*)$", program_text, flags=re.MULTILINE) if flags2: flags = flags2 if flags: flag_list = flags.group(1).split() - flag_list.append('--no-site-packages') # the tests shouldn't need an installed Python + flag_list.append("--no-site-packages") # the tests shouldn't need an installed Python targets, options = process_options(flag_list, require_targets=False) if targets: # TODO: support specifying targets via the flags pragma - raise RuntimeError('Specifying targets via the flags pragma is not supported.') + raise RuntimeError("Specifying targets via the flags pragma is not supported.") else: flag_list = [] options = Options() @@ -390,22 +383,18 @@ def parse_options(program_text: str, testcase: DataDrivenTestCase, options.error_summary = False # Allow custom python version to override testcase_pyversion. - if all(flag.split('=')[0] not in ['--python-version', '-2', '--py2'] for flag in flag_list): + if all(flag.split("=")[0] not in ["--python-version", "-2", "--py2"] for flag in flag_list): options.python_version = testcase_pyversion(testcase.file, testcase.name) - if testcase.config.getoption('--mypy-verbose'): - options.verbosity = testcase.config.getoption('--mypy-verbose') + if testcase.config.getoption("--mypy-verbose"): + options.verbosity = testcase.config.getoption("--mypy-verbose") return options def split_lines(*streams: bytes) -> List[str]: """Returns a single list of string lines from the byte streams in args.""" - return [ - s - for stream in streams - for s in stream.decode('utf8').splitlines() - ] + return [s for stream in streams for s in stream.decode("utf8").splitlines()] def write_and_fudge_mtime(content: str, target_path: str) -> None: @@ -429,8 +418,7 @@ def write_and_fudge_mtime(content: str, target_path: str) -> None: os.utime(target_path, times=(new_time, new_time)) -def perform_file_operations( - operations: List[Union[UpdateFile, DeleteFile]]) -> None: +def perform_file_operations(operations: List[Union[UpdateFile, DeleteFile]]) -> None: for op in operations: if isinstance(op, UpdateFile): # Modify/create file @@ -439,7 +427,7 @@ def perform_file_operations( # Delete file/directory if os.path.isdir(op.path): # Sanity check to avoid unexpected deletions - assert op.path.startswith('tmp') + assert op.path.startswith("tmp") shutil.rmtree(op.path) else: # Use retries to work around potential flakiness on Windows (AppVeyor). @@ -447,52 +435,59 @@ def perform_file_operations( retry_on_error(lambda: os.remove(path)) -def check_test_output_files(testcase: DataDrivenTestCase, - step: int, - strip_prefix: str = '') -> None: +def check_test_output_files( + testcase: DataDrivenTestCase, step: int, strip_prefix: str = "" +) -> None: for path, expected_content in testcase.output_files: if path.startswith(strip_prefix): - path = path[len(strip_prefix):] + path = path[len(strip_prefix) :] if not os.path.exists(path): raise AssertionError( - 'Expected file {} was not produced by test case{}'.format( - path, ' on step %d' % step if testcase.output2 else '')) - with open(path, encoding='utf8') as output_file: + "Expected file {} was not produced by test case{}".format( + path, " on step %d" % step if testcase.output2 else "" + ) + ) + with open(path, encoding="utf8") as output_file: actual_output_content = output_file.read() if isinstance(expected_content, Pattern): if expected_content.fullmatch(actual_output_content) is not None: continue raise AssertionError( - 'Output file {} did not match its expected output pattern\n---\n{}\n---'.format( - path, actual_output_content) + "Output file {} did not match its expected output pattern\n---\n{}\n---".format( + path, actual_output_content + ) ) - normalized_output = normalize_file_output(actual_output_content.splitlines(), - os.path.abspath(test_temp_dir)) + normalized_output = normalize_file_output( + actual_output_content.splitlines(), os.path.abspath(test_temp_dir) + ) # We always normalize things like timestamp, but only handle operating-system # specific things if requested. if testcase.normalize_output: - if testcase.suite.native_sep and os.path.sep == '\\': - normalized_output = [fix_cobertura_filename(line) - for line in normalized_output] + if testcase.suite.native_sep and os.path.sep == "\\": + normalized_output = [fix_cobertura_filename(line) for line in normalized_output] normalized_output = normalize_error_messages(normalized_output) - assert_string_arrays_equal(expected_content.splitlines(), normalized_output, - 'Output file {} did not match its expected output{}'.format( - path, ' on step %d' % step if testcase.output2 else '')) + assert_string_arrays_equal( + expected_content.splitlines(), + normalized_output, + "Output file {} did not match its expected output{}".format( + path, " on step %d" % step if testcase.output2 else "" + ), + ) def normalize_file_output(content: List[str], current_abs_path: str) -> List[str]: """Normalize file output for comparison.""" - timestamp_regex = re.compile(r'\d{10}') - result = [x.replace(current_abs_path, '$PWD') for x in content] + timestamp_regex = re.compile(r"\d{10}") + result = [x.replace(current_abs_path, "$PWD") for x in content] version = mypy.version.__version__ - result = [re.sub(r'\b' + re.escape(version) + r'\b', '$VERSION', x) for x in result] + result = [re.sub(r"\b" + re.escape(version) + r"\b", "$VERSION", x) for x in result] # We generate a new mypy.version when building mypy wheels that # lacks base_version, so handle that case. - base_version = getattr(mypy.version, 'base_version', version) - result = [re.sub(r'\b' + re.escape(base_version) + r'\b', '$VERSION', x) for x in result] - result = [timestamp_regex.sub('$TIMESTAMP', x) for x in result] + base_version = getattr(mypy.version, "base_version", version) + result = [re.sub(r"\b" + re.escape(base_version) + r"\b", "$VERSION", x) for x in result] + result = [timestamp_regex.sub("$TIMESTAMP", x) for x in result] return result diff --git a/mypy/test/test_find_sources.py b/mypy/test/test_find_sources.py index e9e7432327e7f..ff0809d54183a 100644 --- a/mypy/test/test_find_sources.py +++ b/mypy/test/test_find_sources.py @@ -1,14 +1,15 @@ import os -import pytest import shutil import tempfile import unittest from typing import List, Optional, Set, Tuple +import pytest + from mypy.find_sources import InvalidSourceList, SourceFinder, create_source_list from mypy.fscache import FileSystemCache -from mypy.options import Options from mypy.modulefinder import BuildSource +from mypy.options import Options class FakeFSCache(FileSystemCache): @@ -26,7 +27,7 @@ def isdir(self, dir: str) -> bool: def listdir(self, dir: str) -> List[str]: if not dir.endswith(os.sep): dir += os.sep - return list({f[len(dir):].split(os.sep)[0] for f in self.files if f.startswith(dir)}) + return list({f[len(dir) :].split(os.sep)[0] for f in self.files if f.startswith(dir)}) def init_under_package_root(self, file: str) -> bool: return False @@ -87,18 +88,14 @@ def test_crawl_no_namespace(self) -> None: finder = SourceFinder(FakeFSCache({"/a/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/setup.py") == ("a.setup", "/") - finder = SourceFinder( - FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), - options, - ) + finder = SourceFinder(FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/invalid-name/setup.py") == ("setup", "/a/invalid-name") finder = SourceFinder(FakeFSCache({"/a/b/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/b/setup.py") == ("setup", "/a/b") finder = SourceFinder( - FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), - options, + FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), options ) assert crawl(finder, "/a/b/c/setup.py") == ("c.setup", "/a/b") @@ -118,18 +115,14 @@ def test_crawl_namespace(self) -> None: finder = SourceFinder(FakeFSCache({"/a/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/setup.py") == ("a.setup", "/") - finder = SourceFinder( - FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), - options, - ) + finder = SourceFinder(FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/invalid-name/setup.py") == ("setup", "/a/invalid-name") finder = SourceFinder(FakeFSCache({"/a/b/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/b/setup.py") == ("a.b.setup", "/") finder = SourceFinder( - FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), - options, + FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), options ) assert crawl(finder, "/a/b/c/setup.py") == ("a.b.c.setup", "/") @@ -150,18 +143,14 @@ def test_crawl_namespace_explicit_base(self) -> None: finder = SourceFinder(FakeFSCache({"/a/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/setup.py") == ("a.setup", "/") - finder = SourceFinder( - FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), - options, - ) + finder = SourceFinder(FakeFSCache({"/a/invalid-name/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/invalid-name/setup.py") == ("setup", "/a/invalid-name") finder = SourceFinder(FakeFSCache({"/a/b/setup.py", "/a/__init__.py"}), options) assert crawl(finder, "/a/b/setup.py") == ("a.b.setup", "/") finder = SourceFinder( - FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), - options, + FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), options ) assert crawl(finder, "/a/b/c/setup.py") == ("a.b.c.setup", "/") @@ -172,8 +161,7 @@ def test_crawl_namespace_explicit_base(self) -> None: assert crawl(finder, "/a/b/c/setup.py") == ("c.setup", "/a/b") finder = SourceFinder( - FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), - options, + FakeFSCache({"/a/b/c/setup.py", "/a/__init__.py", "/a/b/c/__init__.py"}), options ) assert crawl(finder, "/a/b/c/setup.py") == ("c.setup", "/a/b") @@ -305,8 +293,8 @@ def test_find_sources_exclude(self) -> None: ("a2.b.c.d.e", "/pkg"), ("e", "/pkg/a1/b/c/d"), ] - assert find_sources(["/pkg/a1/b/f.py"], options, fscache) == [('f', '/pkg/a1/b')] - assert find_sources(["/pkg/a2/b/f.py"], options, fscache) == [('a2.b.f', '/pkg')] + assert find_sources(["/pkg/a1/b/f.py"], options, fscache) == [("f", "/pkg/a1/b")] + assert find_sources(["/pkg/a2/b/f.py"], options, fscache) == [("a2.b.f", "/pkg")] # directory name options.exclude = ["/a1/"] @@ -325,7 +313,8 @@ def test_find_sources_exclude(self) -> None: options.exclude = ["/a1/$"] assert find_sources(["/pkg/a1"], options, fscache) == [ - ('e', '/pkg/a1/b/c/d'), ('f', '/pkg/a1/b') + ("e", "/pkg/a1/b/c/d"), + ("f", "/pkg/a1/b"), ] # paths @@ -359,8 +348,14 @@ def test_find_sources_exclude(self) -> None: # nothing should be ignored as a result of this big_exclude1 = [ - "/pkg/a/", "/2", "/1", "/pk/", "/kg", "/g.py", "/bc", "/xxx/pkg/a2/b/f.py" - "xxx/pkg/a2/b/f.py", + "/pkg/a/", + "/2", + "/1", + "/pk/", + "/kg", + "/g.py", + "/bc", + "/xxx/pkg/a2/b/f.py" "xxx/pkg/a2/b/f.py", ] big_exclude2 = ["|".join(big_exclude1)] for big_exclude in [big_exclude1, big_exclude2]: diff --git a/mypy/test/testapi.py b/mypy/test/testapi.py index 00f086c11eced..9b8787ed4af64 100644 --- a/mypy/test/testapi.py +++ b/mypy/test/testapi.py @@ -1,13 +1,11 @@ -from io import StringIO import sys +from io import StringIO import mypy.api - from mypy.test.helpers import Suite class APISuite(Suite): - def setUp(self) -> None: self.sys_stdout = sys.stdout self.sys_stderr = sys.stderr @@ -17,29 +15,29 @@ def setUp(self) -> None: def tearDown(self) -> None: sys.stdout = self.sys_stdout sys.stderr = self.sys_stderr - assert self.stdout.getvalue() == '' - assert self.stderr.getvalue() == '' + assert self.stdout.getvalue() == "" + assert self.stderr.getvalue() == "" def test_capture_bad_opt(self) -> None: """stderr should be captured when a bad option is passed.""" - _, stderr, _ = mypy.api.run(['--some-bad-option']) + _, stderr, _ = mypy.api.run(["--some-bad-option"]) assert isinstance(stderr, str) - assert stderr != '' + assert stderr != "" def test_capture_empty(self) -> None: """stderr should be captured when a bad option is passed.""" _, stderr, _ = mypy.api.run([]) assert isinstance(stderr, str) - assert stderr != '' + assert stderr != "" def test_capture_help(self) -> None: """stdout should be captured when --help is passed.""" - stdout, _, _ = mypy.api.run(['--help']) + stdout, _, _ = mypy.api.run(["--help"]) assert isinstance(stdout, str) - assert stdout != '' + assert stdout != "" def test_capture_version(self) -> None: """stdout should be captured when --version is passed.""" - stdout, _, _ = mypy.api.run(['--version']) + stdout, _, _ = mypy.api.run(["--version"]) assert isinstance(stdout, str) - assert stdout != '' + assert stdout != "" diff --git a/mypy/test/testargs.py b/mypy/test/testargs.py index 8d74207f353fd..686f7e132bc92 100644 --- a/mypy/test/testargs.py +++ b/mypy/test/testargs.py @@ -7,9 +7,9 @@ import argparse import sys -from mypy.test.helpers import Suite, assert_equal +from mypy.main import infer_python_executable, process_options from mypy.options import Options -from mypy.main import process_options, infer_python_executable +from mypy.test.helpers import Suite, assert_equal class ArgSuite(Suite): @@ -22,31 +22,32 @@ def test_coherence(self) -> None: def test_executable_inference(self) -> None: """Test the --python-executable flag with --python-version""" - sys_ver_str = '{ver.major}.{ver.minor}'.format(ver=sys.version_info) + sys_ver_str = "{ver.major}.{ver.minor}".format(ver=sys.version_info) - base = ['file.py'] # dummy file + base = ["file.py"] # dummy file # test inference given one (infer the other) - matching_version = base + [f'--python-version={sys_ver_str}'] + matching_version = base + [f"--python-version={sys_ver_str}"] _, options = process_options(matching_version) assert options.python_version == sys.version_info[:2] assert options.python_executable == sys.executable - matching_version = base + [f'--python-executable={sys.executable}'] + matching_version = base + [f"--python-executable={sys.executable}"] _, options = process_options(matching_version) assert options.python_version == sys.version_info[:2] assert options.python_executable == sys.executable # test inference given both - matching_version = base + [f'--python-version={sys_ver_str}', - f'--python-executable={sys.executable}'] + matching_version = base + [ + f"--python-version={sys_ver_str}", + f"--python-executable={sys.executable}", + ] _, options = process_options(matching_version) assert options.python_version == sys.version_info[:2] assert options.python_executable == sys.executable # test that --no-site-packages will disable executable inference - matching_version = base + [f'--python-version={sys_ver_str}', - '--no-site-packages'] + matching_version = base + [f"--python-version={sys_ver_str}", "--no-site-packages"] _, options = process_options(matching_version) assert options.python_version == sys.version_info[:2] assert options.python_executable is None diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index dc2a2ba070402..8e1f017b23360 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -3,24 +3,26 @@ import os import re import sys - from typing import Dict, List, Set, Tuple from mypy import build from mypy.build import Graph -from mypy.modulefinder import BuildSource, SearchPaths, FindModuleCache -from mypy.test.config import test_temp_dir, test_data_prefix -from mypy.test.data import ( - DataDrivenTestCase, DataSuite, FileOperation, module_from_path -) +from mypy.errors import CompileError +from mypy.modulefinder import BuildSource, FindModuleCache, SearchPaths +from mypy.semanal_main import core_modules +from mypy.test.config import test_data_prefix, test_temp_dir +from mypy.test.data import DataDrivenTestCase, DataSuite, FileOperation, module_from_path from mypy.test.helpers import ( - assert_string_arrays_equal, normalize_error_messages, assert_module_equivalence, - update_testcase_output, parse_options, - assert_target_equivalence, check_test_output_files, perform_file_operations, + assert_module_equivalence, + assert_string_arrays_equal, + assert_target_equivalence, + check_test_output_files, find_test_files, + normalize_error_messages, + parse_options, + perform_file_operations, + update_testcase_output, ) -from mypy.errors import CompileError -from mypy.semanal_main import core_modules try: import lxml # type: ignore @@ -35,26 +37,28 @@ # Tests that use Python 3.8-only AST features (like expression-scoped ignores): if sys.version_info < (3, 8): - typecheck_files.remove('check-python38.test') + typecheck_files.remove("check-python38.test") if sys.version_info < (3, 9): - typecheck_files.remove('check-python39.test') + typecheck_files.remove("check-python39.test") if sys.version_info < (3, 10): - typecheck_files.remove('check-python310.test') + typecheck_files.remove("check-python310.test") # Special tests for platforms with case-insensitive filesystems. -if sys.platform not in ('darwin', 'win32'): - typecheck_files.remove('check-modules-case.test') +if sys.platform not in ("darwin", "win32"): + typecheck_files.remove("check-modules-case.test") class TypeCheckSuite(DataSuite): files = typecheck_files def run_case(self, testcase: DataDrivenTestCase) -> None: - if lxml is None and os.path.basename(testcase.file) == 'check-reports.test': + if lxml is None and os.path.basename(testcase.file) == "check-reports.test": pytest.skip("Cannot import lxml. Is it installed?") - incremental = ('incremental' in testcase.name.lower() - or 'incremental' in testcase.file - or 'serialize' in testcase.file) + incremental = ( + "incremental" in testcase.name.lower() + or "incremental" in testcase.file + or "serialize" in testcase.file + ) if incremental: # Incremental tests are run once with a cold cache, once with a warm cache. # Expect success on first run, errors from testcase.output (if any) on second run. @@ -62,11 +66,13 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: # Check that there are no file changes beyond the last run (they would be ignored). for dn, dirs, files in os.walk(os.curdir): for file in files: - m = re.search(r'\.([2-9])$', file) + m = re.search(r"\.([2-9])$", file) if m and int(m.group(1)) > num_steps: raise ValueError( - 'Output file {} exists though test case only has {} runs'.format( - file, num_steps)) + "Output file {} exists though test case only has {} runs".format( + file, num_steps + ) + ) steps = testcase.find_steps() for step in range(1, num_steps + 1): idx = step - 2 @@ -75,22 +81,25 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: self.run_case_once(testcase) - def run_case_once(self, testcase: DataDrivenTestCase, - operations: List[FileOperation] = [], - incremental_step: int = 0) -> None: - original_program_text = '\n'.join(testcase.input) + def run_case_once( + self, + testcase: DataDrivenTestCase, + operations: List[FileOperation] = [], + incremental_step: int = 0, + ) -> None: + original_program_text = "\n".join(testcase.input) module_data = self.parse_module(original_program_text, incremental_step) # Unload already loaded plugins, they may be updated. for file, _ in testcase.files: module = module_from_path(file) - if module.endswith('_plugin') and module in sys.modules: + if module.endswith("_plugin") and module in sys.modules: del sys.modules[module] if incremental_step == 0 or incremental_step == 1: # In run 1, copy program text to program file. for module_name, program_path, program_text in module_data: - if module_name == '__main__': - with open(program_path, 'w', encoding='utf8') as f: + if module_name == "__main__": + with open(program_path, "w", encoding="utf8") as f: f.write(program_text) break elif incremental_step > 1: @@ -104,11 +113,11 @@ def run_case_once(self, testcase: DataDrivenTestCase, options.show_traceback = True # Enable some options automatically based on test file name. - if 'optional' in testcase.file: + if "optional" in testcase.file: options.strict_optional = True - if 'columns' in testcase.file: + if "columns" in testcase.file: options.show_column_numbers = True - if 'errorcodes' in testcase.file: + if "errorcodes" in testcase.file: options.show_error_codes = True if incremental_step and options.incremental: @@ -123,17 +132,16 @@ def run_case_once(self, testcase: DataDrivenTestCase, sources = [] for module_name, program_path, program_text in module_data: # Always set to none so we're forced to reread the module in incremental mode - sources.append(BuildSource(program_path, module_name, - None if incremental_step else program_text)) + sources.append( + BuildSource(program_path, module_name, None if incremental_step else program_text) + ) - plugin_dir = os.path.join(test_data_prefix, 'plugins') + plugin_dir = os.path.join(test_data_prefix, "plugins") sys.path.insert(0, plugin_dir) res = None try: - res = build.build(sources=sources, - options=options, - alt_lib_path=test_temp_dir) + res = build.build(sources=sources, options=options, alt_lib_path=test_temp_dir) a = res.errors except CompileError as e: a = e.messages @@ -147,19 +155,21 @@ def run_case_once(self, testcase: DataDrivenTestCase, # Make sure error messages match if incremental_step == 0: # Not incremental - msg = 'Unexpected type checker output ({}, line {})' + msg = "Unexpected type checker output ({}, line {})" output = testcase.output elif incremental_step == 1: - msg = 'Unexpected type checker output in incremental, run 1 ({}, line {})' + msg = "Unexpected type checker output in incremental, run 1 ({}, line {})" output = testcase.output elif incremental_step > 1: - msg = ('Unexpected type checker output in incremental, run {}'.format( - incremental_step) + ' ({}, line {})') + msg = ( + "Unexpected type checker output in incremental, run {}".format(incremental_step) + + " ({}, line {})" + ) output = testcase.output2.get(incremental_step, []) else: raise AssertionError() - if output != a and testcase.config.getoption('--update-data', False): + if output != a and testcase.config.getoption("--update-data", False): update_testcase_output(testcase, a) assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line)) @@ -167,41 +177,47 @@ def run_case_once(self, testcase: DataDrivenTestCase, if options.cache_dir != os.devnull: self.verify_cache(module_data, res.errors, res.manager, res.graph) - name = 'targets' + name = "targets" if incremental_step: name += str(incremental_step + 1) expected = testcase.expected_fine_grained_targets.get(incremental_step + 1) actual = res.manager.processed_targets # Skip the initial builtin cycle. - actual = [t for t in actual - if not any(t.startswith(mod) - for mod in core_modules + ['mypy_extensions'])] + actual = [ + t + for t in actual + if not any(t.startswith(mod) for mod in core_modules + ["mypy_extensions"]) + ] if expected is not None: assert_target_equivalence(name, expected, actual) if incremental_step > 1: - suffix = '' if incremental_step == 2 else str(incremental_step - 1) + suffix = "" if incremental_step == 2 else str(incremental_step - 1) expected_rechecked = testcase.expected_rechecked_modules.get(incremental_step - 1) if expected_rechecked is not None: assert_module_equivalence( - 'rechecked' + suffix, - expected_rechecked, res.manager.rechecked_modules) + "rechecked" + suffix, expected_rechecked, res.manager.rechecked_modules + ) expected_stale = testcase.expected_stale_modules.get(incremental_step - 1) if expected_stale is not None: assert_module_equivalence( - 'stale' + suffix, - expected_stale, res.manager.stale_modules) + "stale" + suffix, expected_stale, res.manager.stale_modules + ) if testcase.output_files: - check_test_output_files(testcase, incremental_step, strip_prefix='tmp/') - - def verify_cache(self, module_data: List[Tuple[str, str, str]], a: List[str], - manager: build.BuildManager, graph: Graph) -> None: + check_test_output_files(testcase, incremental_step, strip_prefix="tmp/") + + def verify_cache( + self, + module_data: List[Tuple[str, str, str]], + a: List[str], + manager: build.BuildManager, + graph: Graph, + ) -> None: # There should be valid cache metadata for each module except # for those that had an error in themselves or one of their # dependencies. error_paths = self.find_error_message_paths(a) - busted_paths = {m.path for id, m in manager.modules.items() - if graph[id].transitive_error} + busted_paths = {m.path for id, m in manager.modules.items() if graph[id].transitive_error} modules = self.find_module_files(manager) modules.update({module_name: path for module_name, path, text in module_data}) missing_paths = self.find_missing_cache_files(modules, manager) @@ -211,8 +227,7 @@ def verify_cache(self, module_data: List[Tuple[str, str, str]], a: List[str], # just notes attached to other errors. assert error_paths or not busted_paths, "Some modules reported error despite no errors" if not missing_paths == busted_paths: - raise AssertionError("cache data discrepancy %s != %s" % - (missing_paths, busted_paths)) + raise AssertionError("cache data discrepancy %s != %s" % (missing_paths, busted_paths)) assert os.path.isfile(os.path.join(manager.options.cache_dir, ".gitignore")) cachedir_tag = os.path.join(manager.options.cache_dir, "CACHEDIR.TAG") assert os.path.isfile(cachedir_tag) @@ -222,7 +237,7 @@ def verify_cache(self, module_data: List[Tuple[str, str, str]], a: List[str], def find_error_message_paths(self, a: List[str]) -> Set[str]: hits = set() for line in a: - m = re.match(r'([^\s:]+):(\d+:)?(\d+:)? (error|warning|note):', line) + m = re.match(r"([^\s:]+):(\d+:)?(\d+:)? (error|warning|note):", line) if m: p = m.group(1) hits.add(p) @@ -231,8 +246,9 @@ def find_error_message_paths(self, a: List[str]) -> Set[str]: def find_module_files(self, manager: build.BuildManager) -> Dict[str, str]: return {id: module.path for id, module in manager.modules.items()} - def find_missing_cache_files(self, modules: Dict[str, str], - manager: build.BuildManager) -> Set[str]: + def find_missing_cache_files( + self, modules: Dict[str, str], manager: build.BuildManager + ) -> Set[str]: ignore_errors = True missing = {} for id, path in modules.items(): @@ -241,9 +257,9 @@ def find_missing_cache_files(self, modules: Dict[str, str], missing[id] = path return set(missing.values()) - def parse_module(self, - program_text: str, - incremental_step: int = 0) -> List[Tuple[str, str, str]]: + def parse_module( + self, program_text: str, incremental_step: int = 0 + ) -> List[Tuple[str, str, str]]: """Return the module and program names for a test case. Normally, the unit tests will parse the default ('__main__') @@ -258,9 +274,9 @@ def parse_module(self, Return a list of tuples (module name, file name, program text). """ - m = re.search('# cmd: mypy -m ([a-zA-Z0-9_. ]+)$', program_text, flags=re.MULTILINE) + m = re.search("# cmd: mypy -m ([a-zA-Z0-9_. ]+)$", program_text, flags=re.MULTILINE) if incremental_step > 1: - alt_regex = f'# cmd{incremental_step}: mypy -m ([a-zA-Z0-9_. ]+)$' + alt_regex = f"# cmd{incremental_step}: mypy -m ([a-zA-Z0-9_. ]+)$" alt_m = re.search(alt_regex, program_text, flags=re.MULTILINE) if alt_m is not None: # Optionally return a different command if in a later step @@ -276,12 +292,12 @@ def parse_module(self, out = [] search_paths = SearchPaths((test_temp_dir,), (), (), ()) cache = FindModuleCache(search_paths, fscache=None, options=None) - for module_name in module_names.split(' '): + for module_name in module_names.split(" "): path = cache.find_module(module_name) assert isinstance(path, str), f"Can't find ad hoc case file: {module_name}" - with open(path, encoding='utf8') as f: + with open(path, encoding="utf8") as f: program_text = f.read() out.append((module_name, path, program_text)) return out else: - return [('__main__', 'main', program_text)] + return [("__main__", "main", program_text)] diff --git a/mypy/test/testcmdline.py b/mypy/test/testcmdline.py index 9983dc554323a..dd0410746e90c 100644 --- a/mypy/test/testcmdline.py +++ b/mypy/test/testcmdline.py @@ -8,14 +8,14 @@ import re import subprocess import sys +from typing import List, Optional -from typing import List -from typing import Optional - -from mypy.test.config import test_temp_dir, PREFIX +from mypy.test.config import PREFIX, test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import ( - assert_string_arrays_equal, normalize_error_messages, check_test_output_files + assert_string_arrays_equal, + check_test_output_files, + normalize_error_messages, ) try: @@ -29,12 +29,7 @@ python3_path = sys.executable # Files containing test case descriptions. -cmdline_files = [ - 'cmdline.test', - 'cmdline.pyproject.test', - 'reports.test', - 'envvars.test', -] +cmdline_files = ["cmdline.test", "cmdline.pyproject.test", "reports.test", "envvars.test"] class PythonCmdlineSuite(DataSuite): @@ -42,7 +37,7 @@ class PythonCmdlineSuite(DataSuite): native_sep = True def run_case(self, testcase: DataDrivenTestCase) -> None: - if lxml is None and os.path.basename(testcase.file) == 'reports.test': + if lxml is None and os.path.basename(testcase.file) == "reports.test": pytest.skip("Cannot import lxml. Is it installed?") for step in [1] + sorted(testcase.output2): test_python_cmdline(testcase, step) @@ -51,43 +46,42 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: def test_python_cmdline(testcase: DataDrivenTestCase, step: int) -> None: assert testcase.old_cwd is not None, "test was not properly set up" # Write the program to a file. - program = '_program.py' + program = "_program.py" program_path = os.path.join(test_temp_dir, program) - with open(program_path, 'w', encoding='utf8') as file: + with open(program_path, "w", encoding="utf8") as file: for s in testcase.input: - file.write(f'{s}\n') + file.write(f"{s}\n") args = parse_args(testcase.input[0]) custom_cwd = parse_cwd(testcase.input[1]) if len(testcase.input) > 1 else None - args.append('--show-traceback') - if '--error-summary' not in args: - args.append('--no-error-summary') + args.append("--show-traceback") + if "--error-summary" not in args: + args.append("--no-error-summary") # Type check the program. - fixed = [python3_path, '-m', 'mypy'] + fixed = [python3_path, "-m", "mypy"] env = os.environ.copy() - env.pop('COLUMNS', None) - extra_path = os.path.join(os.path.abspath(test_temp_dir), 'pypath') - env['PYTHONPATH'] = PREFIX + env.pop("COLUMNS", None) + extra_path = os.path.join(os.path.abspath(test_temp_dir), "pypath") + env["PYTHONPATH"] = PREFIX if os.path.isdir(extra_path): - env['PYTHONPATH'] += os.pathsep + extra_path - process = subprocess.Popen(fixed + args, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - cwd=os.path.join( - test_temp_dir, - custom_cwd or "" - ), - env=env) + env["PYTHONPATH"] += os.pathsep + extra_path + process = subprocess.Popen( + fixed + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=os.path.join(test_temp_dir, custom_cwd or ""), + env=env, + ) outb, errb = process.communicate() result = process.returncode # Split output into lines. - out = [s.rstrip('\n\r') for s in str(outb, 'utf8').splitlines()] - err = [s.rstrip('\n\r') for s in str(errb, 'utf8').splitlines()] + out = [s.rstrip("\n\r") for s in str(outb, "utf8").splitlines()] + err = [s.rstrip("\n\r") for s in str(errb, "utf8").splitlines()] if "PYCHARM_HOSTED" in os.environ: for pos, line in enumerate(err): - if line.startswith('pydev debugger: '): + if line.startswith("pydev debugger: "): # Delete the attaching debugger message itself, plus the extra newline added. - del err[pos:pos + 2] + del err[pos : pos + 2] break # Remove temp file. @@ -97,23 +91,26 @@ def test_python_cmdline(testcase: DataDrivenTestCase, step: int) -> None: # Ignore stdout, but we insist on empty stderr and zero status. if err or result: raise AssertionError( - 'Expected zero status and empty stderr%s, got %d and\n%s' % - (' on step %d' % step if testcase.output2 else '', - result, '\n'.join(err + out))) + "Expected zero status and empty stderr%s, got %d and\n%s" + % (" on step %d" % step if testcase.output2 else "", result, "\n".join(err + out)) + ) check_test_output_files(testcase, step) else: if testcase.normalize_output: out = normalize_error_messages(err + out) obvious_result = 1 if out else 0 if obvious_result != result: - out.append(f'== Return code: {result}') + out.append(f"== Return code: {result}") expected_out = testcase.output if step == 1 else testcase.output2[step] # Strip "tmp/" out of the test so that # E: works... expected_out = [s.replace("tmp" + os.sep, "") for s in expected_out] - assert_string_arrays_equal(expected_out, out, - 'Invalid output ({}, line {}){}'.format( - testcase.file, testcase.line, - ' on step %d' % step if testcase.output2 else '')) + assert_string_arrays_equal( + expected_out, + out, + "Invalid output ({}, line {}){}".format( + testcase.file, testcase.line, " on step %d" % step if testcase.output2 else "" + ), + ) def parse_args(line: str) -> List[str]: @@ -127,7 +124,7 @@ def parse_args(line: str) -> List[str]: # cmd: mypy pkg/ """ - m = re.match('# cmd: mypy (.*)$', line) + m = re.match("# cmd: mypy (.*)$", line) if not m: return [] # No args; mypy will spit out an error. return m.group(1).split() @@ -144,5 +141,5 @@ def parse_cwd(line: str) -> Optional[str]: # cwd: main/subdir """ - m = re.match('# cwd: (.*)$', line) + m = re.match("# cwd: (.*)$", line) return m.group(1) if m else None diff --git a/mypy/test/testconstraints.py b/mypy/test/testconstraints.py index f8af9ec140b5d..c3930b0475b5c 100644 --- a/mypy/test/testconstraints.py +++ b/mypy/test/testconstraints.py @@ -1,6 +1,6 @@ +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint, infer_constraints from mypy.test.helpers import Suite from mypy.test.typefixture import TypeFixture -from mypy.constraints import infer_constraints, SUBTYPE_OF, SUPERTYPE_OF, Constraint class ConstraintsSuite(Suite): @@ -14,9 +14,5 @@ def test_basic_type_variable(self) -> None: fx = self.fx for direction in [SUBTYPE_OF, SUPERTYPE_OF]: assert infer_constraints(fx.gt, fx.ga, direction) == [ - Constraint( - type_var=fx.t.id, - op=direction, - target=fx.a, - ) + Constraint(type_var=fx.t.id, op=direction, target=fx.a) ] diff --git a/mypy/test/testdaemon.py b/mypy/test/testdaemon.py index 804a562e71f16..87a9877267e24 100644 --- a/mypy/test/testdaemon.py +++ b/mypy/test/testdaemon.py @@ -12,18 +12,15 @@ import unittest from typing import List, Tuple -from mypy.modulefinder import SearchPaths -from mypy.fscache import FileSystemCache from mypy.dmypy_server import filter_out_missing_top_level_packages - -from mypy.test.config import test_temp_dir, PREFIX +from mypy.fscache import FileSystemCache +from mypy.modulefinder import SearchPaths +from mypy.test.config import PREFIX, test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages # Files containing test cases descriptions. -daemon_files = [ - 'daemon.test', -] +daemon_files = ["daemon.test"] class DaemonSuite(DataSuite): @@ -34,7 +31,7 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: test_daemon(testcase) finally: # Kill the daemon if it's still running. - run_cmd('dmypy kill') + run_cmd("dmypy kill") def test_daemon(testcase: DataDrivenTestCase) -> None: @@ -42,18 +39,19 @@ def test_daemon(testcase: DataDrivenTestCase) -> None: for i, step in enumerate(parse_script(testcase.input)): cmd = step[0] expected_lines = step[1:] - assert cmd.startswith('$') + assert cmd.startswith("$") cmd = cmd[1:].strip() - cmd = cmd.replace('{python}', sys.executable) + cmd = cmd.replace("{python}", sys.executable) sts, output = run_cmd(cmd) output_lines = output.splitlines() output_lines = normalize_error_messages(output_lines) if sts: - output_lines.append('== Return code: %d' % sts) - assert_string_arrays_equal(expected_lines, - output_lines, - "Command %d (%s) did not give expected output" % - (i + 1, cmd)) + output_lines.append("== Return code: %d" % sts) + assert_string_arrays_equal( + expected_lines, + output_lines, + "Command %d (%s) did not give expected output" % (i + 1, cmd), + ) def parse_script(input: List[str]) -> List[List[str]]: @@ -66,9 +64,9 @@ def parse_script(input: List[str]) -> List[List[str]]: steps = [] step: List[str] = [] for line in input: - if line.startswith('$'): + if line.startswith("$"): if step: - assert step[0].startswith('$') + assert step[0].startswith("$") steps.append(step) step = [] step.append(line) @@ -78,19 +76,21 @@ def parse_script(input: List[str]) -> List[List[str]]: def run_cmd(input: str) -> Tuple[int, str]: - if input.startswith('dmypy '): - input = sys.executable + ' -m mypy.' + input - if input.startswith('mypy '): - input = sys.executable + ' -m' + input + if input.startswith("dmypy "): + input = sys.executable + " -m mypy." + input + if input.startswith("mypy "): + input = sys.executable + " -m" + input env = os.environ.copy() - env['PYTHONPATH'] = PREFIX + env["PYTHONPATH"] = PREFIX try: - output = subprocess.check_output(input, - shell=True, - stderr=subprocess.STDOUT, - universal_newlines=True, - cwd=test_temp_dir, - env=env) + output = subprocess.check_output( + input, + shell=True, + stderr=subprocess.STDOUT, + universal_newlines=True, + cwd=test_temp_dir, + env=env, + ) return 0, output except subprocess.CalledProcessError as err: return err.returncode, err.output @@ -101,33 +101,34 @@ class DaemonUtilitySuite(unittest.TestCase): def test_filter_out_missing_top_level_packages(self) -> None: with tempfile.TemporaryDirectory() as td: - self.make_file(td, 'base/a/') - self.make_file(td, 'base/b.py') - self.make_file(td, 'base/c.pyi') - self.make_file(td, 'base/missing.txt') - self.make_file(td, 'typeshed/d.pyi') - self.make_file(td, 'typeshed/@python2/e') - self.make_file(td, 'pkg1/f-stubs') - self.make_file(td, 'pkg2/g-python2-stubs') - self.make_file(td, 'mpath/sub/long_name/') + self.make_file(td, "base/a/") + self.make_file(td, "base/b.py") + self.make_file(td, "base/c.pyi") + self.make_file(td, "base/missing.txt") + self.make_file(td, "typeshed/d.pyi") + self.make_file(td, "typeshed/@python2/e") + self.make_file(td, "pkg1/f-stubs") + self.make_file(td, "pkg2/g-python2-stubs") + self.make_file(td, "mpath/sub/long_name/") def makepath(p: str) -> str: return os.path.join(td, p) - search = SearchPaths(python_path=(makepath('base'),), - mypy_path=(makepath('mpath/sub'),), - package_path=(makepath('pkg1'), makepath('pkg2')), - typeshed_path=(makepath('typeshed'),)) + search = SearchPaths( + python_path=(makepath("base"),), + mypy_path=(makepath("mpath/sub"),), + package_path=(makepath("pkg1"), makepath("pkg2")), + typeshed_path=(makepath("typeshed"),), + ) fscache = FileSystemCache() res = filter_out_missing_top_level_packages( - {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'long_name', 'ff', 'missing'}, - search, - fscache) - assert res == {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'long_name'} + {"a", "b", "c", "d", "e", "f", "g", "long_name", "ff", "missing"}, search, fscache + ) + assert res == {"a", "b", "c", "d", "e", "f", "g", "long_name"} def make_file(self, base: str, path: str) -> None: fullpath = os.path.join(base, path) os.makedirs(os.path.dirname(fullpath), exist_ok=True) - if not path.endswith('/'): - with open(fullpath, 'w') as f: - f.write('# test file') + if not path.endswith("/"): + with open(fullpath, "w") as f: + f.write("# test file") diff --git a/mypy/test/testdeps.py b/mypy/test/testdeps.py index 1ca36b6ca4bbb..657982eef4673 100644 --- a/mypy/test/testdeps.py +++ b/mypy/test/testdeps.py @@ -2,32 +2,32 @@ import os from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple -from typing import List, Tuple, Dict, Optional, Set from typing_extensions import DefaultDict from mypy import build -from mypy.modulefinder import BuildSource from mypy.errors import CompileError -from mypy.nodes import MypyFile, Expression +from mypy.modulefinder import BuildSource +from mypy.nodes import Expression, MypyFile from mypy.options import Options from mypy.server.deps import get_dependencies from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.test.helpers import assert_string_arrays_equal, parse_options, find_test_files +from mypy.test.helpers import assert_string_arrays_equal, find_test_files, parse_options from mypy.types import Type from mypy.typestate import TypeState # Only dependencies in these modules are dumped -dumped_modules = ['__main__', 'pkg', 'pkg.mod'] +dumped_modules = ["__main__", "pkg", "pkg.mod"] class GetDependenciesSuite(DataSuite): files = find_test_files(pattern="deps*.test") def run_case(self, testcase: DataDrivenTestCase) -> None: - src = '\n'.join(testcase.input) - dump_all = '# __dump_all__' in src + src = "\n".join(testcase.input) + dump_all = "# __dump_all__" in src options = parse_options(src, testcase, incremental_step=1) options.use_builtins_fixtures = True options.show_traceback = True @@ -38,44 +38,46 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: a = messages if files is None or type_map is None: if not a: - a = ['Unknown compile error (likely syntax error in test case or fixture)'] + a = ["Unknown compile error (likely syntax error in test case or fixture)"] else: deps: DefaultDict[str, Set[str]] = defaultdict(set) for module in files: - if module in dumped_modules or dump_all and module not in ('abc', - 'typing', - 'mypy_extensions', - 'typing_extensions', - 'enum'): - new_deps = get_dependencies(files[module], type_map, options.python_version, - options) + if ( + module in dumped_modules + or dump_all + and module + not in ("abc", "typing", "mypy_extensions", "typing_extensions", "enum") + ): + new_deps = get_dependencies( + files[module], type_map, options.python_version, options + ) for source in new_deps: deps[source].update(new_deps[source]) TypeState.add_all_protocol_deps(deps) for source, targets in sorted(deps.items()): - if source.startswith((' {', '.join(sorted(targets))}" # Clean up output a bit - line = line.replace('__main__', 'm') + line = line.replace("__main__", "m") a.append(line) assert_string_arrays_equal( - testcase.output, a, - f'Invalid output ({testcase.file}, line {testcase.line})') + testcase.output, a, f"Invalid output ({testcase.file}, line {testcase.line})" + ) - def build(self, - source: str, - options: Options) -> Tuple[List[str], - Optional[Dict[str, MypyFile]], - Optional[Dict[Expression, Type]]]: + def build( + self, source: str, options: Options + ) -> Tuple[List[str], Optional[Dict[str, MypyFile]], Optional[Dict[Expression, Type]]]: try: - result = build.build(sources=[BuildSource('main', None, source)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource("main", None, source)], + options=options, + alt_lib_path=test_temp_dir, + ) except CompileError as e: # TODO: Should perhaps not return None here. return e.messages, None, None diff --git a/mypy/test/testdiff.py b/mypy/test/testdiff.py index 56f4564e91d3a..54a688bd00e60 100644 --- a/mypy/test/testdiff.py +++ b/mypy/test/testdiff.py @@ -1,29 +1,27 @@ """Test cases for AST diff (used for fine-grained incremental checking)""" import os -from typing import List, Tuple, Dict, Optional +from typing import Dict, List, Optional, Tuple from mypy import build -from mypy.modulefinder import BuildSource from mypy.defaults import PYTHON3_VERSION from mypy.errors import CompileError +from mypy.modulefinder import BuildSource from mypy.nodes import MypyFile from mypy.options import Options -from mypy.server.astdiff import snapshot_symbol_table, compare_symbol_table_snapshots +from mypy.server.astdiff import compare_symbol_table_snapshots, snapshot_symbol_table from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal, parse_options class ASTDiffSuite(DataSuite): - files = [ - 'diff.test', - ] + files = ["diff.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: - first_src = '\n'.join(testcase.input) + first_src = "\n".join(testcase.input) files_dict = dict(testcase.files) - second_src = files_dict['tmp/next.py'] + second_src = files_dict["tmp/next.py"] options = parse_options(first_src, testcase, 1) messages1, files1 = self.build(first_src, options) @@ -33,32 +31,36 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: if messages1: a.extend(messages1) if messages2: - a.append('== next ==') + a.append("== next ==") a.extend(messages2) - assert files1 is not None and files2 is not None, ('cases where CompileError' - ' occurred should not be run') - prefix = '__main__' - snapshot1 = snapshot_symbol_table(prefix, files1['__main__'].names) - snapshot2 = snapshot_symbol_table(prefix, files2['__main__'].names) + assert files1 is not None and files2 is not None, ( + "cases where CompileError" " occurred should not be run" + ) + prefix = "__main__" + snapshot1 = snapshot_symbol_table(prefix, files1["__main__"].names) + snapshot2 = snapshot_symbol_table(prefix, files2["__main__"].names) diff = compare_symbol_table_snapshots(prefix, snapshot1, snapshot2) for trigger in sorted(diff): a.append(trigger) assert_string_arrays_equal( - testcase.output, a, - f'Invalid output ({testcase.file}, line {testcase.line})') + testcase.output, a, f"Invalid output ({testcase.file}, line {testcase.line})" + ) - def build(self, source: str, - options: Options) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]: + def build( + self, source: str, options: Options + ) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]: options.use_builtins_fixtures = True options.show_traceback = True options.cache_dir = os.devnull options.python_version = PYTHON3_VERSION try: - result = build.build(sources=[BuildSource('main', None, source)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource("main", None, source)], + options=options, + alt_lib_path=test_temp_dir, + ) except CompileError as e: # TODO: Is it okay to return None? return e.messages, None diff --git a/mypy/test/testerrorstream.py b/mypy/test/testerrorstream.py index 278fc1152504b..551f0cf18a932 100644 --- a/mypy/test/testerrorstream.py +++ b/mypy/test/testerrorstream.py @@ -2,17 +2,17 @@ from typing import List from mypy import build -from mypy.test.helpers import assert_string_arrays_equal -from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.modulefinder import BuildSource from mypy.errors import CompileError +from mypy.modulefinder import BuildSource from mypy.options import Options +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import assert_string_arrays_equal class ErrorStreamSuite(DataSuite): required_out_section = True - base_path = '.' - files = ['errorstream.test'] + base_path = "." + files = ["errorstream.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: test_error_stream(testcase) @@ -30,17 +30,17 @@ def test_error_stream(testcase: DataDrivenTestCase) -> None: def flush_errors(msgs: List[str], serious: bool) -> None: if msgs: - logged_messages.append('==== Errors flushed ====') + logged_messages.append("==== Errors flushed ====") logged_messages.extend(msgs) - sources = [BuildSource('main', '__main__', '\n'.join(testcase.input))] + sources = [BuildSource("main", "__main__", "\n".join(testcase.input))] try: - build.build(sources=sources, - options=options, - flush_errors=flush_errors) + build.build(sources=sources, options=options, flush_errors=flush_errors) except CompileError as e: assert e.messages == [] - assert_string_arrays_equal(testcase.output, logged_messages, - 'Invalid output ({}, line {})'.format( - testcase.file, testcase.line)) + assert_string_arrays_equal( + testcase.output, + logged_messages, + "Invalid output ({}, line {})".format(testcase.file, testcase.line), + ) diff --git a/mypy/test/testfinegrained.py b/mypy/test/testfinegrained.py index 97b97f2c078ea..c88ab595443fc 100644 --- a/mypy/test/testfinegrained.py +++ b/mypy/test/testfinegrained.py @@ -14,28 +14,29 @@ import os import re +from typing import Any, Dict, List, Tuple, Union, cast -from typing import List, Dict, Any, Tuple, Union, cast +import pytest from mypy import build -from mypy.modulefinder import BuildSource +from mypy.config_parser import parse_config_file +from mypy.dmypy_server import Server +from mypy.dmypy_util import DEFAULT_STATUS_FILE from mypy.errors import CompileError +from mypy.find_sources import create_source_list +from mypy.modulefinder import BuildSource from mypy.options import Options +from mypy.server.mergecheck import check_consistency from mypy.test.config import test_temp_dir -from mypy.test.data import ( - DataDrivenTestCase, DataSuite, UpdateFile, DeleteFile -) +from mypy.test.data import DataDrivenTestCase, DataSuite, DeleteFile, UpdateFile from mypy.test.helpers import ( - assert_string_arrays_equal, parse_options, assert_module_equivalence, - assert_target_equivalence, perform_file_operations, find_test_files, + assert_module_equivalence, + assert_string_arrays_equal, + assert_target_equivalence, + find_test_files, + parse_options, + perform_file_operations, ) -from mypy.server.mergecheck import check_consistency -from mypy.dmypy_util import DEFAULT_STATUS_FILE -from mypy.dmypy_server import Server -from mypy.config_parser import parse_config_file -from mypy.find_sources import create_source_list - -import pytest # Set to True to perform (somewhat expensive) checks for duplicate AST nodes after merge CHECK_CONSISTENCY = False @@ -55,14 +56,14 @@ def should_skip(self, testcase: DataDrivenTestCase) -> bool: # as a filter() classmethod also, but we want the tests reported # as skipped, not just elided. if self.use_cache: - if testcase.only_when == '-only_when_nocache': + if testcase.only_when == "-only_when_nocache": return True # TODO: In caching mode we currently don't well support # starting from cached states with errors in them. - if testcase.output and testcase.output[0] != '==': + if testcase.output and testcase.output[0] != "==": return True else: - if testcase.only_when == '-only_when_cache': + if testcase.only_when == "-only_when_cache": return True return False @@ -72,9 +73,9 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: pytest.skip() return - main_src = '\n'.join(testcase.input) - main_path = os.path.join(test_temp_dir, 'main') - with open(main_path, 'w', encoding='utf8') as f: + main_src = "\n".join(testcase.input) + main_path = os.path.join(test_temp_dir, "main") + with open(main_path, "w", encoding="utf8") as f: f.write(main_src) options = self.get_options(main_src, testcase, build_cache=False) @@ -115,27 +116,25 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: step, num_regular_incremental_steps, ) - a.append('==') + a.append("==") a.extend(output) all_triggered.extend(triggered) # Normalize paths in test output (for Windows). - a = [line.replace('\\', '/') for line in a] + a = [line.replace("\\", "/") for line in a] assert_string_arrays_equal( - testcase.output, a, - f'Invalid output ({testcase.file}, line {testcase.line})') + testcase.output, a, f"Invalid output ({testcase.file}, line {testcase.line})" + ) if testcase.triggered: assert_string_arrays_equal( testcase.triggered, self.format_triggered(all_triggered), - f'Invalid active triggers ({testcase.file}, line {testcase.line})') + f"Invalid active triggers ({testcase.file}, line {testcase.line})", + ) - def get_options(self, - source: str, - testcase: DataDrivenTestCase, - build_cache: bool,) -> Options: + def get_options(self, source: str, testcase: DataDrivenTestCase, build_cache: bool) -> Options: # This handles things like '# flags: --foo'. options = parse_options(source, testcase, incremental_step=1) options.incremental = True @@ -146,12 +145,12 @@ def get_options(self, options.use_fine_grained_cache = self.use_cache and not build_cache options.cache_fine_grained = self.use_cache options.local_partial_types = True - if re.search('flags:.*--follow-imports', source) is None: + if re.search("flags:.*--follow-imports", source) is None: # Override the default for follow_imports - options.follow_imports = 'error' + options.follow_imports = "error" for name, _ in testcase.files: - if 'mypy.ini' in name or 'pyproject.toml' in name: + if "mypy.ini" in name or "pyproject.toml" in name: parse_config_file(options, lambda: None, name) break @@ -159,15 +158,12 @@ def get_options(self, def run_check(self, server: Server, sources: List[BuildSource]) -> List[str]: response = server.check(sources, is_tty=False, terminal_width=-1) - out = cast(str, response['out'] or response['err']) + out = cast(str, response["out"] or response["err"]) return out.splitlines() - def build(self, - options: Options, - sources: List[BuildSource]) -> List[str]: + def build(self, options: Options, sources: List[BuildSource]) -> List[str]: try: - result = build.build(sources=sources, - options=options) + result = build.build(sources=sources, options=options) except CompileError as e: return e.messages return result.errors @@ -175,30 +171,31 @@ def build(self, def format_triggered(self, triggered: List[List[str]]) -> List[str]: result = [] for n, triggers in enumerate(triggered): - filtered = [trigger for trigger in triggers - if not trigger.endswith('__>')] + filtered = [trigger for trigger in triggers if not trigger.endswith("__>")] filtered = sorted(filtered) - result.append(('%d: %s' % (n + 2, ', '.join(filtered))).strip()) + result.append(("%d: %s" % (n + 2, ", ".join(filtered))).strip()) return result def get_build_steps(self, program_text: str) -> int: """Get the number of regular incremental steps to run, from the test source""" if not self.use_cache: return 0 - m = re.search('# num_build_steps: ([0-9]+)$', program_text, flags=re.MULTILINE) + m = re.search("# num_build_steps: ([0-9]+)$", program_text, flags=re.MULTILINE) if m is not None: return int(m.group(1)) return 1 - def perform_step(self, - operations: List[Union[UpdateFile, DeleteFile]], - server: Server, - options: Options, - build_options: Options, - testcase: DataDrivenTestCase, - main_src: str, - step: int, - num_regular_incremental_steps: int) -> Tuple[List[str], List[List[str]]]: + def perform_step( + self, + operations: List[Union[UpdateFile, DeleteFile]], + server: Server, + options: Options, + build_options: Options, + testcase: DataDrivenTestCase, + main_src: str, + step: int, + num_regular_incremental_steps: int, + ) -> Tuple[List[str], List[List[str]]]: """Perform one fine-grained incremental build step (after some file updates/deletions). Return (mypy output, triggered targets). @@ -226,21 +223,15 @@ def perform_step(self, expected_stale = testcase.expected_stale_modules.get(step - 1) if expected_stale is not None: - assert_module_equivalence( - 'stale' + str(step - 1), - expected_stale, changed) + assert_module_equivalence("stale" + str(step - 1), expected_stale, changed) expected_rechecked = testcase.expected_rechecked_modules.get(step - 1) if expected_rechecked is not None: - assert_module_equivalence( - 'rechecked' + str(step - 1), - expected_rechecked, updated) + assert_module_equivalence("rechecked" + str(step - 1), expected_rechecked, updated) expected = testcase.expected_fine_grained_targets.get(step) if expected: - assert_target_equivalence( - 'targets' + str(step), - expected, targets) + assert_target_equivalence("targets" + str(step), expected, targets) new_messages = normalize_messages(new_messages) @@ -250,9 +241,9 @@ def perform_step(self, return a, triggered - def parse_sources(self, program_text: str, - incremental_step: int, - options: Options) -> List[BuildSource]: + def parse_sources( + self, program_text: str, incremental_step: int, options: Options + ) -> List[BuildSource]: """Return target BuildSources for a test case. Normally, the unit tests will check all files included in the test @@ -269,8 +260,8 @@ def parse_sources(self, program_text: str, step N (2, 3, ...). """ - m = re.search('# cmd: mypy ([a-zA-Z0-9_./ ]+)$', program_text, flags=re.MULTILINE) - regex = f'# cmd{incremental_step}: mypy ([a-zA-Z0-9_./ ]+)$' + m = re.search("# cmd: mypy ([a-zA-Z0-9_./ ]+)$", program_text, flags=re.MULTILINE) + regex = f"# cmd{incremental_step}: mypy ([a-zA-Z0-9_./ ]+)$" alt_m = re.search(regex, program_text, flags=re.MULTILINE) if alt_m is not None: # Optionally return a different command if in a later step @@ -283,48 +274,54 @@ def parse_sources(self, program_text: str, paths = [os.path.join(test_temp_dir, path) for path in m.group(1).strip().split()] return create_source_list(paths, options) else: - base = BuildSource(os.path.join(test_temp_dir, 'main'), '__main__', None) + base = BuildSource(os.path.join(test_temp_dir, "main"), "__main__", None) # Use expand_dir instead of create_source_list to avoid complaints # when there aren't any .py files in an increment - return [base] + create_source_list([test_temp_dir], options, - allow_empty_dir=True) + return [base] + create_source_list([test_temp_dir], options, allow_empty_dir=True) def maybe_suggest(self, step: int, server: Server, src: str, tmp_dir: str) -> List[str]: output: List[str] = [] targets = self.get_suggest(src, step) for flags, target in targets: - json = '--json' in flags - callsites = '--callsites' in flags - no_any = '--no-any' in flags - no_errors = '--no-errors' in flags - try_text = '--try-text' in flags - m = re.match('--flex-any=([0-9.]+)', flags) + json = "--json" in flags + callsites = "--callsites" in flags + no_any = "--no-any" in flags + no_errors = "--no-errors" in flags + try_text = "--try-text" in flags + m = re.match("--flex-any=([0-9.]+)", flags) flex_any = float(m.group(1)) if m else None - m = re.match(r'--use-fixme=(\w+)', flags) + m = re.match(r"--use-fixme=(\w+)", flags) use_fixme = m.group(1) if m else None - m = re.match('--max-guesses=([0-9]+)', flags) + m = re.match("--max-guesses=([0-9]+)", flags) max_guesses = int(m.group(1)) if m else None - res = cast(Dict[str, Any], - server.cmd_suggest( - target.strip(), json=json, no_any=no_any, no_errors=no_errors, - try_text=try_text, flex_any=flex_any, use_fixme=use_fixme, - callsites=callsites, max_guesses=max_guesses)) - val = res['error'] if 'error' in res else res['out'] + res['err'] + res = cast( + Dict[str, Any], + server.cmd_suggest( + target.strip(), + json=json, + no_any=no_any, + no_errors=no_errors, + try_text=try_text, + flex_any=flex_any, + use_fixme=use_fixme, + callsites=callsites, + max_guesses=max_guesses, + ), + ) + val = res["error"] if "error" in res else res["out"] + res["err"] if json: # JSON contains already escaped \ on Windows, so requires a bit of care. - val = val.replace('\\\\', '\\') - val = val.replace(os.path.realpath(tmp_dir) + os.path.sep, '') - output.extend(val.strip().split('\n')) + val = val.replace("\\\\", "\\") + val = val.replace(os.path.realpath(tmp_dir) + os.path.sep, "") + output.extend(val.strip().split("\n")) return normalize_messages(output) - def get_suggest(self, program_text: str, - incremental_step: int) -> List[Tuple[str, str]]: - step_bit = '1?' if incremental_step == 1 else str(incremental_step) - regex = f'# suggest{step_bit}: (--[a-zA-Z0-9_\\-./=?^ ]+ )*([a-zA-Z0-9_.:/?^ ]+)$' + def get_suggest(self, program_text: str, incremental_step: int) -> List[Tuple[str, str]]: + step_bit = "1?" if incremental_step == 1 else str(incremental_step) + regex = f"# suggest{step_bit}: (--[a-zA-Z0-9_\\-./=?^ ]+ )*([a-zA-Z0-9_.:/?^ ]+)$" m = re.findall(regex, program_text, flags=re.MULTILINE) return m def normalize_messages(messages: List[str]) -> List[str]: - return [re.sub('^tmp' + re.escape(os.sep), '', message) - for message in messages] + return [re.sub("^tmp" + re.escape(os.sep), "", message) for message in messages] diff --git a/mypy/test/testfinegrainedcache.py b/mypy/test/testfinegrainedcache.py index ee03f0b688f4a..c6c4406a0353b 100644 --- a/mypy/test/testfinegrainedcache.py +++ b/mypy/test/testfinegrainedcache.py @@ -10,6 +10,7 @@ class FineGrainedCacheSuite(mypy.test.testfinegrained.FineGrainedSuite): use_cache = True - test_name_suffix = '_cached' - files = ( - mypy.test.testfinegrained.FineGrainedSuite.files + ['fine-grained-cache-incremental.test']) + test_name_suffix = "_cached" + files = mypy.test.testfinegrained.FineGrainedSuite.files + [ + "fine-grained-cache-incremental.test" + ] diff --git a/mypy/test/testformatter.py b/mypy/test/testformatter.py index 623c7a62753fd..2f209a26aebd6 100644 --- a/mypy/test/testformatter.py +++ b/mypy/test/testformatter.py @@ -1,51 +1,83 @@ from unittest import TestCase, main -from mypy.util import trim_source_line, split_words +from mypy.util import split_words, trim_source_line class FancyErrorFormattingTestCases(TestCase): def test_trim_source(self) -> None: - assert trim_source_line('0123456789abcdef', - max_len=16, col=5, min_width=2) == ('0123456789abcdef', 0) + assert trim_source_line("0123456789abcdef", max_len=16, col=5, min_width=2) == ( + "0123456789abcdef", + 0, + ) # Locations near start. - assert trim_source_line('0123456789abcdef', - max_len=7, col=0, min_width=2) == ('0123456...', 0) - assert trim_source_line('0123456789abcdef', - max_len=7, col=4, min_width=2) == ('0123456...', 0) + assert trim_source_line("0123456789abcdef", max_len=7, col=0, min_width=2) == ( + "0123456...", + 0, + ) + assert trim_source_line("0123456789abcdef", max_len=7, col=4, min_width=2) == ( + "0123456...", + 0, + ) # Middle locations. - assert trim_source_line('0123456789abcdef', - max_len=7, col=5, min_width=2) == ('...1234567...', -2) - assert trim_source_line('0123456789abcdef', - max_len=7, col=6, min_width=2) == ('...2345678...', -1) - assert trim_source_line('0123456789abcdef', - max_len=7, col=8, min_width=2) == ('...456789a...', 1) + assert trim_source_line("0123456789abcdef", max_len=7, col=5, min_width=2) == ( + "...1234567...", + -2, + ) + assert trim_source_line("0123456789abcdef", max_len=7, col=6, min_width=2) == ( + "...2345678...", + -1, + ) + assert trim_source_line("0123456789abcdef", max_len=7, col=8, min_width=2) == ( + "...456789a...", + 1, + ) # Locations near the end. - assert trim_source_line('0123456789abcdef', - max_len=7, col=11, min_width=2) == ('...789abcd...', 4) - assert trim_source_line('0123456789abcdef', - max_len=7, col=13, min_width=2) == ('...9abcdef', 6) - assert trim_source_line('0123456789abcdef', - max_len=7, col=15, min_width=2) == ('...9abcdef', 6) + assert trim_source_line("0123456789abcdef", max_len=7, col=11, min_width=2) == ( + "...789abcd...", + 4, + ) + assert trim_source_line("0123456789abcdef", max_len=7, col=13, min_width=2) == ( + "...9abcdef", + 6, + ) + assert trim_source_line("0123456789abcdef", max_len=7, col=15, min_width=2) == ( + "...9abcdef", + 6, + ) def test_split_words(self) -> None: - assert split_words('Simple message') == ['Simple', 'message'] - assert split_words('Message with "Some[Long, Types]"' - ' in it') == ['Message', 'with', - '"Some[Long, Types]"', 'in', 'it'] - assert split_words('Message with "Some[Long, Types]"' - ' and [error-code]') == ['Message', 'with', '"Some[Long, Types]"', - 'and', '[error-code]'] - assert split_words('"Type[Stands, First]" then words') == ['"Type[Stands, First]"', - 'then', 'words'] - assert split_words('First words "Then[Stands, Type]"') == ['First', 'words', - '"Then[Stands, Type]"'] + assert split_words("Simple message") == ["Simple", "message"] + assert split_words('Message with "Some[Long, Types]"' " in it") == [ + "Message", + "with", + '"Some[Long, Types]"', + "in", + "it", + ] + assert split_words('Message with "Some[Long, Types]"' " and [error-code]") == [ + "Message", + "with", + '"Some[Long, Types]"', + "and", + "[error-code]", + ] + assert split_words('"Type[Stands, First]" then words') == [ + '"Type[Stands, First]"', + "then", + "words", + ] + assert split_words('First words "Then[Stands, Type]"') == [ + "First", + "words", + '"Then[Stands, Type]"', + ] assert split_words('"Type[Only, Here]"') == ['"Type[Only, Here]"'] - assert split_words('OneWord') == ['OneWord'] - assert split_words(' ') == ['', ''] + assert split_words("OneWord") == ["OneWord"] + assert split_words(" ") == ["", ""] -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mypy/test/testfscache.py b/mypy/test/testfscache.py index 73f926ad748c5..4ba5cf713d3f2 100644 --- a/mypy/test/testfscache.py +++ b/mypy/test/testfscache.py @@ -22,79 +22,79 @@ def tearDown(self) -> None: shutil.rmtree(self.tempdir) def test_isfile_case_1(self) -> None: - self.make_file('bar.py') - self.make_file('pkg/sub_package/__init__.py') - self.make_file('pkg/sub_package/foo.py') + self.make_file("bar.py") + self.make_file("pkg/sub_package/__init__.py") + self.make_file("pkg/sub_package/foo.py") # Run twice to test both cached and non-cached code paths. for i in range(2): - assert self.isfile_case('bar.py') - assert self.isfile_case('pkg/sub_package/__init__.py') - assert self.isfile_case('pkg/sub_package/foo.py') - assert not self.isfile_case('non_existent.py') - assert not self.isfile_case('pkg/non_existent.py') - assert not self.isfile_case('pkg/') - assert not self.isfile_case('bar.py/') + assert self.isfile_case("bar.py") + assert self.isfile_case("pkg/sub_package/__init__.py") + assert self.isfile_case("pkg/sub_package/foo.py") + assert not self.isfile_case("non_existent.py") + assert not self.isfile_case("pkg/non_existent.py") + assert not self.isfile_case("pkg/") + assert not self.isfile_case("bar.py/") for i in range(2): - assert not self.isfile_case('Bar.py') - assert not self.isfile_case('pkg/sub_package/__init__.PY') - assert not self.isfile_case('pkg/Sub_Package/foo.py') - assert not self.isfile_case('Pkg/sub_package/foo.py') + assert not self.isfile_case("Bar.py") + assert not self.isfile_case("pkg/sub_package/__init__.PY") + assert not self.isfile_case("pkg/Sub_Package/foo.py") + assert not self.isfile_case("Pkg/sub_package/foo.py") def test_isfile_case_2(self) -> None: - self.make_file('bar.py') - self.make_file('pkg/sub_package/__init__.py') - self.make_file('pkg/sub_package/foo.py') + self.make_file("bar.py") + self.make_file("pkg/sub_package/__init__.py") + self.make_file("pkg/sub_package/foo.py") # Run twice to test both cached and non-cached code paths. # This reverses the order of checks from test_isfile_case_1. for i in range(2): - assert not self.isfile_case('Bar.py') - assert not self.isfile_case('pkg/sub_package/__init__.PY') - assert not self.isfile_case('pkg/Sub_Package/foo.py') - assert not self.isfile_case('Pkg/sub_package/foo.py') + assert not self.isfile_case("Bar.py") + assert not self.isfile_case("pkg/sub_package/__init__.PY") + assert not self.isfile_case("pkg/Sub_Package/foo.py") + assert not self.isfile_case("Pkg/sub_package/foo.py") for i in range(2): - assert self.isfile_case('bar.py') - assert self.isfile_case('pkg/sub_package/__init__.py') - assert self.isfile_case('pkg/sub_package/foo.py') - assert not self.isfile_case('non_existent.py') - assert not self.isfile_case('pkg/non_existent.py') + assert self.isfile_case("bar.py") + assert self.isfile_case("pkg/sub_package/__init__.py") + assert self.isfile_case("pkg/sub_package/foo.py") + assert not self.isfile_case("non_existent.py") + assert not self.isfile_case("pkg/non_existent.py") def test_isfile_case_3(self) -> None: - self.make_file('bar.py') - self.make_file('pkg/sub_package/__init__.py') - self.make_file('pkg/sub_package/foo.py') + self.make_file("bar.py") + self.make_file("pkg/sub_package/__init__.py") + self.make_file("pkg/sub_package/foo.py") # Run twice to test both cached and non-cached code paths. for i in range(2): - assert self.isfile_case('bar.py') - assert not self.isfile_case('non_existent.py') - assert not self.isfile_case('pkg/non_existent.py') - assert not self.isfile_case('Bar.py') - assert not self.isfile_case('pkg/sub_package/__init__.PY') - assert not self.isfile_case('pkg/Sub_Package/foo.py') - assert not self.isfile_case('Pkg/sub_package/foo.py') - assert self.isfile_case('pkg/sub_package/__init__.py') - assert self.isfile_case('pkg/sub_package/foo.py') + assert self.isfile_case("bar.py") + assert not self.isfile_case("non_existent.py") + assert not self.isfile_case("pkg/non_existent.py") + assert not self.isfile_case("Bar.py") + assert not self.isfile_case("pkg/sub_package/__init__.PY") + assert not self.isfile_case("pkg/Sub_Package/foo.py") + assert not self.isfile_case("Pkg/sub_package/foo.py") + assert self.isfile_case("pkg/sub_package/__init__.py") + assert self.isfile_case("pkg/sub_package/foo.py") def test_isfile_case_other_directory(self) -> None: - self.make_file('bar.py') + self.make_file("bar.py") with tempfile.TemporaryDirectory() as other: - self.make_file('other_dir.py', base=other) - self.make_file('pkg/other_dir.py', base=other) - assert self.isfile_case(os.path.join(other, 'other_dir.py')) - assert not self.isfile_case(os.path.join(other, 'Other_Dir.py')) - assert not self.isfile_case(os.path.join(other, 'bar.py')) - if sys.platform in ('win32', 'darwin'): + self.make_file("other_dir.py", base=other) + self.make_file("pkg/other_dir.py", base=other) + assert self.isfile_case(os.path.join(other, "other_dir.py")) + assert not self.isfile_case(os.path.join(other, "Other_Dir.py")) + assert not self.isfile_case(os.path.join(other, "bar.py")) + if sys.platform in ("win32", "darwin"): # We only check case for directories under our prefix, and since # this path is not under the prefix, case difference is fine. - assert self.isfile_case(os.path.join(other, 'PKG/other_dir.py')) + assert self.isfile_case(os.path.join(other, "PKG/other_dir.py")) def make_file(self, path: str, base: Optional[str] = None) -> None: if base is None: base = self.tempdir fullpath = os.path.join(base, path) os.makedirs(os.path.dirname(fullpath), exist_ok=True) - if not path.endswith('/'): - with open(fullpath, 'w') as f: - f.write('# test file') + if not path.endswith("/"): + with open(fullpath, "w") as f: + f.write("# test file") def isfile_case(self, path: str) -> bool: return self.fscache.isfile_case(os.path.join(self.tempdir, path), self.tempdir) diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index 7d32db2b1c1c4..fb22452ddac66 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -1,27 +1,33 @@ """Test cases for graph processing code in build.py.""" import sys -from typing import AbstractSet, Dict, Set, List +from typing import AbstractSet, Dict, List, Set -from mypy.test.helpers import assert_equal, Suite -from mypy.build import BuildManager, State, BuildSourceSet +from mypy.build import ( + BuildManager, + BuildSourceSet, + State, + order_ascc, + sorted_components, + strongly_connected_components, + topsort, +) +from mypy.errors import Errors +from mypy.fscache import FileSystemCache from mypy.modulefinder import SearchPaths -from mypy.build import topsort, strongly_connected_components, sorted_components, order_ascc -from mypy.version import __version__ from mypy.options import Options -from mypy.report import Reports from mypy.plugin import Plugin -from mypy.errors import Errors -from mypy.fscache import FileSystemCache +from mypy.report import Reports +from mypy.test.helpers import Suite, assert_equal +from mypy.version import __version__ class GraphSuite(Suite): - def test_topsort(self) -> None: - a = frozenset({'A'}) - b = frozenset({'B'}) - c = frozenset({'C'}) - d = frozenset({'D'}) + a = frozenset({"A"}) + b = frozenset({"B"}) + c = frozenset({"C"}) + d = frozenset({"D"}) data: Dict[AbstractSet[str], Set[AbstractSet[str]]] = {a: {b, c}, b: {d}, c: {d}} res = list(topsort(data)) assert_equal(res, [{d}, {b, c}, {a}]) @@ -30,10 +36,7 @@ def test_scc(self) -> None: vertices = {"A", "B", "C", "D"} edges: Dict[str, List[str]] = {"A": ["B", "C"], "B": ["C"], "C": ["B", "D"], "D": []} sccs = {frozenset(x) for x in strongly_connected_components(vertices, edges)} - assert_equal(sccs, - {frozenset({'A'}), - frozenset({'B', 'C'}), - frozenset({'D'})}) + assert_equal(sccs, {frozenset({"A"}), frozenset({"B", "C"}), frozenset({"D"})}) def _make_manager(self) -> BuildManager: errors = Errors() @@ -41,11 +44,11 @@ def _make_manager(self) -> BuildManager: fscache = FileSystemCache() search_paths = SearchPaths((), (), (), ()) manager = BuildManager( - data_dir='', + data_dir="", search_paths=search_paths, - ignore_prefix='', + ignore_prefix="", source_set=BuildSourceSet([]), - reports=Reports('', {}), + reports=Reports("", {}), options=options, version_id=__version__, plugin=Plugin(options), @@ -60,23 +63,25 @@ def _make_manager(self) -> BuildManager: def test_sorted_components(self) -> None: manager = self._make_manager() - graph = {'a': State('a', None, 'import b, c', manager), - 'd': State('d', None, 'pass', manager), - 'b': State('b', None, 'import c', manager), - 'c': State('c', None, 'import b, d', manager), - } + graph = { + "a": State("a", None, "import b, c", manager), + "d": State("d", None, "pass", manager), + "b": State("b", None, "import c", manager), + "c": State("c", None, "import b, d", manager), + } res = sorted_components(graph) - assert_equal(res, [frozenset({'d'}), frozenset({'c', 'b'}), frozenset({'a'})]) + assert_equal(res, [frozenset({"d"}), frozenset({"c", "b"}), frozenset({"a"})]) def test_order_ascc(self) -> None: manager = self._make_manager() - graph = {'a': State('a', None, 'import b, c', manager), - 'd': State('d', None, 'def f(): import a', manager), - 'b': State('b', None, 'import c', manager), - 'c': State('c', None, 'import b, d', manager), - } + graph = { + "a": State("a", None, "import b, c", manager), + "d": State("d", None, "def f(): import a", manager), + "b": State("b", None, "import c", manager), + "c": State("c", None, "import b, d", manager), + } res = sorted_components(graph) - assert_equal(res, [frozenset({'a', 'd', 'c', 'b'})]) + assert_equal(res, [frozenset({"a", "d", "c", "b"})]) ascc = res[0] scc = order_ascc(graph, ascc) - assert_equal(scc, ['d', 'c', 'b', 'a']) + assert_equal(scc, ["d", "c", "b", "a"]) diff --git a/mypy/test/testinfer.py b/mypy/test/testinfer.py index afb66a7d09e1b..ac2b231598410 100644 --- a/mypy/test/testinfer.py +++ b/mypy/test/testinfer.py @@ -1,14 +1,14 @@ """Test cases for type inference helper functions.""" -from typing import List, Optional, Tuple, Union, Dict, Set +from typing import Dict, List, Optional, Set, Tuple, Union -from mypy.test.helpers import Suite, assert_equal from mypy.argmap import map_actuals_to_formals -from mypy.checker import group_comparison_operands, DisjointDict +from mypy.checker import DisjointDict, group_comparison_operands from mypy.literals import Key -from mypy.nodes import ArgKind, ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, NameExpr -from mypy.types import AnyType, TupleType, Type, TypeOfAny +from mypy.nodes import ARG_NAMED, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind, NameExpr +from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture +from mypy.types import AnyType, TupleType, Type, TypeOfAny class MapActualsToFormalsSuite(Suite): @@ -18,162 +18,81 @@ def test_basic(self) -> None: self.assert_map([], [], []) def test_positional_only(self) -> None: - self.assert_map([ARG_POS], - [ARG_POS], - [[0]]) - self.assert_map([ARG_POS, ARG_POS], - [ARG_POS, ARG_POS], - [[0], [1]]) + self.assert_map([ARG_POS], [ARG_POS], [[0]]) + self.assert_map([ARG_POS, ARG_POS], [ARG_POS, ARG_POS], [[0], [1]]) def test_optional(self) -> None: - self.assert_map([], - [ARG_OPT], - [[]]) - self.assert_map([ARG_POS], - [ARG_OPT], - [[0]]) - self.assert_map([ARG_POS], - [ARG_OPT, ARG_OPT], - [[0], []]) + self.assert_map([], [ARG_OPT], [[]]) + self.assert_map([ARG_POS], [ARG_OPT], [[0]]) + self.assert_map([ARG_POS], [ARG_OPT, ARG_OPT], [[0], []]) def test_callee_star(self) -> None: - self.assert_map([], - [ARG_STAR], - [[]]) - self.assert_map([ARG_POS], - [ARG_STAR], - [[0]]) - self.assert_map([ARG_POS, ARG_POS], - [ARG_STAR], - [[0, 1]]) + self.assert_map([], [ARG_STAR], [[]]) + self.assert_map([ARG_POS], [ARG_STAR], [[0]]) + self.assert_map([ARG_POS, ARG_POS], [ARG_STAR], [[0, 1]]) def test_caller_star(self) -> None: - self.assert_map([ARG_STAR], - [ARG_STAR], - [[0]]) - self.assert_map([ARG_POS, ARG_STAR], - [ARG_STAR], - [[0, 1]]) - self.assert_map([ARG_STAR], - [ARG_POS, ARG_STAR], - [[0], [0]]) - self.assert_map([ARG_STAR], - [ARG_OPT, ARG_STAR], - [[0], [0]]) + self.assert_map([ARG_STAR], [ARG_STAR], [[0]]) + self.assert_map([ARG_POS, ARG_STAR], [ARG_STAR], [[0, 1]]) + self.assert_map([ARG_STAR], [ARG_POS, ARG_STAR], [[0], [0]]) + self.assert_map([ARG_STAR], [ARG_OPT, ARG_STAR], [[0], [0]]) def test_too_many_caller_args(self) -> None: - self.assert_map([ARG_POS], - [], - []) - self.assert_map([ARG_STAR], - [], - []) - self.assert_map([ARG_STAR], - [ARG_POS], - [[0]]) + self.assert_map([ARG_POS], [], []) + self.assert_map([ARG_STAR], [], []) + self.assert_map([ARG_STAR], [ARG_POS], [[0]]) def test_tuple_star(self) -> None: any_type = AnyType(TypeOfAny.special_form) + self.assert_vararg_map([ARG_STAR], [ARG_POS], [[0]], self.tuple(any_type)) self.assert_vararg_map( - [ARG_STAR], - [ARG_POS], - [[0]], - self.tuple(any_type)) - self.assert_vararg_map( - [ARG_STAR], - [ARG_POS, ARG_POS], - [[0], [0]], - self.tuple(any_type, any_type)) + [ARG_STAR], [ARG_POS, ARG_POS], [[0], [0]], self.tuple(any_type, any_type) + ) self.assert_vararg_map( - [ARG_STAR], - [ARG_POS, ARG_OPT, ARG_OPT], - [[0], [0], []], - self.tuple(any_type, any_type)) + [ARG_STAR], [ARG_POS, ARG_OPT, ARG_OPT], [[0], [0], []], self.tuple(any_type, any_type) + ) def tuple(self, *args: Type) -> TupleType: return TupleType(list(args), TypeFixture().std_tuple) def test_named_args(self) -> None: - self.assert_map( - ['x'], - [(ARG_POS, 'x')], - [[0]]) - self.assert_map( - ['y', 'x'], - [(ARG_POS, 'x'), (ARG_POS, 'y')], - [[1], [0]]) + self.assert_map(["x"], [(ARG_POS, "x")], [[0]]) + self.assert_map(["y", "x"], [(ARG_POS, "x"), (ARG_POS, "y")], [[1], [0]]) def test_some_named_args(self) -> None: - self.assert_map( - ['y'], - [(ARG_OPT, 'x'), (ARG_OPT, 'y'), (ARG_OPT, 'z')], - [[], [0], []]) + self.assert_map(["y"], [(ARG_OPT, "x"), (ARG_OPT, "y"), (ARG_OPT, "z")], [[], [0], []]) def test_missing_named_arg(self) -> None: - self.assert_map( - ['y'], - [(ARG_OPT, 'x')], - [[]]) + self.assert_map(["y"], [(ARG_OPT, "x")], [[]]) def test_duplicate_named_arg(self) -> None: - self.assert_map( - ['x', 'x'], - [(ARG_OPT, 'x')], - [[0, 1]]) + self.assert_map(["x", "x"], [(ARG_OPT, "x")], [[0, 1]]) def test_varargs_and_bare_asterisk(self) -> None: - self.assert_map( - [ARG_STAR], - [ARG_STAR, (ARG_NAMED, 'x')], - [[0], []]) - self.assert_map( - [ARG_STAR, 'x'], - [ARG_STAR, (ARG_NAMED, 'x')], - [[0], [1]]) + self.assert_map([ARG_STAR], [ARG_STAR, (ARG_NAMED, "x")], [[0], []]) + self.assert_map([ARG_STAR, "x"], [ARG_STAR, (ARG_NAMED, "x")], [[0], [1]]) def test_keyword_varargs(self) -> None: - self.assert_map( - ['x'], - [ARG_STAR2], - [[0]]) - self.assert_map( - ['x', ARG_STAR2], - [ARG_STAR2], - [[0, 1]]) - self.assert_map( - ['x', ARG_STAR2], - [(ARG_POS, 'x'), ARG_STAR2], - [[0], [1]]) - self.assert_map( - [ARG_POS, ARG_STAR2], - [(ARG_POS, 'x'), ARG_STAR2], - [[0], [1]]) + self.assert_map(["x"], [ARG_STAR2], [[0]]) + self.assert_map(["x", ARG_STAR2], [ARG_STAR2], [[0, 1]]) + self.assert_map(["x", ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [1]]) + self.assert_map([ARG_POS, ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [1]]) def test_both_kinds_of_varargs(self) -> None: - self.assert_map( - [ARG_STAR, ARG_STAR2], - [(ARG_POS, 'x'), (ARG_POS, 'y')], - [[0, 1], [0, 1]]) + self.assert_map([ARG_STAR, ARG_STAR2], [(ARG_POS, "x"), (ARG_POS, "y")], [[0, 1], [0, 1]]) def test_special_cases(self) -> None: - self.assert_map([ARG_STAR], - [ARG_STAR, ARG_STAR2], - [[0], []]) - self.assert_map([ARG_STAR, ARG_STAR2], - [ARG_STAR, ARG_STAR2], - [[0], [1]]) - self.assert_map([ARG_STAR2], - [(ARG_POS, 'x'), ARG_STAR2], - [[0], [0]]) - self.assert_map([ARG_STAR2], - [ARG_STAR2], - [[0]]) - - def assert_map(self, - caller_kinds_: List[Union[ArgKind, str]], - callee_kinds_: List[Union[ArgKind, Tuple[ArgKind, str]]], - expected: List[List[int]], - ) -> None: + self.assert_map([ARG_STAR], [ARG_STAR, ARG_STAR2], [[0], []]) + self.assert_map([ARG_STAR, ARG_STAR2], [ARG_STAR, ARG_STAR2], [[0], [1]]) + self.assert_map([ARG_STAR2], [(ARG_POS, "x"), ARG_STAR2], [[0], [0]]) + self.assert_map([ARG_STAR2], [ARG_STAR2], [[0]]) + + def assert_map( + self, + caller_kinds_: List[Union[ArgKind, str]], + callee_kinds_: List[Union[ArgKind, Tuple[ArgKind, str]]], + expected: List[List[int]], + ) -> None: caller_kinds, caller_names = expand_caller_kinds(caller_kinds_) callee_kinds, callee_names = expand_callee_kinds(callee_kinds_) result = map_actuals_to_formals( @@ -181,26 +100,24 @@ def assert_map(self, caller_names, callee_kinds, callee_names, - lambda i: AnyType(TypeOfAny.special_form)) + lambda i: AnyType(TypeOfAny.special_form), + ) assert_equal(result, expected) - def assert_vararg_map(self, - caller_kinds: List[ArgKind], - callee_kinds: List[ArgKind], - expected: List[List[int]], - vararg_type: Type, - ) -> None: - result = map_actuals_to_formals( - caller_kinds, - [], - callee_kinds, - [], - lambda i: vararg_type) + def assert_vararg_map( + self, + caller_kinds: List[ArgKind], + callee_kinds: List[ArgKind], + expected: List[List[int]], + vararg_type: Type, + ) -> None: + result = map_actuals_to_formals(caller_kinds, [], callee_kinds, [], lambda i: vararg_type) assert_equal(result, expected) -def expand_caller_kinds(kinds_or_names: List[Union[ArgKind, str]] - ) -> Tuple[List[ArgKind], List[Optional[str]]]: +def expand_caller_kinds( + kinds_or_names: List[Union[ArgKind, str]] +) -> Tuple[List[ArgKind], List[Optional[str]]]: kinds = [] names: List[Optional[str]] = [] for k in kinds_or_names: @@ -213,8 +130,9 @@ def expand_caller_kinds(kinds_or_names: List[Union[ArgKind, str]] return kinds, names -def expand_callee_kinds(kinds_and_names: List[Union[ArgKind, Tuple[ArgKind, str]]] - ) -> Tuple[List[ArgKind], List[Optional[str]]]: +def expand_callee_kinds( + kinds_and_names: List[Union[ArgKind, Tuple[ArgKind, str]]] +) -> Tuple[List[ArgKind], List[Optional[str]]]: kinds = [] names: List[Optional[str]] = [] for v in kinds_and_names: @@ -229,6 +147,7 @@ def expand_callee_kinds(kinds_and_names: List[Union[ArgKind, Tuple[ArgKind, str] class OperandDisjointDictSuite(Suite): """Test cases for checker.DisjointDict, which is used for type inference with operands.""" + def new(self) -> DisjointDict[int, str]: return DisjointDict() @@ -238,11 +157,9 @@ def test_independent_maps(self) -> None: d.add_mapping({2, 3, 4}, {"group2"}) d.add_mapping({5, 6, 7}, {"group3"}) - self.assertEqual(d.items(), [ - ({0, 1}, {"group1"}), - ({2, 3, 4}, {"group2"}), - ({5, 6, 7}, {"group3"}), - ]) + self.assertEqual( + d.items(), [({0, 1}, {"group1"}), ({2, 3, 4}, {"group2"}), ({5, 6, 7}, {"group3"})] + ) def test_partial_merging(self) -> None: d = self.new() @@ -253,10 +170,13 @@ def test_partial_merging(self) -> None: d.add_mapping({5, 6}, {"group5"}) d.add_mapping({4, 7}, {"group6"}) - self.assertEqual(d.items(), [ - ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}), - ({3, 4, 7}, {"group3", "group6"}), - ]) + self.assertEqual( + d.items(), + [ + ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}), + ({3, 4, 7}, {"group3", "group6"}), + ], + ) def test_full_merging(self) -> None: d = self.new() @@ -267,9 +187,10 @@ def test_full_merging(self) -> None: d.add_mapping({14, 10, 16}, {"e"}) d.add_mapping({0, 10}, {"f"}) - self.assertEqual(d.items(), [ - ({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"}), - ]) + self.assertEqual( + d.items(), + [({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"})], + ) def test_merge_with_multiple_overlaps(self) -> None: d = self.new() @@ -279,29 +200,28 @@ def test_merge_with_multiple_overlaps(self) -> None: d.add_mapping({6, 1, 2, 4, 5}, {"d"}) d.add_mapping({6, 1, 2, 4, 5}, {"e"}) - self.assertEqual(d.items(), [ - ({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"}), - ]) + self.assertEqual(d.items(), [({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"})]) class OperandComparisonGroupingSuite(Suite): """Test cases for checker.group_comparison_operands.""" + def literal_keymap(self, assignable_operands: Dict[int, NameExpr]) -> Dict[int, Key]: output: Dict[int, Key] = {} for index, expr in assignable_operands.items(): - output[index] = ('FakeExpr', expr.name) + output[index] = ("FakeExpr", expr.name) return output def test_basic_cases(self) -> None: # Note: the grouping function doesn't actually inspect the input exprs, so we # just default to using NameExprs for simplicity. - x0 = NameExpr('x0') - x1 = NameExpr('x1') - x2 = NameExpr('x2') - x3 = NameExpr('x3') - x4 = NameExpr('x4') + x0 = NameExpr("x0") + x1 = NameExpr("x1") + x2 = NameExpr("x2") + x3 = NameExpr("x3") + x4 = NameExpr("x4") - basic_input = [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4)] + basic_input = [("==", x0, x1), ("==", x1, x2), ("<", x2, x3), ("==", x3, x4)] none_assignable = self.literal_keymap({}) all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4}) @@ -309,137 +229,129 @@ def test_basic_cases(self) -> None: for assignable in [none_assignable, all_assignable]: self.assertEqual( group_comparison_operands(basic_input, assignable, set()), - [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + [("==", [0, 1]), ("==", [1, 2]), ("<", [2, 3]), ("==", [3, 4])], ) self.assertEqual( - group_comparison_operands(basic_input, assignable, {'=='}), - [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + group_comparison_operands(basic_input, assignable, {"=="}), + [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4])], ) self.assertEqual( - group_comparison_operands(basic_input, assignable, {'<'}), - [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], + group_comparison_operands(basic_input, assignable, {"<"}), + [("==", [0, 1]), ("==", [1, 2]), ("<", [2, 3]), ("==", [3, 4])], ) self.assertEqual( - group_comparison_operands(basic_input, assignable, {'==', '<'}), - [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], + group_comparison_operands(basic_input, assignable, {"==", "<"}), + [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4])], ) def test_multiple_groups(self) -> None: - x0 = NameExpr('x0') - x1 = NameExpr('x1') - x2 = NameExpr('x2') - x3 = NameExpr('x3') - x4 = NameExpr('x4') - x5 = NameExpr('x5') + x0 = NameExpr("x0") + x1 = NameExpr("x1") + x2 = NameExpr("x2") + x3 = NameExpr("x3") + x4 = NameExpr("x4") + x5 = NameExpr("x5") self.assertEqual( group_comparison_operands( - [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x4)], + [("==", x0, x1), ("==", x1, x2), ("is", x2, x3), ("is", x3, x4)], self.literal_keymap({}), - {'==', 'is'}, + {"==", "is"}, ), - [('==', [0, 1, 2]), ('is', [2, 3, 4])], + [("==", [0, 1, 2]), ("is", [2, 3, 4])], ) self.assertEqual( group_comparison_operands( - [('==', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + [("==", x0, x1), ("==", x1, x2), ("==", x2, x3), ("==", x3, x4)], self.literal_keymap({}), - {'==', 'is'}, + {"==", "is"}, ), - [('==', [0, 1, 2, 3, 4])], + [("==", [0, 1, 2, 3, 4])], ) self.assertEqual( group_comparison_operands( - [('is', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], + [("is", x0, x1), ("==", x1, x2), ("==", x2, x3), ("==", x3, x4)], self.literal_keymap({}), - {'==', 'is'}, + {"==", "is"}, ), - [('is', [0, 1]), ('==', [1, 2, 3, 4])], + [("is", [0, 1]), ("==", [1, 2, 3, 4])], ) self.assertEqual( group_comparison_operands( - [('is', x0, x1), ('is', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x5)], + [("is", x0, x1), ("is", x1, x2), ("<", x2, x3), ("==", x3, x4), ("==", x4, x5)], self.literal_keymap({}), - {'==', 'is'}, + {"==", "is"}, ), - [('is', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])], + [("is", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4, 5])], ) def test_multiple_groups_coalescing(self) -> None: - x0 = NameExpr('x0') - x1 = NameExpr('x1') - x2 = NameExpr('x2') - x3 = NameExpr('x3') - x4 = NameExpr('x4') + x0 = NameExpr("x0") + x1 = NameExpr("x1") + x2 = NameExpr("x2") + x3 = NameExpr("x3") + x4 = NameExpr("x4") - nothing_combined = [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])] - everything_combined = [('==', [0, 1, 2, 3, 4, 5]), ('<', [2, 3])] + nothing_combined = [("==", [0, 1, 2]), ("<", [2, 3]), ("==", [3, 4, 5])] + everything_combined = [("==", [0, 1, 2, 3, 4, 5]), ("<", [2, 3])] # Note: We do 'x4 == x0' at the very end! two_groups = [ - ('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x0), + ("==", x0, x1), + ("==", x1, x2), + ("<", x2, x3), + ("==", x3, x4), + ("==", x4, x0), ] self.assertEqual( group_comparison_operands( - two_groups, - self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), - {'=='}, + two_groups, self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), {"=="} ), everything_combined, - "All vars are assignable, everything is combined" + "All vars are assignable, everything is combined", ) self.assertEqual( group_comparison_operands( - two_groups, - self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), - {'=='}, + two_groups, self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), {"=="} ), nothing_combined, - "x0 is unassignable, so no combining" + "x0 is unassignable, so no combining", ) self.assertEqual( group_comparison_operands( - two_groups, - self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), - {'=='}, + two_groups, self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), {"=="} ), everything_combined, - "Some vars are unassignable but x0 is, so we combine" + "Some vars are unassignable but x0 is, so we combine", ) self.assertEqual( - group_comparison_operands( - two_groups, - self.literal_keymap({0: x0, 5: x0}), - {'=='}, - ), + group_comparison_operands(two_groups, self.literal_keymap({0: x0, 5: x0}), {"=="}), everything_combined, - "All vars are unassignable but x0 is, so we combine" + "All vars are unassignable but x0 is, so we combine", ) def test_multiple_groups_different_operators(self) -> None: - x0 = NameExpr('x0') - x1 = NameExpr('x1') - x2 = NameExpr('x2') - x3 = NameExpr('x3') + x0 = NameExpr("x0") + x1 = NameExpr("x1") + x2 = NameExpr("x2") + x3 = NameExpr("x3") - groups = [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x0)] + groups = [("==", x0, x1), ("==", x1, x2), ("is", x2, x3), ("is", x3, x0)] keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0}) self.assertEqual( - group_comparison_operands(groups, keymap, {'==', 'is'}), - [('==', [0, 1, 2]), ('is', [2, 3, 4])], - "Different operators can never be combined" + group_comparison_operands(groups, keymap, {"==", "is"}), + [("==", [0, 1, 2]), ("is", [2, 3, 4])], + "Different operators can never be combined", ) def test_single_pair(self) -> None: - x0 = NameExpr('x0') - x1 = NameExpr('x1') + x0 = NameExpr("x0") + x1 = NameExpr("x1") - single_comparison = [('==', x0, x1)] - expected_output = [('==', [0, 1])] + single_comparison = [("==", x0, x1)] + expected_output = [("==", [0, 1])] - assignable_combinations: List[Dict[int, NameExpr]] = [ - {}, {0: x0}, {1: x1}, {0: x0, 1: x1}, - ] + assignable_combinations: List[Dict[int, NameExpr]] = [{}, {0: x0}, {1: x1}, {0: x0, 1: x1}] to_group_by: List[Set[str]] = [set(), {"=="}, {"is"}] for combo in assignable_combinations: @@ -455,4 +367,4 @@ def test_empty_pair_list(self) -> None: # always contain at least one comparison. But in case it does... self.assertEqual(group_comparison_operands([], {}, set()), []) - self.assertEqual(group_comparison_operands([], {}, {'=='}), []) + self.assertEqual(group_comparison_operands([], {}, {"=="}), []) diff --git a/mypy/test/testipc.py b/mypy/test/testipc.py index 462fd44c88005..fde15a1a3f318 100644 --- a/mypy/test/testipc.py +++ b/mypy/test/testipc.py @@ -1,19 +1,19 @@ -from unittest import TestCase, main +import sys +import time from multiprocessing import Process, Queue - -from mypy.ipc import IPCClient, IPCServer +from unittest import TestCase, main import pytest -import sys -import time -CONNECTION_NAME = 'dmypy-test-ipc' +from mypy.ipc import IPCClient, IPCServer + +CONNECTION_NAME = "dmypy-test-ipc" -def server(msg: str, q: 'Queue[str]') -> None: +def server(msg: str, q: "Queue[str]") -> None: server = IPCServer(CONNECTION_NAME) q.put(server.connection_name) - data = b'' + data = b"" while not data: with server: server.write(msg.encode()) @@ -24,30 +24,30 @@ def server(msg: str, q: 'Queue[str]') -> None: class IPCTests(TestCase): def test_transaction_large(self) -> None: queue: Queue[str] = Queue() - msg = 't' * 200000 # longer than the max read size of 100_000 + msg = "t" * 200000 # longer than the max read size of 100_000 p = Process(target=server, args=(msg, queue), daemon=True) p.start() connection_name = queue.get() with IPCClient(connection_name, timeout=1) as client: assert client.read() == msg.encode() - client.write(b'test') + client.write(b"test") queue.close() queue.join_thread() p.join() def test_connect_twice(self) -> None: queue: Queue[str] = Queue() - msg = 'this is a test message' + msg = "this is a test message" p = Process(target=server, args=(msg, queue), daemon=True) p.start() connection_name = queue.get() with IPCClient(connection_name, timeout=1) as client: assert client.read() == msg.encode() - client.write(b'') # don't let the server hang up yet, we want to connect again. + client.write(b"") # don't let the server hang up yet, we want to connect again. with IPCClient(connection_name, timeout=1) as client: assert client.read() == msg.encode() - client.write(b'test') + client.write(b"test") queue.close() queue.join_thread() p.join() @@ -61,7 +61,7 @@ def test_connect_alot(self) -> None: t0 = time.time() for i in range(1000): try: - print(i, 'start') + print(i, "start") self.test_connect_twice() finally: t1 = time.time() @@ -70,5 +70,5 @@ def test_connect_alot(self) -> None: t0 = t1 -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mypy/test/testmerge.py b/mypy/test/testmerge.py index 3f07c39f856db..cd780d088b28c 100644 --- a/mypy/test/testmerge.py +++ b/mypy/test/testmerge.py @@ -2,15 +2,22 @@ import os import shutil -from typing import List, Tuple, Dict, Optional +from typing import Dict, List, Optional, Tuple from mypy import build from mypy.build import BuildResult -from mypy.modulefinder import BuildSource from mypy.errors import CompileError +from mypy.modulefinder import BuildSource from mypy.nodes import ( - Node, MypyFile, SymbolTable, SymbolTableNode, TypeInfo, Expression, Var, TypeVarExpr, - UNBOUND_IMPORTED + UNBOUND_IMPORTED, + Expression, + MypyFile, + Node, + SymbolTable, + SymbolTableNode, + TypeInfo, + TypeVarExpr, + Var, ) from mypy.server.subexpr import get_subexpressions from mypy.server.update import FineGrainedBuildManager @@ -18,31 +25,30 @@ from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages, parse_options -from mypy.types import TypeStrVisitor, Type -from mypy.util import short_type, IdMapper - +from mypy.types import Type, TypeStrVisitor +from mypy.util import IdMapper, short_type # Which data structures to dump in a test case? -SYMTABLE = 'SYMTABLE' -TYPEINFO = ' TYPEINFO' -TYPES = 'TYPES' -AST = 'AST' +SYMTABLE = "SYMTABLE" +TYPEINFO = " TYPEINFO" +TYPES = "TYPES" +AST = "AST" NOT_DUMPED_MODULES = ( - 'builtins', - 'typing', - 'abc', - 'contextlib', - 'sys', - 'mypy_extensions', - 'typing_extensions', - 'enum', + "builtins", + "typing", + "abc", + "contextlib", + "sys", + "mypy_extensions", + "typing_extensions", + "enum", ) class ASTMergeSuite(DataSuite): - files = ['merge.test'] + files = ["merge.test"] def setup(self) -> None: super().setup() @@ -55,18 +61,18 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: name = testcase.name # We use the test case name to decide which data structures to dump. # Dumping everything would result in very verbose test cases. - if name.endswith('_symtable'): + if name.endswith("_symtable"): kind = SYMTABLE - elif name.endswith('_typeinfo'): + elif name.endswith("_typeinfo"): kind = TYPEINFO - elif name.endswith('_types'): + elif name.endswith("_types"): kind = TYPES else: kind = AST - main_src = '\n'.join(testcase.input) + main_src = "\n".join(testcase.input) result = self.build(main_src, testcase) - assert result is not None, 'cases where CompileError occurred should not be run' + assert result is not None, "cases where CompileError occurred should not be run" result.manager.fscache.flush() fine_grained_manager = FineGrainedBuildManager(result) @@ -74,15 +80,15 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: if result.errors: a.extend(result.errors) - target_path = os.path.join(test_temp_dir, 'target.py') - shutil.copy(os.path.join(test_temp_dir, 'target.py.next'), target_path) + target_path = os.path.join(test_temp_dir, "target.py") + shutil.copy(os.path.join(test_temp_dir, "target.py.next"), target_path) a.extend(self.dump(fine_grained_manager, kind)) - old_subexpr = get_subexpressions(result.manager.modules['target']) + old_subexpr = get_subexpressions(result.manager.modules["target"]) - a.append('==>') + a.append("==>") - new_file, new_types = self.build_increment(fine_grained_manager, 'target', target_path) + new_file, new_types = self.build_increment(fine_grained_manager, "target", target_path) a.extend(self.dump(fine_grained_manager, kind)) for expr in old_subexpr: @@ -96,8 +102,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: a = normalize_error_messages(a) assert_string_arrays_equal( - testcase.output, a, - f'Invalid output ({testcase.file}, line {testcase.line})') + testcase.output, a, f"Invalid output ({testcase.file}, line {testcase.line})" + ) def build(self, source: str, testcase: DataDrivenTestCase) -> Optional[BuildResult]: options = parse_options(source, testcase, incremental_step=1) @@ -106,30 +112,30 @@ def build(self, source: str, testcase: DataDrivenTestCase) -> Optional[BuildResu options.use_builtins_fixtures = True options.export_types = True options.show_traceback = True - main_path = os.path.join(test_temp_dir, 'main') - with open(main_path, 'w', encoding='utf8') as f: + main_path = os.path.join(test_temp_dir, "main") + with open(main_path, "w", encoding="utf8") as f: f.write(source) try: - result = build.build(sources=[BuildSource(main_path, None, None)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource(main_path, None, None)], + options=options, + alt_lib_path=test_temp_dir, + ) except CompileError: # TODO: Is it okay to return None? return None return result - def build_increment(self, manager: FineGrainedBuildManager, - module_id: str, path: str) -> Tuple[MypyFile, - Dict[Expression, Type]]: + def build_increment( + self, manager: FineGrainedBuildManager, module_id: str, path: str + ) -> Tuple[MypyFile, Dict[Expression, Type]]: manager.flush_cache() manager.update([(module_id, path)], []) module = manager.manager.modules[module_id] type_map = manager.graph[module_id].type_map() return module, type_map - def dump(self, - manager: FineGrainedBuildManager, - kind: str) -> List[str]: + def dump(self, manager: FineGrainedBuildManager, kind: str) -> List[str]: modules = manager.manager.modules if kind == AST: return self.dump_asts(modules) @@ -139,7 +145,7 @@ def dump(self, return self.dump_symbol_tables(modules) elif kind == TYPES: return self.dump_types(manager) - assert False, f'Invalid kind {kind}' + assert False, f"Invalid kind {kind}" def dump_asts(self, modules: Dict[str, MypyFile]) -> List[str]: a = [] @@ -161,26 +167,29 @@ def dump_symbol_tables(self, modules: Dict[str, MypyFile]) -> List[str]: return a def dump_symbol_table(self, module_id: str, symtable: SymbolTable) -> List[str]: - a = [f'{module_id}:'] + a = [f"{module_id}:"] for name in sorted(symtable): - if name.startswith('__'): + if name.startswith("__"): continue - a.append(f' {name}: {self.format_symbol_table_node(symtable[name])}') + a.append(f" {name}: {self.format_symbol_table_node(symtable[name])}") return a def format_symbol_table_node(self, node: SymbolTableNode) -> str: if node.node is None: if node.kind == UNBOUND_IMPORTED: - return 'UNBOUND_IMPORTED' - return 'None' + return "UNBOUND_IMPORTED" + return "None" if isinstance(node.node, Node): - s = f'{str(type(node.node).__name__)}<{self.id_mapper.id(node.node)}>' + s = f"{str(type(node.node).__name__)}<{self.id_mapper.id(node.node)}>" else: - s = f'? ({type(node.node)})' - if (isinstance(node.node, Var) and node.node.type and - not node.node.fullname.startswith('typing.')): + s = f"? ({type(node.node)})" + if ( + isinstance(node.node, Var) + and node.node.type + and not node.node.fullname.startswith("typing.") + ): typestr = self.format_type(node.node.type) - s += f'({typestr})' + s += f"({typestr})" return s def dump_typeinfos(self, modules: Dict[str, MypyFile]) -> List[str]: @@ -200,11 +209,10 @@ def dump_typeinfos_recursive(self, names: SymbolTable) -> List[str]: return a def dump_typeinfo(self, info: TypeInfo) -> List[str]: - if info.fullname == 'enum.Enum': + if info.fullname == "enum.Enum": # Avoid noise return [] - s = info.dump(str_conv=self.str_conv, - type_str_conv=self.type_str_conv) + s = info.dump(str_conv=self.str_conv, type_str_conv=self.type_str_conv) return s.splitlines() def dump_types(self, manager: FineGrainedBuildManager) -> List[str]: @@ -218,15 +226,16 @@ def dump_types(self, manager: FineGrainedBuildManager) -> List[str]: # Compute a module type map from the global type map tree = manager.graph[module_id].tree assert tree is not None - type_map = {node: all_types[node] - for node in get_subexpressions(tree) - if node in all_types} + type_map = { + node: all_types[node] for node in get_subexpressions(tree) if node in all_types + } if type_map: - a.append(f'## {module_id}') - for expr in sorted(type_map, key=lambda n: (n.line, short_type(n), - str(n) + str(type_map[n]))): + a.append(f"## {module_id}") + for expr in sorted( + type_map, key=lambda n: (n.line, short_type(n), str(n) + str(type_map[n])) + ): typ = type_map[expr] - a.append(f'{short_type(expr)}:{expr.line}: {self.format_type(typ)}') + a.append(f"{short_type(expr)}:{expr.line}: {self.format_type(typ)}") return a def format_type(self, typ: Type) -> str: @@ -234,4 +243,4 @@ def format_type(self, typ: Type) -> str: def is_dumped_module(id: str) -> bool: - return id not in NOT_DUMPED_MODULES and (not id.startswith('_') or id == '__main__') + return id not in NOT_DUMPED_MODULES and (not id.startswith("_") or id == "__main__") diff --git a/mypy/test/testmodulefinder.py b/mypy/test/testmodulefinder.py index fc80893659c2d..b9792828794f5 100644 --- a/mypy/test/testmodulefinder.py +++ b/mypy/test/testmodulefinder.py @@ -1,19 +1,14 @@ import os +from mypy.modulefinder import FindModuleCache, ModuleNotFoundReason, SearchPaths from mypy.options import Options -from mypy.modulefinder import ( - FindModuleCache, - SearchPaths, - ModuleNotFoundReason, -) - -from mypy.test.helpers import Suite, assert_equal from mypy.test.config import package_path +from mypy.test.helpers import Suite, assert_equal + data_path = os.path.relpath(os.path.join(package_path, "modulefinder")) class ModuleFinderSuite(Suite): - def setUp(self) -> None: self.search_paths = SearchPaths( python_path=(), @@ -141,12 +136,10 @@ def test__find_d_nowhere(self) -> None: class ModuleFinderSitePackagesSuite(Suite): - def setUp(self) -> None: - self.package_dir = os.path.relpath(os.path.join( - package_path, - "modulefinder-site-packages", - )) + self.package_dir = os.path.relpath( + os.path.join(package_path, "modulefinder-site-packages") + ) package_paths = ( os.path.join(self.package_dir, "baz"), @@ -180,51 +173,44 @@ def test__packages_with_ns(self) -> None: ("ns_pkg_typed.b", self.path("ns_pkg_typed", "b")), ("ns_pkg_typed.b.c", self.path("ns_pkg_typed", "b", "c.py")), ("ns_pkg_typed.a.a_var", ModuleNotFoundReason.NOT_FOUND), - # Namespace package without py.typed ("ns_pkg_untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.a", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.b", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.b.c", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.a.a_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Namespace package without stub package ("ns_pkg_w_stubs", self.path("ns_pkg_w_stubs")), ("ns_pkg_w_stubs.typed", self.path("ns_pkg_w_stubs-stubs", "typed", "__init__.pyi")), - ("ns_pkg_w_stubs.typed_inline", - self.path("ns_pkg_w_stubs", "typed_inline", "__init__.py")), + ( + "ns_pkg_w_stubs.typed_inline", + self.path("ns_pkg_w_stubs", "typed_inline", "__init__.py"), + ), ("ns_pkg_w_stubs.untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Regular package with py.typed ("pkg_typed", self.path("pkg_typed", "__init__.py")), ("pkg_typed.a", self.path("pkg_typed", "a.py")), ("pkg_typed.b", self.path("pkg_typed", "b", "__init__.py")), ("pkg_typed.b.c", self.path("pkg_typed", "b", "c.py")), ("pkg_typed.a.a_var", ModuleNotFoundReason.NOT_FOUND), - # Regular package without py.typed ("pkg_untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.a", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.b", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.b.c", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.a.a_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Top-level Python file in site-packages ("standalone", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("standalone.standalone_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Packages found by following .pth files ("baz_pkg", self.path("baz", "baz_pkg", "__init__.py")), ("ns_baz_pkg.a", self.path("baz", "ns_baz_pkg", "a.py")), ("neighbor_pkg", self.path("..", "modulefinder-src", "neighbor_pkg", "__init__.py")), ("ns_neighbor_pkg.a", self.path("..", "modulefinder-src", "ns_neighbor_pkg", "a.py")), - # Something that doesn't exist ("does_not_exist", ModuleNotFoundReason.NOT_FOUND), - # A regular package with an installed set of stubs ("foo.bar", self.path("foo-stubs", "bar.pyi")), - # A regular, non-site-packages module ("a", os.path.join(data_path, "pkg1", "a.py")), ] @@ -242,51 +228,44 @@ def test__packages_without_ns(self) -> None: ("ns_pkg_typed.b", ModuleNotFoundReason.NOT_FOUND), ("ns_pkg_typed.b.c", ModuleNotFoundReason.NOT_FOUND), ("ns_pkg_typed.a.a_var", ModuleNotFoundReason.NOT_FOUND), - # Namespace package without py.typed ("ns_pkg_untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.a", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.b", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.b.c", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_untyped.a.a_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Namespace package without stub package ("ns_pkg_w_stubs", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("ns_pkg_w_stubs.typed", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - ("ns_pkg_w_stubs.typed_inline", - self.path("ns_pkg_w_stubs", "typed_inline", "__init__.py")), + ( + "ns_pkg_w_stubs.typed_inline", + self.path("ns_pkg_w_stubs", "typed_inline", "__init__.py"), + ), ("ns_pkg_w_stubs.untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Regular package with py.typed ("pkg_typed", self.path("pkg_typed", "__init__.py")), ("pkg_typed.a", self.path("pkg_typed", "a.py")), ("pkg_typed.b", self.path("pkg_typed", "b", "__init__.py")), ("pkg_typed.b.c", self.path("pkg_typed", "b", "c.py")), ("pkg_typed.a.a_var", ModuleNotFoundReason.NOT_FOUND), - # Regular package without py.typed ("pkg_untyped", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.a", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.b", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.b.c", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("pkg_untyped.a.a_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Top-level Python file in site-packages ("standalone", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), ("standalone.standalone_var", ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS), - # Packages found by following .pth files ("baz_pkg", self.path("baz", "baz_pkg", "__init__.py")), ("ns_baz_pkg.a", ModuleNotFoundReason.NOT_FOUND), ("neighbor_pkg", self.path("..", "modulefinder-src", "neighbor_pkg", "__init__.py")), ("ns_neighbor_pkg.a", ModuleNotFoundReason.NOT_FOUND), - # Something that doesn't exist ("does_not_exist", ModuleNotFoundReason.NOT_FOUND), - # A regular package with an installed set of stubs ("foo.bar", self.path("foo-stubs", "bar.pyi")), - # A regular, non-site-packages module ("a", os.path.join(data_path, "pkg1", "a.py")), ] diff --git a/mypy/test/testmypyc.py b/mypy/test/testmypyc.py index b66ec9e5ccf3f..7281bde79ecac 100644 --- a/mypy/test/testmypyc.py +++ b/mypy/test/testmypyc.py @@ -1,12 +1,12 @@ """A basic check to make sure that we are using a mypyc-compiled version when expected.""" -import mypy - -from unittest import TestCase import os +from unittest import TestCase + +import mypy class MypycTest(TestCase): def test_using_mypyc(self) -> None: - if os.getenv('TEST_MYPYC', None) == '1': - assert not mypy.__file__.endswith('.py'), "Expected to find a mypyc-compiled version" + if os.getenv("TEST_MYPYC", None) == "1": + assert not mypy.__file__.endswith(".py"), "Expected to find a mypyc-compiled version" diff --git a/mypy/test/testparse.py b/mypy/test/testparse.py index f75452c58860f..4a7ea86219fe0 100644 --- a/mypy/test/testparse.py +++ b/mypy/test/testparse.py @@ -5,20 +5,20 @@ from pytest import skip from mypy import defaults -from mypy.test.helpers import assert_string_arrays_equal, parse_options, find_test_files -from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.parse import parse from mypy.errors import CompileError from mypy.options import Options +from mypy.parse import parse +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import assert_string_arrays_equal, find_test_files, parse_options class ParserSuite(DataSuite): required_out_section = True - base_path = '.' + base_path = "." files = find_test_files(pattern="parse*.test", exclude=["parse-errors.test"]) if sys.version_info < (3, 10): - files.remove('parse-python310.test') + files.remove("parse-python310.test") def run_case(self, testcase: DataDrivenTestCase) -> None: test_parser(testcase) @@ -31,34 +31,38 @@ def test_parser(testcase: DataDrivenTestCase) -> None: """ options = Options() - if testcase.file.endswith('python310.test'): + if testcase.file.endswith("python310.test"): options.python_version = (3, 10) else: options.python_version = defaults.PYTHON3_VERSION try: - n = parse(bytes('\n'.join(testcase.input), 'ascii'), - fnam='main', - module='__main__', - errors=None, - options=options) - a = str(n).split('\n') + n = parse( + bytes("\n".join(testcase.input), "ascii"), + fnam="main", + module="__main__", + errors=None, + options=options, + ) + a = str(n).split("\n") except CompileError as e: a = e.messages - assert_string_arrays_equal(testcase.output, a, - 'Invalid parser output ({}, line {})'.format( - testcase.file, testcase.line)) + assert_string_arrays_equal( + testcase.output, + a, + "Invalid parser output ({}, line {})".format(testcase.file, testcase.line), + ) # The file name shown in test case output. This is displayed in error # messages, and must match the file name in the test case descriptions. -INPUT_FILE_NAME = 'file' +INPUT_FILE_NAME = "file" class ParseErrorSuite(DataSuite): required_out_section = True - base_path = '.' - files = ['parse-errors.test'] + base_path = "." + files = ["parse-errors.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: test_parse_error(testcase) @@ -66,18 +70,21 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: def test_parse_error(testcase: DataDrivenTestCase) -> None: try: - options = parse_options('\n'.join(testcase.input), testcase, 0) + options = parse_options("\n".join(testcase.input), testcase, 0) if options.python_version != sys.version_info[:2]: skip() # Compile temporary file. The test file contains non-ASCII characters. - parse(bytes('\n'.join(testcase.input), 'utf-8'), INPUT_FILE_NAME, '__main__', None, - options) - raise AssertionError('No errors reported') + parse( + bytes("\n".join(testcase.input), "utf-8"), INPUT_FILE_NAME, "__main__", None, options + ) + raise AssertionError("No errors reported") except CompileError as e: if e.module_with_blocker is not None: - assert e.module_with_blocker == '__main__' + assert e.module_with_blocker == "__main__" # Verify that there was a compile error and that the error messages # are equivalent. assert_string_arrays_equal( - testcase.output, e.messages, - f'Invalid compiler output ({testcase.file}, line {testcase.line})') + testcase.output, + e.messages, + f"Invalid compiler output ({testcase.file}, line {testcase.line})", + ) diff --git a/mypy/test/testpep561.py b/mypy/test/testpep561.py index 0f327658c8e0b..97229b03ed227 100644 --- a/mypy/test/testpep561.py +++ b/mypy/test/testpep561.py @@ -1,29 +1,26 @@ -from contextlib import contextmanager -import filelock import os import re import subprocess -from subprocess import PIPE import sys import tempfile -from typing import Tuple, List, Iterator +from contextlib import contextmanager +from subprocess import PIPE +from typing import Iterator, List, Tuple + +import filelock import mypy.api -from mypy.test.config import package_path, pip_lock, pip_timeout +from mypy.test.config import package_path, pip_lock, pip_timeout, test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.test.config import test_temp_dir from mypy.test.helpers import assert_string_arrays_equal, perform_file_operations - # NOTE: options.use_builtins_fixtures should not be set in these # tests, otherwise mypy will ignore installed third-party packages. class PEP561Suite(DataSuite): - files = [ - 'pep561.test', - ] - base_path = '.' + files = ["pep561.test"] + base_path = "." def run_case(self, test_case: DataDrivenTestCase) -> None: test_pep561(test_case) @@ -37,52 +34,48 @@ def virtualenv(python_executable: str = sys.executable) -> Iterator[Tuple[str, s """ with tempfile.TemporaryDirectory() as venv_dir: proc = subprocess.run( - [python_executable, '-m', 'venv', venv_dir], - cwd=os.getcwd(), stdout=PIPE, stderr=PIPE + [python_executable, "-m", "venv", venv_dir], cwd=os.getcwd(), stdout=PIPE, stderr=PIPE ) if proc.returncode != 0: - err = proc.stdout.decode('utf-8') + proc.stderr.decode('utf-8') + err = proc.stdout.decode("utf-8") + proc.stderr.decode("utf-8") raise Exception("Failed to create venv.\n" + err) - if sys.platform == 'win32': - yield venv_dir, os.path.abspath(os.path.join(venv_dir, 'Scripts', 'python')) + if sys.platform == "win32": + yield venv_dir, os.path.abspath(os.path.join(venv_dir, "Scripts", "python")) else: - yield venv_dir, os.path.abspath(os.path.join(venv_dir, 'bin', 'python')) + yield venv_dir, os.path.abspath(os.path.join(venv_dir, "bin", "python")) -def install_package(pkg: str, - python_executable: str = sys.executable, - use_pip: bool = True, - editable: bool = False) -> None: +def install_package( + pkg: str, python_executable: str = sys.executable, use_pip: bool = True, editable: bool = False +) -> None: """Install a package from test-data/packages/pkg/""" working_dir = os.path.join(package_path, pkg) with tempfile.TemporaryDirectory() as dir: if use_pip: - install_cmd = [python_executable, '-m', 'pip', 'install'] + install_cmd = [python_executable, "-m", "pip", "install"] if editable: - install_cmd.append('-e') - install_cmd.append('.') + install_cmd.append("-e") + install_cmd.append(".") else: - install_cmd = [python_executable, 'setup.py'] + install_cmd = [python_executable, "setup.py"] if editable: - install_cmd.append('develop') + install_cmd.append("develop") else: - install_cmd.append('install') + install_cmd.append("install") # Note that newer versions of pip (21.3+) don't # follow this env variable, but this is for compatibility - env = {'PIP_BUILD': dir} + env = {"PIP_BUILD": dir} # Inherit environment for Windows env.update(os.environ) try: with filelock.FileLock(pip_lock, timeout=pip_timeout): - proc = subprocess.run(install_cmd, - cwd=working_dir, - stdout=PIPE, - stderr=PIPE, - env=env) + proc = subprocess.run( + install_cmd, cwd=working_dir, stdout=PIPE, stderr=PIPE, env=env + ) except filelock.Timeout as err: raise Exception("Failed to acquire {}".format(pip_lock)) from err if proc.returncode != 0: - raise Exception(proc.stdout.decode('utf-8') + proc.stderr.decode('utf-8')) + raise Exception(proc.stdout.decode("utf-8") + proc.stderr.decode("utf-8")) def test_pep561(testcase: DataDrivenTestCase) -> None: @@ -96,9 +89,9 @@ def test_pep561(testcase: DataDrivenTestCase) -> None: use_pip = True editable = False for arg in pip_args: - if arg == 'no-pip': + if arg == "no-pip": use_pip = False - elif arg == 'editable': + elif arg == "editable": editable = True assert pkgs != [], "No packages to install for PEP 561 test?" with virtualenv(python) as venv: @@ -107,17 +100,17 @@ def test_pep561(testcase: DataDrivenTestCase) -> None: install_package(pkg, python_executable, use_pip, editable) cmd_line = list(mypy_args) - has_program = not ('-p' in cmd_line or '--package' in cmd_line) + has_program = not ("-p" in cmd_line or "--package" in cmd_line) if has_program: - program = testcase.name + '.py' - with open(program, 'w', encoding='utf-8') as f: + program = testcase.name + ".py" + with open(program, "w", encoding="utf-8") as f: for s in testcase.input: - f.write(f'{s}\n') + f.write(f"{s}\n") cmd_line.append(program) - cmd_line.extend(['--no-error-summary']) + cmd_line.extend(["--no-error-summary"]) if python_executable != sys.executable: - cmd_line.append(f'--python-executable={python_executable}') + cmd_line.append(f"--python-executable={python_executable}") steps = testcase.find_steps() if steps != [[]]: @@ -133,32 +126,34 @@ def test_pep561(testcase: DataDrivenTestCase) -> None: # split lines, remove newlines, and remove directory of test case for line in (out + err).splitlines(): if line.startswith(test_temp_dir + os.sep): - output.append(line[len(test_temp_dir + os.sep):].rstrip("\r\n")) + output.append(line[len(test_temp_dir + os.sep) :].rstrip("\r\n")) else: # Normalize paths so that the output is the same on Windows and Linux/macOS. - line = line.replace(test_temp_dir + os.sep, test_temp_dir + '/') + line = line.replace(test_temp_dir + os.sep, test_temp_dir + "/") output.append(line.rstrip("\r\n")) - iter_count = '' if i == 0 else f' on iteration {i + 1}' + iter_count = "" if i == 0 else f" on iteration {i + 1}" expected = testcase.output if i == 0 else testcase.output2.get(i + 1, []) - assert_string_arrays_equal(expected, output, - 'Invalid output ({}, line {}){}'.format( - testcase.file, testcase.line, iter_count)) + assert_string_arrays_equal( + expected, + output, + "Invalid output ({}, line {}){}".format(testcase.file, testcase.line, iter_count), + ) if has_program: os.remove(program) def parse_pkgs(comment: str) -> Tuple[List[str], List[str]]: - if not comment.startswith('# pkgs:'): + if not comment.startswith("# pkgs:"): return ([], []) else: - pkgs_str, *args = comment[7:].split(';') - return ([pkg.strip() for pkg in pkgs_str.split(',')], [arg.strip() for arg in args]) + pkgs_str, *args = comment[7:].split(";") + return ([pkg.strip() for pkg in pkgs_str.split(",")], [arg.strip() for arg in args]) def parse_mypy_args(line: str) -> List[str]: - m = re.match('# flags: (.*)$', line) + m = re.match("# flags: (.*)$", line) if not m: return [] # No args; mypy will spit out an error. return m.group(1).split() @@ -166,8 +161,8 @@ def parse_mypy_args(line: str) -> List[str]: def test_mypy_path_is_respected() -> None: assert False - packages = 'packages' - pkg_name = 'a' + packages = "packages" + pkg_name = "a" with tempfile.TemporaryDirectory() as temp_dir: old_dir = os.getcwd() os.chdir(temp_dir) @@ -177,22 +172,21 @@ def test_mypy_path_is_respected() -> None: os.makedirs(full_pkg_name) # Create the empty __init__ file to declare a package - pkg_init_name = os.path.join(temp_dir, packages, pkg_name, '__init__.py') - open(pkg_init_name, 'w', encoding='utf8').close() + pkg_init_name = os.path.join(temp_dir, packages, pkg_name, "__init__.py") + open(pkg_init_name, "w", encoding="utf8").close() - mypy_config_path = os.path.join(temp_dir, 'mypy.ini') - with open(mypy_config_path, 'w') as mypy_file: - mypy_file.write('[mypy]\n') - mypy_file.write(f'mypy_path = ./{packages}\n') + mypy_config_path = os.path.join(temp_dir, "mypy.ini") + with open(mypy_config_path, "w") as mypy_file: + mypy_file.write("[mypy]\n") + mypy_file.write(f"mypy_path = ./{packages}\n") with virtualenv() as venv: venv_dir, python_executable = venv cmd_line_args = [] if python_executable != sys.executable: - cmd_line_args.append(f'--python-executable={python_executable}') - cmd_line_args.extend(['--config-file', mypy_config_path, - '--package', pkg_name]) + cmd_line_args.append(f"--python-executable={python_executable}") + cmd_line_args.extend(["--config-file", mypy_config_path, "--package", pkg_name]) out, err, returncode = mypy.api.run(cmd_line_args) assert returncode == 0 diff --git a/mypy/test/testpythoneval.py b/mypy/test/testpythoneval.py index 4fcf6e063268b..7238e427b1d45 100644 --- a/mypy/test/testpythoneval.py +++ b/mypy/test/testpythoneval.py @@ -14,30 +14,28 @@ import os.path import re import subprocess -from subprocess import PIPE import sys +from subprocess import PIPE from tempfile import TemporaryDirectory - from typing import List +from mypy import api from mypy.defaults import PYTHON3_VERSION from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal, split_lines -from mypy import api # Path to Python 3 interpreter python3_path = sys.executable -program_re = re.compile(r'\b_program.py\b') +program_re = re.compile(r"\b_program.py\b") class PythonEvaluationSuite(DataSuite): - files = ['pythoneval.test', - 'pythoneval-asyncio.test'] + files = ["pythoneval.test", "pythoneval-asyncio.test"] cache_dir = TemporaryDirectory() def run_case(self, testcase: DataDrivenTestCase) -> None: - test_python_evaluation(testcase, os.path.join(self.cache_dir.name, '.mypy_cache')) + test_python_evaluation(testcase, os.path.join(self.cache_dir.name, ".mypy_cache")) def test_python_evaluation(testcase: DataDrivenTestCase, cache_dir: str) -> None: @@ -50,53 +48,56 @@ def test_python_evaluation(testcase: DataDrivenTestCase, cache_dir: str) -> None # We must enable site packages to get access to installed stubs. # TODO: Enable strict optional for these tests mypy_cmdline = [ - '--show-traceback', - '--no-strict-optional', - '--no-silence-site-packages', - '--no-error-summary', + "--show-traceback", + "--no-strict-optional", + "--no-silence-site-packages", + "--no-error-summary", ] interpreter = python3_path mypy_cmdline.append(f"--python-version={'.'.join(map(str, PYTHON3_VERSION))}") - m = re.search('# flags: (.*)$', '\n'.join(testcase.input), re.MULTILINE) + m = re.search("# flags: (.*)$", "\n".join(testcase.input), re.MULTILINE) if m: mypy_cmdline.extend(m.group(1).split()) # Write the program to a file. - program = '_' + testcase.name + '.py' + program = "_" + testcase.name + ".py" program_path = os.path.join(test_temp_dir, program) mypy_cmdline.append(program_path) - with open(program_path, 'w', encoding='utf8') as file: + with open(program_path, "w", encoding="utf8") as file: for s in testcase.input: - file.write(f'{s}\n') - mypy_cmdline.append(f'--cache-dir={cache_dir}') + file.write(f"{s}\n") + mypy_cmdline.append(f"--cache-dir={cache_dir}") output = [] # Type check the program. out, err, returncode = api.run(mypy_cmdline) # split lines, remove newlines, and remove directory of test case for line in (out + err).splitlines(): if line.startswith(test_temp_dir + os.sep): - output.append(line[len(test_temp_dir + os.sep):].rstrip("\r\n")) + output.append(line[len(test_temp_dir + os.sep) :].rstrip("\r\n")) else: # Normalize paths so that the output is the same on Windows and Linux/macOS. - line = line.replace(test_temp_dir + os.sep, test_temp_dir + '/') + line = line.replace(test_temp_dir + os.sep, test_temp_dir + "/") output.append(line.rstrip("\r\n")) if returncode == 0: # Execute the program. - proc = subprocess.run([interpreter, '-Wignore', program], - cwd=test_temp_dir, stdout=PIPE, stderr=PIPE) + proc = subprocess.run( + [interpreter, "-Wignore", program], cwd=test_temp_dir, stdout=PIPE, stderr=PIPE + ) output.extend(split_lines(proc.stdout, proc.stderr)) # Remove temp file. os.remove(program_path) for i, line in enumerate(output): - if os.path.sep + 'typeshed' + os.path.sep in line: + if os.path.sep + "typeshed" + os.path.sep in line: output[i] = line.split(os.path.sep)[-1] - assert_string_arrays_equal(adapt_output(testcase), output, - 'Invalid output ({}, line {})'.format( - testcase.file, testcase.line)) + assert_string_arrays_equal( + adapt_output(testcase), + output, + "Invalid output ({}, line {})".format(testcase.file, testcase.line), + ) def adapt_output(testcase: DataDrivenTestCase) -> List[str]: """Translates the generic _program.py into the actual filename.""" - program = '_' + testcase.name + '.py' + program = "_" + testcase.name + ".py" return [program_re.sub(program, line) for line in testcase.output] diff --git a/mypy/test/testreports.py b/mypy/test/testreports.py index 37dc16a107d59..03f8ffd27b3b2 100644 --- a/mypy/test/testreports.py +++ b/mypy/test/testreports.py @@ -1,9 +1,8 @@ """Test cases for reports generated by mypy.""" import textwrap -from mypy.test.helpers import Suite, assert_equal from mypy.report import CoberturaPackage, get_line_rate - +from mypy.test.helpers import Suite, assert_equal try: import lxml # type: ignore @@ -16,25 +15,26 @@ class CoberturaReportSuite(Suite): @pytest.mark.skipif(lxml is None, reason="Cannot import lxml. Is it installed?") def test_get_line_rate(self) -> None: - assert_equal('1.0', get_line_rate(0, 0)) - assert_equal('0.3333', get_line_rate(1, 3)) + assert_equal("1.0", get_line_rate(0, 0)) + assert_equal("0.3333", get_line_rate(1, 3)) @pytest.mark.skipif(lxml is None, reason="Cannot import lxml. Is it installed?") def test_as_xml(self) -> None: import lxml.etree as etree # type: ignore - cobertura_package = CoberturaPackage('foobar') + cobertura_package = CoberturaPackage("foobar") cobertura_package.covered_lines = 21 cobertura_package.total_lines = 42 - child_package = CoberturaPackage('raz') + child_package = CoberturaPackage("raz") child_package.covered_lines = 10 child_package.total_lines = 10 - child_package.classes['class'] = etree.Element('class') + child_package.classes["class"] = etree.Element("class") - cobertura_package.packages['raz'] = child_package + cobertura_package.packages["raz"] = child_package - expected_output = textwrap.dedent('''\ + expected_output = textwrap.dedent( + """\ @@ -45,6 +45,8 @@ def test_as_xml(self) -> None: - ''').encode('ascii') - assert_equal(expected_output, - etree.tostring(cobertura_package.as_xml(), pretty_print=True)) + """ + ).encode("ascii") + assert_equal( + expected_output, etree.tostring(cobertura_package.as_xml(), pretty_print=True) + ) diff --git a/mypy/test/testsemanal.py b/mypy/test/testsemanal.py index d86c6ce2fe4ad..c05f34485f6dd 100644 --- a/mypy/test/testsemanal.py +++ b/mypy/test/testsemanal.py @@ -2,22 +2,23 @@ import os.path import sys - from typing import Dict, List from mypy import build -from mypy.modulefinder import BuildSource from mypy.defaults import PYTHON3_VERSION -from mypy.test.helpers import ( - assert_string_arrays_equal, normalize_error_messages, testfile_pyversion, parse_options, - find_test_files, -) -from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.test.config import test_temp_dir from mypy.errors import CompileError +from mypy.modulefinder import BuildSource from mypy.nodes import TypeInfo from mypy.options import Options - +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import ( + assert_string_arrays_equal, + find_test_files, + normalize_error_messages, + parse_options, + testfile_pyversion, +) # Semantic analyzer test cases: dump parse tree @@ -34,7 +35,7 @@ if sys.version_info < (3, 10): - semanal_files.remove('semanal-python310.test') + semanal_files.remove("semanal-python310.test") def get_semanal_options(program_text: str, testcase: DataDrivenTestCase) -> Options: @@ -63,12 +64,12 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: """ try: - src = '\n'.join(testcase.input) + src = "\n".join(testcase.input) options = get_semanal_options(src, testcase) options.python_version = testfile_pyversion(testcase.file) - result = build.build(sources=[BuildSource('main', None, src)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource("main", None, src)], options=options, alt_lib_path=test_temp_dir + ) a = result.errors if a: raise CompileError(a) @@ -79,32 +80,40 @@ def test_semanal(testcase: DataDrivenTestCase) -> None: # Omit the builtins module and files with a special marker in the # path. # TODO the test is not reliable - if (not f.path.endswith((os.sep + 'builtins.pyi', - 'typing.pyi', - 'mypy_extensions.pyi', - 'typing_extensions.pyi', - 'abc.pyi', - 'collections.pyi', - 'sys.pyi')) - and not os.path.basename(f.path).startswith('_') - and not os.path.splitext( - os.path.basename(f.path))[0].endswith('_')): - a += str(f).split('\n') + if ( + not f.path.endswith( + ( + os.sep + "builtins.pyi", + "typing.pyi", + "mypy_extensions.pyi", + "typing_extensions.pyi", + "abc.pyi", + "collections.pyi", + "sys.pyi", + ) + ) + and not os.path.basename(f.path).startswith("_") + and not os.path.splitext(os.path.basename(f.path))[0].endswith("_") + ): + a += str(f).split("\n") except CompileError as e: a = e.messages if testcase.normalize_output: a = normalize_error_messages(a) assert_string_arrays_equal( - testcase.output, a, - f'Invalid semantic analyzer output ({testcase.file}, line {testcase.line})') + testcase.output, + a, + f"Invalid semantic analyzer output ({testcase.file}, line {testcase.line})", + ) # Semantic analyzer error test cases + class SemAnalErrorSuite(DataSuite): - files = ['semanal-errors.test'] + files = ["semanal-errors.test"] if sys.version_info >= (3, 10): - semanal_files.append('semanal-errors-python310.test') + semanal_files.append("semanal-errors-python310.test") def run_case(self, testcase: DataDrivenTestCase) -> None: test_semanal_error(testcase) @@ -114,12 +123,14 @@ def test_semanal_error(testcase: DataDrivenTestCase) -> None: """Perform a test case.""" try: - src = '\n'.join(testcase.input) - res = build.build(sources=[BuildSource('main', None, src)], - options=get_semanal_options(src, testcase), - alt_lib_path=test_temp_dir) + src = "\n".join(testcase.input) + res = build.build( + sources=[BuildSource("main", None, src)], + options=get_semanal_options(src, testcase), + alt_lib_path=test_temp_dir, + ) a = res.errors - assert a, f'No errors reported in {testcase.file}, line {testcase.line}' + assert a, f"No errors reported in {testcase.file}, line {testcase.line}" except CompileError as e: # Verify that there was a compile error and that the error messages # are equivalent. @@ -127,53 +138,60 @@ def test_semanal_error(testcase: DataDrivenTestCase) -> None: if testcase.normalize_output: a = normalize_error_messages(a) assert_string_arrays_equal( - testcase.output, a, - f'Invalid compiler output ({testcase.file}, line {testcase.line})') + testcase.output, a, f"Invalid compiler output ({testcase.file}, line {testcase.line})" + ) # SymbolNode table export test cases + class SemAnalSymtableSuite(DataSuite): required_out_section = True - files = ['semanal-symtable.test'] + files = ["semanal-symtable.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: """Perform a test case.""" try: # Build test case input. - src = '\n'.join(testcase.input) - result = build.build(sources=[BuildSource('main', None, src)], - options=get_semanal_options(src, testcase), - alt_lib_path=test_temp_dir) + src = "\n".join(testcase.input) + result = build.build( + sources=[BuildSource("main", None, src)], + options=get_semanal_options(src, testcase), + alt_lib_path=test_temp_dir, + ) # The output is the symbol table converted into a string. a = result.errors if a: raise CompileError(a) for f in sorted(result.files.keys()): - if f not in ('builtins', 'typing', 'abc'): - a.append(f'{f}:') - for s in str(result.files[f].names).split('\n'): - a.append(' ' + s) + if f not in ("builtins", "typing", "abc"): + a.append(f"{f}:") + for s in str(result.files[f].names).split("\n"): + a.append(" " + s) except CompileError as e: a = e.messages assert_string_arrays_equal( - testcase.output, a, - f'Invalid semantic analyzer output ({testcase.file}, line {testcase.line})') + testcase.output, + a, + f"Invalid semantic analyzer output ({testcase.file}, line {testcase.line})", + ) # Type info export test cases class SemAnalTypeInfoSuite(DataSuite): required_out_section = True - files = ['semanal-typeinfo.test'] + files = ["semanal-typeinfo.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: """Perform a test case.""" try: # Build test case input. - src = '\n'.join(testcase.input) - result = build.build(sources=[BuildSource('main', None, src)], - options=get_semanal_options(src, testcase), - alt_lib_path=test_temp_dir) + src = "\n".join(testcase.input) + result = build.build( + sources=[BuildSource("main", None, src)], + options=get_semanal_options(src, testcase), + alt_lib_path=test_temp_dir, + ) a = result.errors if a: raise CompileError(a) @@ -187,22 +205,26 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: typeinfos[n.fullname] = n.node # The output is the symbol table converted into a string. - a = str(typeinfos).split('\n') + a = str(typeinfos).split("\n") except CompileError as e: a = e.messages assert_string_arrays_equal( - testcase.output, a, - f'Invalid semantic analyzer output ({testcase.file}, line {testcase.line})') + testcase.output, + a, + f"Invalid semantic analyzer output ({testcase.file}, line {testcase.line})", + ) class TypeInfoMap(Dict[str, TypeInfo]): def __str__(self) -> str: a: List[str] = ["TypeInfoMap("] for x, y in sorted(self.items()): - if isinstance(x, str) and (not x.startswith('builtins.') and - not x.startswith('typing.') and - not x.startswith('abc.')): - ti = ('\n' + ' ').join(str(y).split('\n')) - a.append(f' {x} : {ti}') - a[-1] += ')' - return '\n'.join(a) + if isinstance(x, str) and ( + not x.startswith("builtins.") + and not x.startswith("typing.") + and not x.startswith("abc.") + ): + ti = ("\n" + " ").join(str(y).split("\n")) + a.append(f" {x} : {ti}") + a[-1] += ")" + return "\n".join(a) diff --git a/mypy/test/testsolve.py b/mypy/test/testsolve.py index fd41892779070..829eaf0727c75 100644 --- a/mypy/test/testsolve.py +++ b/mypy/test/testsolve.py @@ -1,12 +1,12 @@ """Test cases for the constraint solver used in type inference.""" -from typing import List, Union, Tuple, Optional +from typing import List, Optional, Tuple, Union -from mypy.test.helpers import Suite, assert_equal -from mypy.constraints import SUPERTYPE_OF, SUBTYPE_OF, Constraint +from mypy.constraints import SUBTYPE_OF, SUPERTYPE_OF, Constraint from mypy.solve import solve_constraints +from mypy.test.helpers import Suite, assert_equal from mypy.test.typefixture import TypeFixture -from mypy.types import Type, TypeVarType, TypeVarId +from mypy.types import Type, TypeVarId, TypeVarType class SolveSuite(Suite): @@ -17,80 +17,90 @@ def test_empty_input(self) -> None: self.assert_solve([], [], []) def test_simple_supertype_constraints(self) -> None: - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.a)], - [(self.fx.a, self.fx.o)]) - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.a), - self.supc(self.fx.t, self.fx.b)], - [(self.fx.a, self.fx.o)]) + self.assert_solve( + [self.fx.t.id], [self.supc(self.fx.t, self.fx.a)], [(self.fx.a, self.fx.o)] + ) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.a), self.supc(self.fx.t, self.fx.b)], + [(self.fx.a, self.fx.o)], + ) def test_simple_subtype_constraints(self) -> None: - self.assert_solve([self.fx.t.id], - [self.subc(self.fx.t, self.fx.a)], - [self.fx.a]) - self.assert_solve([self.fx.t.id], - [self.subc(self.fx.t, self.fx.a), - self.subc(self.fx.t, self.fx.b)], - [self.fx.b]) + self.assert_solve([self.fx.t.id], [self.subc(self.fx.t, self.fx.a)], [self.fx.a]) + self.assert_solve( + [self.fx.t.id], + [self.subc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], + [self.fx.b], + ) def test_both_kinds_of_constraints(self) -> None: - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.b), - self.subc(self.fx.t, self.fx.a)], - [(self.fx.b, self.fx.a)]) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.a)], + [(self.fx.b, self.fx.a)], + ) def test_unsatisfiable_constraints(self) -> None: # The constraints are impossible to satisfy. - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.a), - self.subc(self.fx.t, self.fx.b)], - [None]) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.b)], + [None], + ) def test_exactly_specified_result(self) -> None: - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.b), - self.subc(self.fx.t, self.fx.b)], - [(self.fx.b, self.fx.b)]) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.b), self.subc(self.fx.t, self.fx.b)], + [(self.fx.b, self.fx.b)], + ) def test_multiple_variables(self) -> None: - self.assert_solve([self.fx.t.id, self.fx.s.id], - [self.supc(self.fx.t, self.fx.b), - self.supc(self.fx.s, self.fx.c), - self.subc(self.fx.t, self.fx.a)], - [(self.fx.b, self.fx.a), (self.fx.c, self.fx.o)]) + self.assert_solve( + [self.fx.t.id, self.fx.s.id], + [ + self.supc(self.fx.t, self.fx.b), + self.supc(self.fx.s, self.fx.c), + self.subc(self.fx.t, self.fx.a), + ], + [(self.fx.b, self.fx.a), (self.fx.c, self.fx.o)], + ) def test_no_constraints_for_var(self) -> None: - self.assert_solve([self.fx.t.id], - [], - [self.fx.uninhabited]) - self.assert_solve([self.fx.t.id, self.fx.s.id], - [], - [self.fx.uninhabited, self.fx.uninhabited]) - self.assert_solve([self.fx.t.id, self.fx.s.id], - [self.supc(self.fx.s, self.fx.a)], - [self.fx.uninhabited, (self.fx.a, self.fx.o)]) + self.assert_solve([self.fx.t.id], [], [self.fx.uninhabited]) + self.assert_solve( + [self.fx.t.id, self.fx.s.id], [], [self.fx.uninhabited, self.fx.uninhabited] + ) + self.assert_solve( + [self.fx.t.id, self.fx.s.id], + [self.supc(self.fx.s, self.fx.a)], + [self.fx.uninhabited, (self.fx.a, self.fx.o)], + ) def test_simple_constraints_with_dynamic_type(self) -> None: - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.anyt)], - [(self.fx.anyt, self.fx.anyt)]) - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.anyt), - self.supc(self.fx.t, self.fx.anyt)], - [(self.fx.anyt, self.fx.anyt)]) - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.anyt), - self.supc(self.fx.t, self.fx.a)], - [(self.fx.anyt, self.fx.anyt)]) - - self.assert_solve([self.fx.t.id], - [self.subc(self.fx.t, self.fx.anyt)], - [(self.fx.anyt, self.fx.anyt)]) - self.assert_solve([self.fx.t.id], - [self.subc(self.fx.t, self.fx.anyt), - self.subc(self.fx.t, self.fx.anyt)], - [(self.fx.anyt, self.fx.anyt)]) + self.assert_solve( + [self.fx.t.id], [self.supc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + ) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.anyt)], + [(self.fx.anyt, self.fx.anyt)], + ) + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.anyt), self.supc(self.fx.t, self.fx.a)], + [(self.fx.anyt, self.fx.anyt)], + ) + + self.assert_solve( + [self.fx.t.id], [self.subc(self.fx.t, self.fx.anyt)], [(self.fx.anyt, self.fx.anyt)] + ) + self.assert_solve( + [self.fx.t.id], + [self.subc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.anyt)], + [(self.fx.anyt, self.fx.anyt)], + ) # self.assert_solve([self.fx.t.id], # [self.subc(self.fx.t, self.fx.anyt), # self.subc(self.fx.t, self.fx.a)], @@ -100,21 +110,24 @@ def test_simple_constraints_with_dynamic_type(self) -> None: def test_both_normal_and_any_types_in_results(self) -> None: # If one of the bounds is any, we promote the other bound to # any as well, since otherwise the type range does not make sense. - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.a), - self.subc(self.fx.t, self.fx.anyt)], - [(self.fx.anyt, self.fx.anyt)]) - - self.assert_solve([self.fx.t.id], - [self.supc(self.fx.t, self.fx.anyt), - self.subc(self.fx.t, self.fx.a)], - [(self.fx.anyt, self.fx.anyt)]) - - def assert_solve(self, - vars: List[TypeVarId], - constraints: List[Constraint], - results: List[Union[None, Type, Tuple[Type, Type]]], - ) -> None: + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.a), self.subc(self.fx.t, self.fx.anyt)], + [(self.fx.anyt, self.fx.anyt)], + ) + + self.assert_solve( + [self.fx.t.id], + [self.supc(self.fx.t, self.fx.anyt), self.subc(self.fx.t, self.fx.a)], + [(self.fx.anyt, self.fx.anyt)], + ) + + def assert_solve( + self, + vars: List[TypeVarId], + constraints: List[Constraint], + results: List[Union[None, Type, Tuple[Type, Type]]], + ) -> None: res: List[Optional[Type]] = [] for r in results: if isinstance(r, tuple): diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 3c2b2967fb3cb..783f31cf4eb84 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -1,81 +1,103 @@ import io import os.path +import re import shutil import sys import tempfile -import re import unittest from types import ModuleType +from typing import Any, List, Optional, Tuple -from typing import Any, List, Tuple, Optional - -from mypy.test.helpers import ( - assert_equal, assert_string_arrays_equal, local_sys_path_set -) -from mypy.test.data import DataSuite, DataDrivenTestCase from mypy.errors import CompileError +from mypy.moduleinspect import InspectError, ModuleInspect +from mypy.stubdoc import ( + ArgSig, + FunctionSig, + build_signature, + find_unique_signatures, + infer_arg_sig_from_anon_docstring, + infer_prop_type_from_docstring, + infer_sig_from_docstring, + is_valid_type, + parse_all_signatures, + parse_signature, +) from mypy.stubgen import ( - generate_stubs, parse_options, Options, collect_build_targets, - mypy_options, is_blacklisted_path, is_non_library_module + Options, + collect_build_targets, + generate_stubs, + is_blacklisted_path, + is_non_library_module, + mypy_options, + parse_options, ) -from mypy.stubutil import walk_packages, remove_misplaced_type_comments, common_dir_prefix from mypy.stubgenc import ( - generate_c_type_stub, infer_method_sig, generate_c_function_stub, generate_c_property_stub, - is_c_property_readonly + generate_c_function_stub, + generate_c_property_stub, + generate_c_type_stub, + infer_method_sig, + is_c_property_readonly, ) -from mypy.stubdoc import ( - parse_signature, parse_all_signatures, build_signature, find_unique_signatures, - infer_sig_from_docstring, infer_prop_type_from_docstring, FunctionSig, ArgSig, - infer_arg_sig_from_anon_docstring, is_valid_type -) -from mypy.moduleinspect import ModuleInspect, InspectError +from mypy.stubutil import common_dir_prefix, remove_misplaced_type_comments, walk_packages +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import assert_equal, assert_string_arrays_equal, local_sys_path_set class StubgenCmdLineSuite(unittest.TestCase): """Test cases for processing command-line options and finding files.""" - @unittest.skipIf(sys.platform == 'win32', "clean up fails on Windows") + @unittest.skipIf(sys.platform == "win32", "clean up fails on Windows") def test_files_found(self) -> None: current = os.getcwd() with tempfile.TemporaryDirectory() as tmp: try: os.chdir(tmp) - os.mkdir('subdir') - self.make_file('subdir', 'a.py') - self.make_file('subdir', 'b.py') - os.mkdir(os.path.join('subdir', 'pack')) - self.make_file('subdir', 'pack', '__init__.py') - opts = parse_options(['subdir']) + os.mkdir("subdir") + self.make_file("subdir", "a.py") + self.make_file("subdir", "b.py") + os.mkdir(os.path.join("subdir", "pack")) + self.make_file("subdir", "pack", "__init__.py") + opts = parse_options(["subdir"]) py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) assert_equal(c_mods, []) files = {mod.path for mod in py_mods} - assert_equal(files, {os.path.join('subdir', 'pack', '__init__.py'), - os.path.join('subdir', 'a.py'), - os.path.join('subdir', 'b.py')}) + assert_equal( + files, + { + os.path.join("subdir", "pack", "__init__.py"), + os.path.join("subdir", "a.py"), + os.path.join("subdir", "b.py"), + }, + ) finally: os.chdir(current) - @unittest.skipIf(sys.platform == 'win32', "clean up fails on Windows") + @unittest.skipIf(sys.platform == "win32", "clean up fails on Windows") def test_packages_found(self) -> None: current = os.getcwd() with tempfile.TemporaryDirectory() as tmp: try: os.chdir(tmp) - os.mkdir('pack') - self.make_file('pack', '__init__.py', content='from . import a, b') - self.make_file('pack', 'a.py') - self.make_file('pack', 'b.py') - opts = parse_options(['-p', 'pack']) + os.mkdir("pack") + self.make_file("pack", "__init__.py", content="from . import a, b") + self.make_file("pack", "a.py") + self.make_file("pack", "b.py") + opts = parse_options(["-p", "pack"]) py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) assert_equal(c_mods, []) - files = {os.path.relpath(mod.path or 'FAIL') for mod in py_mods} - assert_equal(files, {os.path.join('pack', '__init__.py'), - os.path.join('pack', 'a.py'), - os.path.join('pack', 'b.py')}) + files = {os.path.relpath(mod.path or "FAIL") for mod in py_mods} + assert_equal( + files, + { + os.path.join("pack", "__init__.py"), + os.path.join("pack", "a.py"), + os.path.join("pack", "b.py"), + }, + ) finally: os.chdir(current) - @unittest.skipIf(sys.platform == 'win32', "clean up fails on Windows") + @unittest.skipIf(sys.platform == "win32", "clean up fails on Windows") def test_module_not_found(self) -> None: current = os.getcwd() captured_output = io.StringIO() @@ -83,17 +105,17 @@ def test_module_not_found(self) -> None: with tempfile.TemporaryDirectory() as tmp: try: os.chdir(tmp) - self.make_file(tmp, 'mymodule.py', content='import a') - opts = parse_options(['-m', 'mymodule']) + self.make_file(tmp, "mymodule.py", content="import a") + opts = parse_options(["-m", "mymodule"]) py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) - assert captured_output.getvalue() == '' + assert captured_output.getvalue() == "" finally: sys.stdout = sys.__stdout__ os.chdir(current) - def make_file(self, *path: str, content: str = '') -> None: + def make_file(self, *path: str, content: str = "") -> None: file = os.path.join(*path) - with open(file, 'w') as f: + with open(file, "w") as f: f.write(content) def run(self, result: Optional[Any] = None) -> Optional[Any]: @@ -104,206 +126,301 @@ def run(self, result: Optional[Any] = None) -> Optional[Any]: class StubgenCliParseSuite(unittest.TestCase): def test_walk_packages(self) -> None: with ModuleInspect() as m: - assert_equal( - set(walk_packages(m, ["mypy.errors"])), - {"mypy.errors"}) + assert_equal(set(walk_packages(m, ["mypy.errors"])), {"mypy.errors"}) assert_equal( set(walk_packages(m, ["mypy.errors", "mypy.stubgen"])), - {"mypy.errors", "mypy.stubgen"}) + {"mypy.errors", "mypy.stubgen"}, + ) all_mypy_packages = set(walk_packages(m, ["mypy"])) - self.assertTrue(all_mypy_packages.issuperset({ - "mypy", - "mypy.errors", - "mypy.stubgen", - "mypy.test", - "mypy.test.helpers", - })) + self.assertTrue( + all_mypy_packages.issuperset( + {"mypy", "mypy.errors", "mypy.stubgen", "mypy.test", "mypy.test.helpers"} + ) + ) class StubgenUtilSuite(unittest.TestCase): """Unit tests for stubgen utility functions.""" def test_parse_signature(self) -> None: - self.assert_parse_signature('func()', ('func', [], [])) + self.assert_parse_signature("func()", ("func", [], [])) def test_parse_signature_with_args(self) -> None: - self.assert_parse_signature('func(arg)', ('func', ['arg'], [])) - self.assert_parse_signature('do(arg, arg2)', ('do', ['arg', 'arg2'], [])) + self.assert_parse_signature("func(arg)", ("func", ["arg"], [])) + self.assert_parse_signature("do(arg, arg2)", ("do", ["arg", "arg2"], [])) def test_parse_signature_with_optional_args(self) -> None: - self.assert_parse_signature('func([arg])', ('func', [], ['arg'])) - self.assert_parse_signature('func(arg[, arg2])', ('func', ['arg'], ['arg2'])) - self.assert_parse_signature('func([arg[, arg2]])', ('func', [], ['arg', 'arg2'])) + self.assert_parse_signature("func([arg])", ("func", [], ["arg"])) + self.assert_parse_signature("func(arg[, arg2])", ("func", ["arg"], ["arg2"])) + self.assert_parse_signature("func([arg[, arg2]])", ("func", [], ["arg", "arg2"])) def test_parse_signature_with_default_arg(self) -> None: - self.assert_parse_signature('func(arg=None)', ('func', [], ['arg'])) - self.assert_parse_signature('func(arg, arg2=None)', ('func', ['arg'], ['arg2'])) - self.assert_parse_signature('func(arg=1, arg2="")', ('func', [], ['arg', 'arg2'])) + self.assert_parse_signature("func(arg=None)", ("func", [], ["arg"])) + self.assert_parse_signature("func(arg, arg2=None)", ("func", ["arg"], ["arg2"])) + self.assert_parse_signature('func(arg=1, arg2="")', ("func", [], ["arg", "arg2"])) def test_parse_signature_with_qualified_function(self) -> None: - self.assert_parse_signature('ClassName.func(arg)', ('func', ['arg'], [])) + self.assert_parse_signature("ClassName.func(arg)", ("func", ["arg"], [])) def test_parse_signature_with_kw_only_arg(self) -> None: - self.assert_parse_signature('ClassName.func(arg, *, arg2=1)', - ('func', ['arg', '*'], ['arg2'])) + self.assert_parse_signature( + "ClassName.func(arg, *, arg2=1)", ("func", ["arg", "*"], ["arg2"]) + ) def test_parse_signature_with_star_arg(self) -> None: - self.assert_parse_signature('ClassName.func(arg, *args)', - ('func', ['arg', '*args'], [])) + self.assert_parse_signature("ClassName.func(arg, *args)", ("func", ["arg", "*args"], [])) def test_parse_signature_with_star_star_arg(self) -> None: - self.assert_parse_signature('ClassName.func(arg, **args)', - ('func', ['arg', '**args'], [])) + self.assert_parse_signature("ClassName.func(arg, **args)", ("func", ["arg", "**args"], [])) def assert_parse_signature(self, sig: str, result: Tuple[str, List[str], List[str]]) -> None: assert_equal(parse_signature(sig), result) def test_build_signature(self) -> None: - assert_equal(build_signature([], []), '()') - assert_equal(build_signature(['arg'], []), '(arg)') - assert_equal(build_signature(['arg', 'arg2'], []), '(arg, arg2)') - assert_equal(build_signature(['arg'], ['arg2']), '(arg, arg2=...)') - assert_equal(build_signature(['arg'], ['arg2', '**x']), '(arg, arg2=..., **x)') + assert_equal(build_signature([], []), "()") + assert_equal(build_signature(["arg"], []), "(arg)") + assert_equal(build_signature(["arg", "arg2"], []), "(arg, arg2)") + assert_equal(build_signature(["arg"], ["arg2"]), "(arg, arg2=...)") + assert_equal(build_signature(["arg"], ["arg2", "**x"]), "(arg, arg2=..., **x)") def test_parse_all_signatures(self) -> None: - assert_equal(parse_all_signatures(['random text', - '.. function:: fn(arg', - '.. function:: fn()', - ' .. method:: fn2(arg)']), - ([('fn', '()'), - ('fn2', '(arg)')], [])) + assert_equal( + parse_all_signatures( + [ + "random text", + ".. function:: fn(arg", + ".. function:: fn()", + " .. method:: fn2(arg)", + ] + ), + ([("fn", "()"), ("fn2", "(arg)")], []), + ) def test_find_unique_signatures(self) -> None: - assert_equal(find_unique_signatures( - [('func', '()'), - ('func', '()'), - ('func2', '()'), - ('func2', '(arg)'), - ('func3', '(arg, arg2)')]), - [('func', '()'), - ('func3', '(arg, arg2)')]) + assert_equal( + find_unique_signatures( + [ + ("func", "()"), + ("func", "()"), + ("func2", "()"), + ("func2", "(arg)"), + ("func3", "(arg, arg2)"), + ] + ), + [("func", "()"), ("func3", "(arg, arg2)")], + ) def test_infer_sig_from_docstring(self) -> None: - assert_equal(infer_sig_from_docstring('\nfunc(x) - y', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x')], ret_type='Any')]) - assert_equal(infer_sig_from_docstring('\nfunc(x)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x')], ret_type='Any')]) - - assert_equal(infer_sig_from_docstring('\nfunc(x, Y_a=None)', 'func'), - [FunctionSig(name='func', - args=[ArgSig(name='x'), ArgSig(name='Y_a', default=True)], - ret_type='Any')]) - - assert_equal(infer_sig_from_docstring('\nfunc(x, Y_a=3)', 'func'), - [FunctionSig(name='func', - args=[ArgSig(name='x'), ArgSig(name='Y_a', default=True)], - ret_type='Any')]) - - assert_equal(infer_sig_from_docstring('\nfunc(x, Y_a=[1, 2, 3])', 'func'), - [FunctionSig(name='func', - args=[ArgSig(name='x'), ArgSig(name='Y_a', default=True)], - ret_type='Any')]) - - assert_equal(infer_sig_from_docstring('\nafunc(x) - y', 'func'), []) - assert_equal(infer_sig_from_docstring('\nfunc(x, y', 'func'), []) - assert_equal(infer_sig_from_docstring('\nfunc(x=z(y))', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', default=True)], - ret_type='Any')]) - - assert_equal(infer_sig_from_docstring('\nfunc x', 'func'), []) + assert_equal( + infer_sig_from_docstring("\nfunc(x) - y", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x")], ret_type="Any")], + ) + assert_equal( + infer_sig_from_docstring("\nfunc(x)", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x")], ret_type="Any")], + ) + + assert_equal( + infer_sig_from_docstring("\nfunc(x, Y_a=None)", "func"), + [ + FunctionSig( + name="func", + args=[ArgSig(name="x"), ArgSig(name="Y_a", default=True)], + ret_type="Any", + ) + ], + ) + + assert_equal( + infer_sig_from_docstring("\nfunc(x, Y_a=3)", "func"), + [ + FunctionSig( + name="func", + args=[ArgSig(name="x"), ArgSig(name="Y_a", default=True)], + ret_type="Any", + ) + ], + ) + + assert_equal( + infer_sig_from_docstring("\nfunc(x, Y_a=[1, 2, 3])", "func"), + [ + FunctionSig( + name="func", + args=[ArgSig(name="x"), ArgSig(name="Y_a", default=True)], + ret_type="Any", + ) + ], + ) + + assert_equal(infer_sig_from_docstring("\nafunc(x) - y", "func"), []) + assert_equal(infer_sig_from_docstring("\nfunc(x, y", "func"), []) + assert_equal( + infer_sig_from_docstring("\nfunc(x=z(y))", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x", default=True)], ret_type="Any")], + ) + + assert_equal(infer_sig_from_docstring("\nfunc x", "func"), []) # Try to infer signature from type annotation. - assert_equal(infer_sig_from_docstring('\nfunc(x: int)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='int')], - ret_type='Any')]) - assert_equal(infer_sig_from_docstring('\nfunc(x: int=3)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='int', default=True)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: int)", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x", type="int")], ret_type="Any")], + ) + assert_equal( + infer_sig_from_docstring("\nfunc(x: int=3)", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="int", default=True)], ret_type="Any" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x=3)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type=None, default=True)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x=3)", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type=None, default=True)], ret_type="Any" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc() -> int', 'func'), - [FunctionSig(name='func', args=[], ret_type='int')]) + assert_equal( + infer_sig_from_docstring("\nfunc() -> int", "func"), + [FunctionSig(name="func", args=[], ret_type="int")], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: int=3) -> int', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='int', default=True)], - ret_type='int')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: int=3) -> int", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="int", default=True)], ret_type="int" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: int=3) -> int \n', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='int', default=True)], - ret_type='int')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: int=3) -> int \n", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="int", default=True)], ret_type="int" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: Tuple[int, str]) -> str', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='Tuple[int,str]')], - ret_type='str')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: Tuple[int, str]) -> str", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="Tuple[int,str]")], ret_type="str" + ) + ], + ) assert_equal( - infer_sig_from_docstring('\nfunc(x: Tuple[int, Tuple[str, int], str], y: int) -> str', - 'func'), - [FunctionSig(name='func', - args=[ArgSig(name='x', type='Tuple[int,Tuple[str,int],str]'), - ArgSig(name='y', type='int')], - ret_type='str')]) + infer_sig_from_docstring( + "\nfunc(x: Tuple[int, Tuple[str, int], str], y: int) -> str", "func" + ), + [ + FunctionSig( + name="func", + args=[ + ArgSig(name="x", type="Tuple[int,Tuple[str,int],str]"), + ArgSig(name="y", type="int"), + ], + ret_type="str", + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: foo.bar)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='foo.bar')], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: foo.bar)", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x", type="foo.bar")], ret_type="Any")], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: list=[1,2,[3,4]])', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='list', default=True)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: list=[1,2,[3,4]])", "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="list", default=True)], ret_type="Any" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: str="nasty[")', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='str', default=True)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring('\nfunc(x: str="nasty[")', "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="str", default=True)], ret_type="Any" + ) + ], + ) - assert_equal(infer_sig_from_docstring('\nfunc[(x: foo.bar, invalid]', 'func'), []) + assert_equal(infer_sig_from_docstring("\nfunc[(x: foo.bar, invalid]", "func"), []) - assert_equal(infer_sig_from_docstring('\nfunc(x: invalid::type)', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type=None)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x: invalid::type)", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x", type=None)], ret_type="Any")], + ) - assert_equal(infer_sig_from_docstring('\nfunc(x: str="")', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x', type='str', default=True)], - ret_type='Any')]) + assert_equal( + infer_sig_from_docstring('\nfunc(x: str="")', "func"), + [ + FunctionSig( + name="func", args=[ArgSig(name="x", type="str", default=True)], ret_type="Any" + ) + ], + ) def test_infer_sig_from_docstring_duplicate_args(self) -> None: - assert_equal(infer_sig_from_docstring('\nfunc(x, x) -> str\nfunc(x, y) -> int', 'func'), - [FunctionSig(name='func', args=[ArgSig(name='x'), ArgSig(name='y')], - ret_type='int')]) + assert_equal( + infer_sig_from_docstring("\nfunc(x, x) -> str\nfunc(x, y) -> int", "func"), + [FunctionSig(name="func", args=[ArgSig(name="x"), ArgSig(name="y")], ret_type="int")], + ) def test_infer_sig_from_docstring_bad_indentation(self) -> None: - assert_equal(infer_sig_from_docstring(""" + assert_equal( + infer_sig_from_docstring( + """ x x x - """, 'func'), None) + """, + "func", + ), + None, + ) def test_infer_arg_sig_from_anon_docstring(self) -> None: - assert_equal(infer_arg_sig_from_anon_docstring("(*args, **kwargs)"), - [ArgSig(name='*args'), ArgSig(name='**kwargs')]) + assert_equal( + infer_arg_sig_from_anon_docstring("(*args, **kwargs)"), + [ArgSig(name="*args"), ArgSig(name="**kwargs")], + ) assert_equal( infer_arg_sig_from_anon_docstring( - "(x: Tuple[int, Tuple[str, int], str]=(1, ('a', 2), 'y'), y: int=4)"), - [ArgSig(name='x', type='Tuple[int,Tuple[str,int],str]', default=True), - ArgSig(name='y', type='int', default=True)]) + "(x: Tuple[int, Tuple[str, int], str]=(1, ('a', 2), 'y'), y: int=4)" + ), + [ + ArgSig(name="x", type="Tuple[int,Tuple[str,int],str]", default=True), + ArgSig(name="y", type="int", default=True), + ], + ) def test_infer_prop_type_from_docstring(self) -> None: - assert_equal(infer_prop_type_from_docstring('str: A string.'), 'str') - assert_equal(infer_prop_type_from_docstring('Optional[int]: An int.'), 'Optional[int]') - assert_equal(infer_prop_type_from_docstring('Tuple[int, int]: A tuple.'), - 'Tuple[int, int]') - assert_equal(infer_prop_type_from_docstring('\nstr: A string.'), None) + assert_equal(infer_prop_type_from_docstring("str: A string."), "str") + assert_equal(infer_prop_type_from_docstring("Optional[int]: An int."), "Optional[int]") + assert_equal( + infer_prop_type_from_docstring("Tuple[int, int]: A tuple."), "Tuple[int, int]" + ) + assert_equal(infer_prop_type_from_docstring("\nstr: A string."), None) def test_infer_sig_from_docstring_square_brackets(self) -> None: - assert infer_sig_from_docstring( - 'fetch_row([maxrows, how]) -- Fetches stuff', - 'fetch_row', - ) == [] + assert ( + infer_sig_from_docstring("fetch_row([maxrows, how]) -- Fetches stuff", "fetch_row") + == [] + ) def test_remove_misplaced_type_comments_1(self) -> None: good = """ @@ -455,82 +572,80 @@ def h(): assert_equal(remove_misplaced_type_comments(original), dest) - @unittest.skipIf(sys.platform == 'win32', - 'Tests building the paths common ancestor on *nix') + @unittest.skipIf(sys.platform == "win32", "Tests building the paths common ancestor on *nix") def test_common_dir_prefix_unix(self) -> None: - assert common_dir_prefix([]) == '.' - assert common_dir_prefix(['x.pyi']) == '.' - assert common_dir_prefix(['./x.pyi']) == '.' - assert common_dir_prefix(['foo/bar/x.pyi']) == 'foo/bar' - assert common_dir_prefix(['foo/bar/x.pyi', - 'foo/bar/y.pyi']) == 'foo/bar' - assert common_dir_prefix(['foo/bar/x.pyi', 'foo/y.pyi']) == 'foo' - assert common_dir_prefix(['foo/x.pyi', 'foo/bar/y.pyi']) == 'foo' - assert common_dir_prefix(['foo/bar/zar/x.pyi', 'foo/y.pyi']) == 'foo' - assert common_dir_prefix(['foo/x.pyi', 'foo/bar/zar/y.pyi']) == 'foo' - assert common_dir_prefix(['foo/bar/zar/x.pyi', 'foo/bar/y.pyi']) == 'foo/bar' - assert common_dir_prefix(['foo/bar/x.pyi', 'foo/bar/zar/y.pyi']) == 'foo/bar' - assert common_dir_prefix([r'foo/bar\x.pyi']) == 'foo' - assert common_dir_prefix([r'foo\bar/x.pyi']) == r'foo\bar' - - @unittest.skipIf(sys.platform != 'win32', - 'Tests building the paths common ancestor on Windows') + assert common_dir_prefix([]) == "." + assert common_dir_prefix(["x.pyi"]) == "." + assert common_dir_prefix(["./x.pyi"]) == "." + assert common_dir_prefix(["foo/bar/x.pyi"]) == "foo/bar" + assert common_dir_prefix(["foo/bar/x.pyi", "foo/bar/y.pyi"]) == "foo/bar" + assert common_dir_prefix(["foo/bar/x.pyi", "foo/y.pyi"]) == "foo" + assert common_dir_prefix(["foo/x.pyi", "foo/bar/y.pyi"]) == "foo" + assert common_dir_prefix(["foo/bar/zar/x.pyi", "foo/y.pyi"]) == "foo" + assert common_dir_prefix(["foo/x.pyi", "foo/bar/zar/y.pyi"]) == "foo" + assert common_dir_prefix(["foo/bar/zar/x.pyi", "foo/bar/y.pyi"]) == "foo/bar" + assert common_dir_prefix(["foo/bar/x.pyi", "foo/bar/zar/y.pyi"]) == "foo/bar" + assert common_dir_prefix([r"foo/bar\x.pyi"]) == "foo" + assert common_dir_prefix([r"foo\bar/x.pyi"]) == r"foo\bar" + + @unittest.skipIf( + sys.platform != "win32", "Tests building the paths common ancestor on Windows" + ) def test_common_dir_prefix_win(self) -> None: - assert common_dir_prefix(['x.pyi']) == '.' - assert common_dir_prefix([r'.\x.pyi']) == '.' - assert common_dir_prefix([r'foo\bar\x.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo\bar\x.pyi', - r'foo\bar\y.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo\bar\x.pyi', r'foo\y.pyi']) == 'foo' - assert common_dir_prefix([r'foo\x.pyi', r'foo\bar\y.pyi']) == 'foo' - assert common_dir_prefix([r'foo\bar\zar\x.pyi', r'foo\y.pyi']) == 'foo' - assert common_dir_prefix([r'foo\x.pyi', r'foo\bar\zar\y.pyi']) == 'foo' - assert common_dir_prefix([r'foo\bar\zar\x.pyi', r'foo\bar\y.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo\bar\x.pyi', r'foo\bar\zar\y.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo/bar\x.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo\bar/x.pyi']) == r'foo\bar' - assert common_dir_prefix([r'foo/bar/x.pyi']) == r'foo\bar' + assert common_dir_prefix(["x.pyi"]) == "." + assert common_dir_prefix([r".\x.pyi"]) == "." + assert common_dir_prefix([r"foo\bar\x.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo\bar\x.pyi", r"foo\bar\y.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo\bar\x.pyi", r"foo\y.pyi"]) == "foo" + assert common_dir_prefix([r"foo\x.pyi", r"foo\bar\y.pyi"]) == "foo" + assert common_dir_prefix([r"foo\bar\zar\x.pyi", r"foo\y.pyi"]) == "foo" + assert common_dir_prefix([r"foo\x.pyi", r"foo\bar\zar\y.pyi"]) == "foo" + assert common_dir_prefix([r"foo\bar\zar\x.pyi", r"foo\bar\y.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo\bar\x.pyi", r"foo\bar\zar\y.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo/bar\x.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo\bar/x.pyi"]) == r"foo\bar" + assert common_dir_prefix([r"foo/bar/x.pyi"]) == r"foo\bar" class StubgenHelpersSuite(unittest.TestCase): def test_is_blacklisted_path(self) -> None: - assert not is_blacklisted_path('foo/bar.py') - assert not is_blacklisted_path('foo.py') - assert not is_blacklisted_path('foo/xvendor/bar.py') - assert not is_blacklisted_path('foo/vendorx/bar.py') - assert is_blacklisted_path('foo/vendor/bar.py') - assert is_blacklisted_path('foo/vendored/bar.py') - assert is_blacklisted_path('foo/vendored/bar/thing.py') - assert is_blacklisted_path('foo/six.py') + assert not is_blacklisted_path("foo/bar.py") + assert not is_blacklisted_path("foo.py") + assert not is_blacklisted_path("foo/xvendor/bar.py") + assert not is_blacklisted_path("foo/vendorx/bar.py") + assert is_blacklisted_path("foo/vendor/bar.py") + assert is_blacklisted_path("foo/vendored/bar.py") + assert is_blacklisted_path("foo/vendored/bar/thing.py") + assert is_blacklisted_path("foo/six.py") def test_is_non_library_module(self) -> None: - assert not is_non_library_module('foo') - assert not is_non_library_module('foo.bar') + assert not is_non_library_module("foo") + assert not is_non_library_module("foo.bar") # The following could be test modules, but we are very conservative and # don't treat them as such since they could plausibly be real modules. - assert not is_non_library_module('foo.bartest') - assert not is_non_library_module('foo.bartests') - assert not is_non_library_module('foo.testbar') + assert not is_non_library_module("foo.bartest") + assert not is_non_library_module("foo.bartests") + assert not is_non_library_module("foo.testbar") - assert is_non_library_module('foo.test') - assert is_non_library_module('foo.test.foo') - assert is_non_library_module('foo.tests') - assert is_non_library_module('foo.tests.foo') - assert is_non_library_module('foo.testing.foo') - assert is_non_library_module('foo.SelfTest.foo') + assert is_non_library_module("foo.test") + assert is_non_library_module("foo.test.foo") + assert is_non_library_module("foo.tests") + assert is_non_library_module("foo.tests.foo") + assert is_non_library_module("foo.testing.foo") + assert is_non_library_module("foo.SelfTest.foo") - assert is_non_library_module('foo.test_bar') - assert is_non_library_module('foo.bar_tests') - assert is_non_library_module('foo.testing') - assert is_non_library_module('foo.conftest') - assert is_non_library_module('foo.bar_test_util') - assert is_non_library_module('foo.bar_test_utils') - assert is_non_library_module('foo.bar_test_base') + assert is_non_library_module("foo.test_bar") + assert is_non_library_module("foo.bar_tests") + assert is_non_library_module("foo.testing") + assert is_non_library_module("foo.conftest") + assert is_non_library_module("foo.bar_test_util") + assert is_non_library_module("foo.bar_test_utils") + assert is_non_library_module("foo.bar_test_base") - assert is_non_library_module('foo.setup') + assert is_non_library_module("foo.setup") - assert is_non_library_module('foo.__main__') + assert is_non_library_module("foo.__main__") class StubgenPythonSuite(DataSuite): @@ -555,8 +670,8 @@ class StubgenPythonSuite(DataSuite): """ required_out_section = True - base_path = '.' - files = ['stubgen.test'] + base_path = "." + files = ["stubgen.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: with local_sys_path_set(): @@ -565,26 +680,26 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: def run_case_inner(self, testcase: DataDrivenTestCase) -> None: extra = [] # Extra command-line args mods = [] # Module names to process - source = '\n'.join(testcase.input) - for file, content in testcase.files + [('./main.py', source)]: + source = "\n".join(testcase.input) + for file, content in testcase.files + [("./main.py", source)]: # Strip ./ prefix and .py suffix. - mod = file[2:-3].replace('/', '.') - if mod.endswith('.__init__'): - mod, _, _ = mod.rpartition('.') + mod = file[2:-3].replace("/", ".") + if mod.endswith(".__init__"): + mod, _, _ = mod.rpartition(".") mods.append(mod) - if '-p ' not in source: - extra.extend(['-m', mod]) - with open(file, 'w') as f: + if "-p " not in source: + extra.extend(["-m", mod]) + with open(file, "w") as f: f.write(content) options = self.parse_flags(source, extra) modules = self.parse_modules(source) - out_dir = 'out' + out_dir = "out" try: try: - if not testcase.name.endswith('_import'): + if not testcase.name.endswith("_import"): options.no_import = True - if not testcase.name.endswith('_semanal'): + if not testcase.name.endswith("_semanal"): options.parse_only = True generate_stubs(options) a: List[str] = [] @@ -593,9 +708,11 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: self.add_file(fnam, a, header=len(modules) > 1) except CompileError as e: a = e.messages - assert_string_arrays_equal(testcase.output, a, - 'Invalid output ({}, line {})'.format( - testcase.file, testcase.line)) + assert_string_arrays_equal( + testcase.output, + a, + "Invalid output ({}, line {})".format(testcase.file, testcase.line), + ) finally: for mod in mods: if mod in sys.modules: @@ -603,36 +720,36 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: shutil.rmtree(out_dir) def parse_flags(self, program_text: str, extra: List[str]) -> Options: - flags = re.search('# flags: (.*)$', program_text, flags=re.MULTILINE) + flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE) if flags: flag_list = flags.group(1).split() else: flag_list = [] options = parse_options(flag_list + extra) - if '--verbose' not in flag_list: + if "--verbose" not in flag_list: options.quiet = True else: options.verbose = True return options def parse_modules(self, program_text: str) -> List[str]: - modules = re.search('# modules: (.*)$', program_text, flags=re.MULTILINE) + modules = re.search("# modules: (.*)$", program_text, flags=re.MULTILINE) if modules: return modules.group(1).split() else: - return ['main'] + return ["main"] def add_file(self, path: str, result: List[str], header: bool) -> None: if not os.path.exists(path): - result.append('<%s was not generated>' % path.replace('\\', '/')) + result.append("<%s was not generated>" % path.replace("\\", "/")) return if header: - result.append(f'# {path[4:]}') - with open(path, encoding='utf8') as file: + result.append(f"# {path[4:]}") + with open(path, encoding="utf8") as file: result.extend(file.read().splitlines()) -self_arg = ArgSig(name='self') +self_arg = ArgSig(name="self") class TestBaseClass: @@ -650,31 +767,45 @@ class StubgencSuite(unittest.TestCase): """ def test_infer_hash_sig(self) -> None: - assert_equal(infer_method_sig('__hash__'), [self_arg]) + assert_equal(infer_method_sig("__hash__"), [self_arg]) def test_infer_getitem_sig(self) -> None: - assert_equal(infer_method_sig('__getitem__'), [self_arg, ArgSig(name='index')]) + assert_equal(infer_method_sig("__getitem__"), [self_arg, ArgSig(name="index")]) def test_infer_setitem_sig(self) -> None: - assert_equal(infer_method_sig('__setitem__'), - [self_arg, ArgSig(name='index'), ArgSig(name='object')]) + assert_equal( + infer_method_sig("__setitem__"), + [self_arg, ArgSig(name="index"), ArgSig(name="object")], + ) def test_infer_binary_op_sig(self) -> None: - for op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge', - 'add', 'radd', 'sub', 'rsub', 'mul', 'rmul'): - assert_equal(infer_method_sig(f'__{op}__'), [self_arg, ArgSig(name='other')]) + for op in ( + "eq", + "ne", + "lt", + "le", + "gt", + "ge", + "add", + "radd", + "sub", + "rsub", + "mul", + "rmul", + ): + assert_equal(infer_method_sig(f"__{op}__"), [self_arg, ArgSig(name="other")]) def test_infer_unary_op_sig(self) -> None: - for op in ('neg', 'pos'): - assert_equal(infer_method_sig(f'__{op}__'), [self_arg]) + for op in ("neg", "pos"): + assert_equal(infer_method_sig(f"__{op}__"), [self_arg]) def test_generate_c_type_stub_no_crash_for_object(self) -> None: output: List[str] = [] - mod = ModuleType('module', '') # any module is fine + mod = ModuleType("module", "") # any module is fine imports: List[str] = [] - generate_c_type_stub(mod, 'alias', object, output, imports) + generate_c_type_stub(mod, "alias", object, output, imports) assert_equal(imports, []) - assert_equal(output[0], 'class alias:') + assert_equal(output[0], "class alias:") def test_generate_c_type_stub_variable_type_annotation(self) -> None: # This class mimics the stubgen unit test 'testClassVariable' @@ -683,10 +814,10 @@ class TestClassVariableCls: output: List[str] = [] imports: List[str] = [] - mod = ModuleType('module', '') # any module is fine - generate_c_type_stub(mod, 'C', TestClassVariableCls, output, imports) + mod = ModuleType("module", "") # any module is fine + generate_c_type_stub(mod, "C", TestClassVariableCls, output, imports) assert_equal(imports, []) - assert_equal(output, ['class C:', ' x: ClassVar[int] = ...']) + assert_equal(output, ["class C:", " x: ClassVar[int] = ..."]) def test_generate_c_type_inheritance(self) -> None: class TestClass(KeyError): @@ -694,17 +825,17 @@ class TestClass(KeyError): output: List[str] = [] imports: List[str] = [] - mod = ModuleType('module, ') - generate_c_type_stub(mod, 'C', TestClass, output, imports) - assert_equal(output, ['class C(KeyError): ...', ]) + mod = ModuleType("module, ") + generate_c_type_stub(mod, "C", TestClass, output, imports) + assert_equal(output, ["class C(KeyError): ..."]) assert_equal(imports, []) def test_generate_c_type_inheritance_same_module(self) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestBaseClass.__module__, '') - generate_c_type_stub(mod, 'C', TestClass, output, imports) - assert_equal(output, ['class C(TestBaseClass): ...', ]) + mod = ModuleType(TestBaseClass.__module__, "") + generate_c_type_stub(mod, "C", TestClass, output, imports) + assert_equal(output, ["class C(TestBaseClass): ..."]) assert_equal(imports, []) def test_generate_c_type_inheritance_other_module(self) -> None: @@ -715,10 +846,10 @@ class TestClass(argparse.Action): output: List[str] = [] imports: List[str] = [] - mod = ModuleType('module', '') - generate_c_type_stub(mod, 'C', TestClass, output, imports) - assert_equal(output, ['class C(argparse.Action): ...', ]) - assert_equal(imports, ['import argparse']) + mod = ModuleType("module", "") + generate_c_type_stub(mod, "C", TestClass, output, imports) + assert_equal(output, ["class C(argparse.Action): ..."]) + assert_equal(imports, ["import argparse"]) def test_generate_c_type_inheritance_builtin_type(self) -> None: class TestClass(type): @@ -726,9 +857,9 @@ class TestClass(type): output: List[str] = [] imports: List[str] = [] - mod = ModuleType('module', '') - generate_c_type_stub(mod, 'C', TestClass, output, imports) - assert_equal(output, ['class C(type): ...', ]) + mod = ModuleType("module", "") + generate_c_type_stub(mod, "C", TestClass, output, imports) + assert_equal(output, ["class C(type): ..."]) assert_equal(imports, []) def test_generate_c_type_with_docstring(self) -> None: @@ -741,10 +872,11 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: int) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_with_docstring_no_self_arg(self) -> None: @@ -754,12 +886,14 @@ def test(self, arg0: str) -> None: test(arg0: int) """ pass + output = [] # type: List[str] imports = [] # type: List[str] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: int) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_classmethod(self) -> None: @@ -767,12 +901,14 @@ class TestClass: @classmethod def test(cls, arg0: str) -> None: pass + output = [] # type: List[str] imports = [] # type: List[str] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='cls', class_name='TestClass') - assert_equal(output, ['def test(cls, *args, **kwargs) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="cls", class_name="TestClass" + ) + assert_equal(output, ["def test(cls, *args, **kwargs) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_with_docstring_empty_default(self) -> None: @@ -785,10 +921,11 @@ def test(self, arg0: str = "") -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: str = ...) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: str = ...) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_function_other_module_arg(self) -> None: @@ -803,10 +940,10 @@ def test(arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(self.__module__, '') - generate_c_function_stub(mod, 'test', test, output, imports) - assert_equal(output, ['def test(arg0: argparse.Action) -> Any: ...']) - assert_equal(imports, ['import argparse']) + mod = ModuleType(self.__module__, "") + generate_c_function_stub(mod, "test", test, output, imports) + assert_equal(output, ["def test(arg0: argparse.Action) -> Any: ..."]) + assert_equal(imports, ["import argparse"]) def test_generate_c_function_same_module_arg(self) -> None: """Test that if argument references type from same module but using full path, no module @@ -822,13 +959,14 @@ def test(arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType('argparse', '') - generate_c_function_stub(mod, 'test', test, output, imports) - assert_equal(output, ['def test(arg0: Action) -> Any: ...']) + mod = ModuleType("argparse", "") + generate_c_function_stub(mod, "test", test, output, imports) + assert_equal(output, ["def test(arg0: Action) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_function_other_module_ret(self) -> None: """Test that if return type references type from other module, module will be imported.""" + def test(arg0: str) -> None: """ test(arg0: str) -> argparse.Action @@ -837,15 +975,16 @@ def test(arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(self.__module__, '') - generate_c_function_stub(mod, 'test', test, output, imports) - assert_equal(output, ['def test(arg0: str) -> argparse.Action: ...']) - assert_equal(imports, ['import argparse']) + mod = ModuleType(self.__module__, "") + generate_c_function_stub(mod, "test", test, output, imports) + assert_equal(output, ["def test(arg0: str) -> argparse.Action: ..."]) + assert_equal(imports, ["import argparse"]) def test_generate_c_function_same_module_ret(self) -> None: """Test that if return type references type from same module but using full path, no module will be imported, and type specification will be striped to local reference. """ + def test(arg0: str) -> None: """ test(arg0: str) -> argparse.Action @@ -854,28 +993,35 @@ def test(arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType('argparse', '') - generate_c_function_stub(mod, 'test', test, output, imports) - assert_equal(output, ['def test(arg0: str) -> Action: ...']) + mod = ModuleType("argparse", "") + generate_c_function_stub(mod, "test", test, output, imports) + assert_equal(output, ["def test(arg0: str) -> Action: ..."]) assert_equal(imports, []) def test_generate_c_property_with_pybind11(self) -> None: """Signatures included by PyBind11 inside property.fget are read.""" + class TestClass: def get_attribute(self) -> None: """ (self: TestClass) -> str """ pass + attribute = property(get_attribute, doc="") readwrite_properties: List[str] = [] readonly_properties: List[str] = [] - generate_c_property_stub('attribute', TestClass.attribute, [], - readwrite_properties, readonly_properties, - is_c_property_readonly(TestClass.attribute)) + generate_c_property_stub( + "attribute", + TestClass.attribute, + [], + readwrite_properties, + readonly_properties, + is_c_property_readonly(TestClass.attribute), + ) assert_equal(readwrite_properties, []) - assert_equal(readonly_properties, ['@property', 'def attribute(self) -> str: ...']) + assert_equal(readonly_properties, ["@property", "def attribute(self) -> str: ..."]) def test_generate_c_property_with_rw_property(self) -> None: class TestClass: @@ -892,10 +1038,15 @@ def attribute(self, value: int) -> None: readwrite_properties: List[str] = [] readonly_properties: List[str] = [] - generate_c_property_stub("attribute", type(TestClass.attribute), [], - readwrite_properties, readonly_properties, - is_c_property_readonly(TestClass.attribute)) - assert_equal(readwrite_properties, ['attribute: Any']) + generate_c_property_stub( + "attribute", + type(TestClass.attribute), + [], + readwrite_properties, + readonly_properties, + is_c_property_readonly(TestClass.attribute), + ) + assert_equal(readwrite_properties, ["attribute: Any"]) assert_equal(readonly_properties, []) def test_generate_c_type_with_single_arg_generic(self) -> None: @@ -908,10 +1059,11 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: List[int]) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: List[int]) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_with_double_arg_generic(self) -> None: @@ -924,10 +1076,11 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: Dict[str,int]) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: Dict[str,int]) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_with_nested_generic(self) -> None: @@ -940,10 +1093,11 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: Dict[str,List[int]]) -> Any: ...']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: Dict[str,List[int]]) -> Any: ..."]) assert_equal(imports, []) def test_generate_c_type_with_generic_using_other_module_first(self) -> None: @@ -956,11 +1110,12 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: Dict[argparse.Action,int]) -> Any: ...']) - assert_equal(imports, ['import argparse']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: Dict[argparse.Action,int]) -> Any: ..."]) + assert_equal(imports, ["import argparse"]) def test_generate_c_type_with_generic_using_other_module_last(self) -> None: class TestClass: @@ -972,11 +1127,12 @@ def test(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, 'test', TestClass.test, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, ['def test(self, arg0: Dict[str,argparse.Action]) -> Any: ...']) - assert_equal(imports, ['import argparse']) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, "test", TestClass.test, output, imports, self_var="self", class_name="TestClass" + ) + assert_equal(output, ["def test(self, arg0: Dict[str,argparse.Action]) -> Any: ..."]) + assert_equal(imports, ["import argparse"]) def test_generate_c_type_with_overload_pybind11(self) -> None: class TestClass: @@ -993,54 +1149,68 @@ def __init__(self, arg0: str) -> None: output: List[str] = [] imports: List[str] = [] - mod = ModuleType(TestClass.__module__, '') - generate_c_function_stub(mod, '__init__', TestClass.__init__, output, imports, - self_var='self', class_name='TestClass') - assert_equal(output, [ - '@overload', - 'def __init__(self, arg0: str) -> None: ...', - '@overload', - 'def __init__(self, arg0: str, arg1: str) -> None: ...', - '@overload', - 'def __init__(*args, **kwargs) -> Any: ...']) - assert_equal(set(imports), {'from typing import overload'}) + mod = ModuleType(TestClass.__module__, "") + generate_c_function_stub( + mod, + "__init__", + TestClass.__init__, + output, + imports, + self_var="self", + class_name="TestClass", + ) + assert_equal( + output, + [ + "@overload", + "def __init__(self, arg0: str) -> None: ...", + "@overload", + "def __init__(self, arg0: str, arg1: str) -> None: ...", + "@overload", + "def __init__(*args, **kwargs) -> Any: ...", + ], + ) + assert_equal(set(imports), {"from typing import overload"}) class ArgSigSuite(unittest.TestCase): def test_repr(self) -> None: - assert_equal(repr(ArgSig(name='asd"dsa')), - "ArgSig(name='asd\"dsa', type=None, default=False)") - assert_equal(repr(ArgSig(name="asd'dsa")), - 'ArgSig(name="asd\'dsa", type=None, default=False)') - assert_equal(repr(ArgSig("func", 'str')), - "ArgSig(name='func', type='str', default=False)") - assert_equal(repr(ArgSig("func", 'str', default=True)), - "ArgSig(name='func', type='str', default=True)") + assert_equal( + repr(ArgSig(name='asd"dsa')), "ArgSig(name='asd\"dsa', type=None, default=False)" + ) + assert_equal( + repr(ArgSig(name="asd'dsa")), 'ArgSig(name="asd\'dsa", type=None, default=False)' + ) + assert_equal(repr(ArgSig("func", "str")), "ArgSig(name='func', type='str', default=False)") + assert_equal( + repr(ArgSig("func", "str", default=True)), + "ArgSig(name='func', type='str', default=True)", + ) class IsValidTypeSuite(unittest.TestCase): def test_is_valid_type(self) -> None: - assert is_valid_type('int') - assert is_valid_type('str') - assert is_valid_type('Foo_Bar234') - assert is_valid_type('foo.bar') - assert is_valid_type('List[int]') - assert is_valid_type('Dict[str, int]') - assert is_valid_type('None') - assert not is_valid_type('foo-bar') - assert not is_valid_type('x->y') - assert not is_valid_type('True') - assert not is_valid_type('False') - assert not is_valid_type('x,y') - assert not is_valid_type('x, y') + assert is_valid_type("int") + assert is_valid_type("str") + assert is_valid_type("Foo_Bar234") + assert is_valid_type("foo.bar") + assert is_valid_type("List[int]") + assert is_valid_type("Dict[str, int]") + assert is_valid_type("None") + assert not is_valid_type("foo-bar") + assert not is_valid_type("x->y") + assert not is_valid_type("True") + assert not is_valid_type("False") + assert not is_valid_type("x,y") + assert not is_valid_type("x, y") class ModuleInspectSuite(unittest.TestCase): def test_python_module(self) -> None: with ModuleInspect() as m: - p = m.get_package_properties('inspect') + p = m.get_package_properties("inspect") assert p is not None - assert p.name == 'inspect' + assert p.name == "inspect" assert p.file assert p.path is None assert p.is_c_module is False @@ -1048,20 +1218,20 @@ def test_python_module(self) -> None: def test_python_package(self) -> None: with ModuleInspect() as m: - p = m.get_package_properties('unittest') + p = m.get_package_properties("unittest") assert p is not None - assert p.name == 'unittest' + assert p.name == "unittest" assert p.file assert p.path assert p.is_c_module is False assert p.subpackages - assert all(sub.startswith('unittest.') for sub in p.subpackages) + assert all(sub.startswith("unittest.") for sub in p.subpackages) def test_c_module(self) -> None: with ModuleInspect() as m: - p = m.get_package_properties('_socket') + p = m.get_package_properties("_socket") assert p is not None - assert p.name == '_socket' + assert p.name == "_socket" assert p.path is None assert p.is_c_module is True assert p.subpackages == [] @@ -1069,14 +1239,14 @@ def test_c_module(self) -> None: def test_non_existent(self) -> None: with ModuleInspect() as m: with self.assertRaises(InspectError) as e: - m.get_package_properties('foobar-non-existent') + m.get_package_properties("foobar-non-existent") assert str(e.exception) == "No module named 'foobar-non-existent'" def module_to_path(out_dir: str, module: str) -> str: fnam = os.path.join(out_dir, f"{module.replace('.', '/')}.pyi") if not os.path.exists(fnam): - alt_fnam = fnam.replace('.pyi', '/__init__.pyi') + alt_fnam = fnam.replace(".pyi", "/__init__.pyi") if os.path.exists(alt_fnam): return alt_fnam return fnam diff --git a/mypy/test/teststubinfo.py b/mypy/test/teststubinfo.py index e00a68a24df0d..36c6721453821 100644 --- a/mypy/test/teststubinfo.py +++ b/mypy/test/teststubinfo.py @@ -5,14 +5,14 @@ class TestStubInfo(unittest.TestCase): def test_is_legacy_bundled_packages(self) -> None: - assert not is_legacy_bundled_package('foobar_asdf', 2) - assert not is_legacy_bundled_package('foobar_asdf', 3) + assert not is_legacy_bundled_package("foobar_asdf", 2) + assert not is_legacy_bundled_package("foobar_asdf", 3) - assert is_legacy_bundled_package('pycurl', 2) - assert is_legacy_bundled_package('pycurl', 3) + assert is_legacy_bundled_package("pycurl", 2) + assert is_legacy_bundled_package("pycurl", 3) - assert is_legacy_bundled_package('scribe', 2) - assert not is_legacy_bundled_package('scribe', 3) + assert is_legacy_bundled_package("scribe", 2) + assert not is_legacy_bundled_package("scribe", 3) - assert not is_legacy_bundled_package('dataclasses', 2) - assert is_legacy_bundled_package('dataclasses', 3) + assert not is_legacy_bundled_package("dataclasses", 2) + assert is_legacy_bundled_package("dataclasses", 3) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 197669714ad36..ef06608a9c1b8 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -100,7 +100,7 @@ def staticmethod(f: T) -> T: ... def run_stubtest( - stub: str, runtime: str, options: List[str], config_file: Optional[str] = None, + stub: str, runtime: str, options: List[str], config_file: Optional[str] = None ) -> str: with use_tmp_dir(TEST_MODULE_NAME) as tmp_dir: with open("builtins.pyi", "w") as f: @@ -117,10 +117,7 @@ def run_stubtest( options = options + ["--mypy-config-file", f"{TEST_MODULE_NAME}_config.ini"] output = io.StringIO() with contextlib.redirect_stdout(output): - test_stubs( - parse_options([TEST_MODULE_NAME] + options), - use_builtins_fixtures=True - ) + test_stubs(parse_options([TEST_MODULE_NAME] + options), use_builtins_fixtures=True) # remove cwd as it's not available from outside return output.getvalue().replace(tmp_dir + os.sep, "") @@ -212,26 +209,12 @@ class X: @collect_cases def test_coroutines(self) -> Iterator[Case]: - yield Case( - stub="def bar() -> int: ...", - runtime="async def bar(): return 5", - error="bar", - ) + yield Case(stub="def bar() -> int: ...", runtime="async def bar(): return 5", error="bar") # Don't error for this one -- we get false positives otherwise + yield Case(stub="async def foo() -> int: ...", runtime="def foo(): return 5", error=None) + yield Case(stub="def baz() -> int: ...", runtime="def baz(): return 5", error=None) yield Case( - stub="async def foo() -> int: ...", - runtime="def foo(): return 5", - error=None, - ) - yield Case( - stub="def baz() -> int: ...", - runtime="def baz(): return 5", - error=None, - ) - yield Case( - stub="async def bingo() -> int: ...", - runtime="async def bingo(): return 5", - error=None, + stub="async def bingo() -> int: ...", runtime="async def bingo(): return 5", error=None ) @collect_cases @@ -717,17 +700,9 @@ class Y: ... error="A", ) # Error if an alias isn't present at runtime... - yield Case( - stub="B = str", - runtime="", - error="B" - ) + yield Case(stub="B = str", runtime="", error="B") # ... but only if the alias isn't private - yield Case( - stub="_C = int", - runtime="", - error=None - ) + yield Case(stub="_C = int", runtime="", error=None) @collect_cases def test_enum(self) -> Iterator[Case]: @@ -792,12 +767,12 @@ def h(x: str): ... yield Case( stub="class Y: ...", runtime="__all__ += ['Y']\nclass Y:\n def __or__(self, other): return self|other", - error="Y.__or__" + error="Y.__or__", ) yield Case( stub="class Z: ...", runtime="__all__ += ['Z']\nclass Z:\n def __reduce__(self): return (Z,)", - error=None + error=None, ) @collect_cases @@ -815,9 +790,7 @@ def test_non_public_1(self) -> Iterator[Case]: @collect_cases def test_non_public_2(self) -> Iterator[Case]: - yield Case( - stub="__all__: list[str] = ['f']", runtime="__all__ = ['f']", error=None - ) + yield Case(stub="__all__: list[str] = ['f']", runtime="__all__ = ['f']", error=None) yield Case(stub="f: int", runtime="def f(): ...", error="f") yield Case(stub="g: int", runtime="def g(): ...", error="g") @@ -853,9 +826,7 @@ def test_dunders(self) -> Iterator[Case]: @collect_cases def test_not_subclassable(self) -> Iterator[Case]: yield Case( - stub="class CanBeSubclassed: ...", - runtime="class CanBeSubclassed: ...", - error=None, + stub="class CanBeSubclassed: ...", runtime="class CanBeSubclassed: ...", error=None ) yield Case( stub="class CannotBeSubclassed:\n def __init_subclass__(cls) -> None: ...", @@ -876,7 +847,7 @@ class X: def __mangle_good(self, text): pass def __mangle_bad(self, text): pass """, - error="X.__mangle_bad" + error="X.__mangle_bad", ) @collect_cases @@ -898,7 +869,7 @@ def foo(self, x: int) -> None: ... class C(A): def foo(self, y: int) -> None: ... """, - error="C.foo" + error="C.foo", ) yield Case( stub=""" @@ -908,7 +879,7 @@ class X: ... class X: def __init__(self, x): pass """, - error="X.__init__" + error="X.__init__", ) @collect_cases @@ -953,16 +924,8 @@ def test_bad_literal(self) -> Iterator[Case]: runtime="INT_FLOAT_MISMATCH = 1.0", error="INT_FLOAT_MISMATCH", ) - yield Case( - stub="WRONG_INT: Literal[1]", - runtime="WRONG_INT = 2", - error="WRONG_INT", - ) - yield Case( - stub="WRONG_STR: Literal['a']", - runtime="WRONG_STR = 'b'", - error="WRONG_STR", - ) + yield Case(stub="WRONG_INT: Literal[1]", runtime="WRONG_INT = 2", error="WRONG_INT") + yield Case(stub="WRONG_STR: Literal['a']", runtime="WRONG_STR = 'b'", error="WRONG_STR") yield Case( stub="BYTES_STR_MISMATCH: Literal[b'value']", runtime="BYTES_STR_MISMATCH = 'value'", @@ -981,12 +944,12 @@ def test_bad_literal(self) -> Iterator[Case]: yield Case( stub="WRONG_BOOL_1: Literal[True]", runtime="WRONG_BOOL_1 = False", - error='WRONG_BOOL_1', + error="WRONG_BOOL_1", ) yield Case( stub="WRONG_BOOL_2: Literal[False]", runtime="WRONG_BOOL_2 = True", - error='WRONG_BOOL_2', + error="WRONG_BOOL_2", ) @collect_cases @@ -1043,7 +1006,7 @@ class X(Protocol): bar: int def foo(self, x: int, y: bytes = ...) -> str: ... """, - error=None + error=None, ) @collect_cases @@ -1051,27 +1014,15 @@ def test_type_var(self) -> Iterator[Case]: yield Case( stub="from typing import TypeVar", runtime="from typing import TypeVar", error=None ) - yield Case( - stub="A = TypeVar('A')", - runtime="A = TypeVar('A')", - error=None, - ) - yield Case( - stub="B = TypeVar('B')", - runtime="B = 5", - error="B", - ) + yield Case(stub="A = TypeVar('A')", runtime="A = TypeVar('A')", error=None) + yield Case(stub="B = TypeVar('B')", runtime="B = 5", error="B") if sys.version_info >= (3, 10): yield Case( stub="from typing import ParamSpec", runtime="from typing import ParamSpec", - error=None - ) - yield Case( - stub="C = ParamSpec('C')", - runtime="C = ParamSpec('C')", error=None, ) + yield Case(stub="C = ParamSpec('C')", runtime="C = ParamSpec('C')", error=None) def remove_color_code(s: str) -> str: @@ -1088,9 +1039,9 @@ def test_output(self) -> None: expected = ( f'error: {TEST_MODULE_NAME}.bad is inconsistent, stub argument "number" differs ' 'from runtime argument "num"\n' - 'Stub: at line 1\ndef (number: builtins.int, text: builtins.str)\n' + "Stub: at line 1\ndef (number: builtins.int, text: builtins.str)\n" f"Runtime: at line 1 in file {TEST_MODULE_NAME}.py\ndef (num, text)\n\n" - 'Found 1 error (checked 1 module)\n' + "Found 1 error (checked 1 module)\n" ) assert remove_color_code(output) == expected @@ -1109,17 +1060,15 @@ def test_ignore_flags(self) -> None: output = run_stubtest( stub="", runtime="__all__ = ['f']\ndef f(): pass", options=["--ignore-missing-stub"] ) - assert output == 'Success: no issues found in 1 module\n' + assert output == "Success: no issues found in 1 module\n" - output = run_stubtest( - stub="", runtime="def f(): pass", options=["--ignore-missing-stub"] - ) - assert output == 'Success: no issues found in 1 module\n' + output = run_stubtest(stub="", runtime="def f(): pass", options=["--ignore-missing-stub"]) + assert output == "Success: no issues found in 1 module\n" output = run_stubtest( stub="def f(__a): ...", runtime="def f(a): pass", options=["--ignore-positional-only"] ) - assert output == 'Success: no issues found in 1 module\n' + assert output == "Success: no issues found in 1 module\n" def test_allowlist(self) -> None: # Can't use this as a context because Windows @@ -1133,7 +1082,7 @@ def test_allowlist(self) -> None: runtime="def bad(asdf, text): pass", options=["--allowlist", allowlist.name], ) - assert output == 'Success: no issues found in 1 module\n' + assert output == "Success: no issues found in 1 module\n" # test unused entry detection output = run_stubtest(stub="", runtime="", options=["--allowlist", allowlist.name]) @@ -1147,7 +1096,7 @@ def test_allowlist(self) -> None: runtime="", options=["--allowlist", allowlist.name, "--ignore-unused-allowlist"], ) - assert output == 'Success: no issues found in 1 module\n' + assert output == "Success: no issues found in 1 module\n" # test regex matching with open(allowlist.name, mode="w+") as f: @@ -1161,20 +1110,23 @@ def test_allowlist(self) -> None: def good() -> None: ... def bad(number: int) -> None: ... def also_bad(number: int) -> None: ... - """.lstrip("\n") + """.lstrip( + "\n" + ) ), runtime=textwrap.dedent( """ def good(): pass def bad(asdf): pass def also_bad(asdf): pass - """.lstrip("\n") + """.lstrip( + "\n" + ) ), options=["--allowlist", allowlist.name, "--generate-allowlist"], ) assert output == ( - f"note: unused allowlist entry unused.*\n" - f"{TEST_MODULE_NAME}.also_bad\n" + f"note: unused allowlist entry unused.*\n" f"{TEST_MODULE_NAME}.also_bad\n" ) finally: os.unlink(allowlist.name) @@ -1188,7 +1140,7 @@ def test_mypy_build(self) -> None: output = run_stubtest(stub="def f(): ...\ndef f(): ...", runtime="", options=[]) assert remove_color_code(output) == ( - 'error: not checking stubs due to mypy build errors:\n{}.pyi:2: ' + "error: not checking stubs due to mypy build errors:\n{}.pyi:2: " 'error: Name "f" already defined on line 1\n'.format(TEST_MODULE_NAME) ) @@ -1212,7 +1164,7 @@ def test_only_py(self) -> None: with contextlib.redirect_stdout(output): test_stubs(parse_options([TEST_MODULE_NAME])) output_str = remove_color_code(output.getvalue()) - assert output_str == 'Success: no issues found in 1 module\n' + assert output_str == "Success: no issues found in 1 module\n" def test_get_typeshed_stdlib_modules(self) -> None: stdlib = mypy.stubtest.get_typeshed_stdlib_modules(None, (3, 6)) @@ -1241,9 +1193,7 @@ def f(a: int, b: int, *, c: int, d: int = 0, **kwargs: Any) -> None: def test_config_file(self) -> None: runtime = "temp = 5\n" stub = "from decimal import Decimal\ntemp: Decimal\n" - config_file = ( - f"[mypy]\nplugins={root_dir}/test-data/unit/plugins/decimal_to_int.py\n" - ) + config_file = f"[mypy]\nplugins={root_dir}/test-data/unit/plugins/decimal_to_int.py\n" output = run_stubtest(stub=stub, runtime=runtime, options=[]) assert remove_color_code(output) == ( f"error: {TEST_MODULE_NAME}.temp variable differs from runtime type Literal[5]\n" diff --git a/mypy/test/testsubtypes.py b/mypy/test/testsubtypes.py index 5b556a1dc16e3..c1eacb9fd8592 100644 --- a/mypy/test/testsubtypes.py +++ b/mypy/test/testsubtypes.py @@ -1,8 +1,8 @@ -from mypy.test.helpers import Suite, skip -from mypy.nodes import CONTRAVARIANT, INVARIANT, COVARIANT +from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT from mypy.subtypes import is_subtype -from mypy.test.typefixture import TypeFixture, InterfaceTypeFixture -from mypy.types import Type, Instance, UnpackType, TupleType +from mypy.test.helpers import Suite, skip +from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture +from mypy.types import Instance, TupleType, Type, UnpackType class SubtypingSuite(Suite): @@ -78,110 +78,127 @@ def test_generic_interface_subtyping(self) -> None: self.assert_equivalent(fx2.gfa, fx2.gfa) def test_basic_callable_subtyping(self) -> None: - self.assert_strict_subtype(self.fx.callable(self.fx.o, self.fx.d), - self.fx.callable(self.fx.a, self.fx.d)) - self.assert_strict_subtype(self.fx.callable(self.fx.d, self.fx.b), - self.fx.callable(self.fx.d, self.fx.a)) + self.assert_strict_subtype( + self.fx.callable(self.fx.o, self.fx.d), self.fx.callable(self.fx.a, self.fx.d) + ) + self.assert_strict_subtype( + self.fx.callable(self.fx.d, self.fx.b), self.fx.callable(self.fx.d, self.fx.a) + ) - self.assert_strict_subtype(self.fx.callable(self.fx.a, self.fx.nonet), - self.fx.callable(self.fx.a, self.fx.a)) + self.assert_strict_subtype( + self.fx.callable(self.fx.a, self.fx.nonet), self.fx.callable(self.fx.a, self.fx.a) + ) self.assert_unrelated( self.fx.callable(self.fx.a, self.fx.a, self.fx.a), - self.fx.callable(self.fx.a, self.fx.a)) + self.fx.callable(self.fx.a, self.fx.a), + ) def test_default_arg_callable_subtyping(self) -> None: self.assert_strict_subtype( self.fx.callable_default(1, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable(self.fx.a, self.fx.d, self.fx.a)) + self.fx.callable(self.fx.a, self.fx.d, self.fx.a), + ) self.assert_strict_subtype( self.fx.callable_default(1, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable(self.fx.a, self.fx.a)) + self.fx.callable(self.fx.a, self.fx.a), + ) self.assert_strict_subtype( self.fx.callable_default(0, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable_default(1, self.fx.a, self.fx.d, self.fx.a)) + self.fx.callable_default(1, self.fx.a, self.fx.d, self.fx.a), + ) self.assert_unrelated( self.fx.callable_default(1, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable(self.fx.d, self.fx.d, self.fx.a)) + self.fx.callable(self.fx.d, self.fx.d, self.fx.a), + ) self.assert_unrelated( self.fx.callable_default(0, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable_default(1, self.fx.a, self.fx.a, self.fx.a)) + self.fx.callable_default(1, self.fx.a, self.fx.a, self.fx.a), + ) self.assert_unrelated( self.fx.callable_default(1, self.fx.a, self.fx.a), - self.fx.callable(self.fx.a, self.fx.a, self.fx.a)) + self.fx.callable(self.fx.a, self.fx.a, self.fx.a), + ) def test_var_arg_callable_subtyping_1(self) -> None: self.assert_strict_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.a), - self.fx.callable_var_arg(0, self.fx.b, self.fx.a)) + self.fx.callable_var_arg(0, self.fx.b, self.fx.a), + ) def test_var_arg_callable_subtyping_2(self) -> None: self.assert_strict_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.a), - self.fx.callable(self.fx.b, self.fx.a)) + self.fx.callable(self.fx.b, self.fx.a), + ) def test_var_arg_callable_subtyping_3(self) -> None: self.assert_strict_subtype( - self.fx.callable_var_arg(0, self.fx.a, self.fx.a), - self.fx.callable(self.fx.a)) + self.fx.callable_var_arg(0, self.fx.a, self.fx.a), self.fx.callable(self.fx.a) + ) def test_var_arg_callable_subtyping_4(self) -> None: self.assert_strict_subtype( self.fx.callable_var_arg(1, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable(self.fx.b, self.fx.a)) + self.fx.callable(self.fx.b, self.fx.a), + ) def test_var_arg_callable_subtyping_5(self) -> None: self.assert_strict_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.d, self.fx.a), - self.fx.callable(self.fx.b, self.fx.a)) + self.fx.callable(self.fx.b, self.fx.a), + ) def test_var_arg_callable_subtyping_6(self) -> None: self.assert_strict_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.f, self.fx.d), - self.fx.callable_var_arg(0, self.fx.b, self.fx.e, self.fx.d)) + self.fx.callable_var_arg(0, self.fx.b, self.fx.e, self.fx.d), + ) def test_var_arg_callable_subtyping_7(self) -> None: self.assert_not_subtype( self.fx.callable_var_arg(0, self.fx.b, self.fx.d), - self.fx.callable(self.fx.a, self.fx.d)) + self.fx.callable(self.fx.a, self.fx.d), + ) def test_var_arg_callable_subtyping_8(self) -> None: self.assert_not_subtype( self.fx.callable_var_arg(0, self.fx.b, self.fx.d), - self.fx.callable_var_arg(0, self.fx.a, self.fx.a, self.fx.d)) + self.fx.callable_var_arg(0, self.fx.a, self.fx.a, self.fx.d), + ) self.assert_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.d), - self.fx.callable_var_arg(0, self.fx.b, self.fx.b, self.fx.d)) + self.fx.callable_var_arg(0, self.fx.b, self.fx.b, self.fx.d), + ) def test_var_arg_callable_subtyping_9(self) -> None: self.assert_not_subtype( self.fx.callable_var_arg(0, self.fx.b, self.fx.b, self.fx.d), - self.fx.callable_var_arg(0, self.fx.a, self.fx.d)) + self.fx.callable_var_arg(0, self.fx.a, self.fx.d), + ) self.assert_subtype( self.fx.callable_var_arg(0, self.fx.a, self.fx.a, self.fx.d), - self.fx.callable_var_arg(0, self.fx.b, self.fx.d)) + self.fx.callable_var_arg(0, self.fx.b, self.fx.d), + ) def test_type_callable_subtyping(self) -> None: - self.assert_subtype( - self.fx.callable_type(self.fx.d, self.fx.a), self.fx.type_type) + self.assert_subtype(self.fx.callable_type(self.fx.d, self.fx.a), self.fx.type_type) self.assert_strict_subtype( - self.fx.callable_type(self.fx.d, self.fx.b), - self.fx.callable(self.fx.d, self.fx.a)) + self.fx.callable_type(self.fx.d, self.fx.b), self.fx.callable(self.fx.d, self.fx.a) + ) - self.assert_strict_subtype(self.fx.callable_type(self.fx.a, self.fx.b), - self.fx.callable(self.fx.a, self.fx.b)) + self.assert_strict_subtype( + self.fx.callable_type(self.fx.a, self.fx.b), self.fx.callable(self.fx.a, self.fx.b) + ) def test_type_var_tuple(self) -> None: - self.assert_subtype( - Instance(self.fx.gvi, []), - Instance(self.fx.gvi, []), - ) + self.assert_subtype(Instance(self.fx.gvi, []), Instance(self.fx.gvi, [])) self.assert_subtype( Instance(self.fx.gvi, [self.fx.a, self.fx.b]), Instance(self.fx.gvi, [self.fx.a, self.fx.b]), @@ -191,8 +208,7 @@ def test_type_var_tuple(self) -> None: Instance(self.fx.gvi, [self.fx.b, self.fx.a]), ) self.assert_not_subtype( - Instance(self.fx.gvi, [self.fx.a, self.fx.b]), - Instance(self.fx.gvi, [self.fx.a]), + Instance(self.fx.gvi, [self.fx.a, self.fx.b]), Instance(self.fx.gvi, [self.fx.a]) ) self.assert_subtype( @@ -209,12 +225,10 @@ def test_type_var_tuple(self) -> None: Instance(self.fx.gvi, [self.fx.anyt]), ) self.assert_not_subtype( - Instance(self.fx.gvi, [UnpackType(self.fx.ss)]), - Instance(self.fx.gvi, []), + Instance(self.fx.gvi, [UnpackType(self.fx.ss)]), Instance(self.fx.gvi, []) ) self.assert_not_subtype( - Instance(self.fx.gvi, [UnpackType(self.fx.ss)]), - Instance(self.fx.gvi, [self.fx.anyt]), + Instance(self.fx.gvi, [UnpackType(self.fx.ss)]), Instance(self.fx.gvi, [self.fx.anyt]) ) def test_type_var_tuple_with_prefix_suffix(self) -> None: @@ -259,35 +273,59 @@ def test_type_var_tuple_with_prefix_suffix(self) -> None: def test_type_var_tuple_unpacked_tuple(self) -> None: self.assert_subtype( - Instance(self.fx.gvi, [ - UnpackType(TupleType( - [self.fx.a, self.fx.b], fallback=Instance(self.fx.std_tuplei, [self.fx.o]), - )) - ]), + Instance( + self.fx.gvi, + [ + UnpackType( + TupleType( + [self.fx.a, self.fx.b], + fallback=Instance(self.fx.std_tuplei, [self.fx.o]), + ) + ) + ], + ), Instance(self.fx.gvi, [self.fx.a, self.fx.b]), ) self.assert_subtype( - Instance(self.fx.gvi, [ - UnpackType(TupleType( - [self.fx.a, self.fx.b], fallback=Instance(self.fx.std_tuplei, [self.fx.o]), - )) - ]), + Instance( + self.fx.gvi, + [ + UnpackType( + TupleType( + [self.fx.a, self.fx.b], + fallback=Instance(self.fx.std_tuplei, [self.fx.o]), + ) + ) + ], + ), Instance(self.fx.gvi, [self.fx.anyt, self.fx.anyt]), ) self.assert_not_subtype( - Instance(self.fx.gvi, [ - UnpackType(TupleType( - [self.fx.a, self.fx.b], fallback=Instance(self.fx.std_tuplei, [self.fx.o]), - )) - ]), + Instance( + self.fx.gvi, + [ + UnpackType( + TupleType( + [self.fx.a, self.fx.b], + fallback=Instance(self.fx.std_tuplei, [self.fx.o]), + ) + ) + ], + ), Instance(self.fx.gvi, [self.fx.a]), ) self.assert_not_subtype( - Instance(self.fx.gvi, [ - UnpackType(TupleType( - [self.fx.a, self.fx.b], fallback=Instance(self.fx.std_tuplei, [self.fx.o]), - )) - ]), + Instance( + self.fx.gvi, + [ + UnpackType( + TupleType( + [self.fx.a, self.fx.b], + fallback=Instance(self.fx.std_tuplei, [self.fx.o]), + ) + ) + ], + ), # Order flipped here. Instance(self.fx.gvi, [self.fx.b, self.fx.a]), ) @@ -295,9 +333,7 @@ def test_type_var_tuple_unpacked_tuple(self) -> None: def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None: self.assert_strict_subtype( Instance(self.fx.gvi, [self.fx.a, self.fx.a]), - Instance(self.fx.gvi, [ - UnpackType(Instance(self.fx.std_tuplei, [self.fx.a])), - ]), + Instance(self.fx.gvi, [UnpackType(Instance(self.fx.std_tuplei, [self.fx.a]))]), ) # IDEA: Maybe add these test cases (they are tested pretty well in type @@ -311,10 +347,10 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None: # * generic function types def assert_subtype(self, s: Type, t: Type) -> None: - assert is_subtype(s, t), f'{s} not subtype of {t}' + assert is_subtype(s, t), f"{s} not subtype of {t}" def assert_not_subtype(self, s: Type, t: Type) -> None: - assert not is_subtype(s, t), f'{s} subtype of {t}' + assert not is_subtype(s, t), f"{s} subtype of {t}" def assert_strict_subtype(self, s: Type, t: Type) -> None: self.assert_subtype(s, t) diff --git a/mypy/test/testtransform.py b/mypy/test/testtransform.py index 3b9b77a9cf586..8d54899527b8b 100644 --- a/mypy/test/testtransform.py +++ b/mypy/test/testtransform.py @@ -3,26 +3,26 @@ import os.path from mypy import build +from mypy.errors import CompileError from mypy.modulefinder import BuildSource -from mypy.test.helpers import ( - assert_string_arrays_equal, normalize_error_messages, parse_options -) -from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase, DataSuite +from mypy.test.helpers import assert_string_arrays_equal, normalize_error_messages, parse_options from mypy.test.visitors import TypeAssertTransformVisitor -from mypy.errors import CompileError class TransformSuite(DataSuite): required_out_section = True # Reuse semantic analysis test cases. - files = ['semanal-basic.test', - 'semanal-expressions.test', - 'semanal-classes.test', - 'semanal-types.test', - 'semanal-modules.test', - 'semanal-statements.test', - 'semanal-abstractclasses.test'] + files = [ + "semanal-basic.test", + "semanal-expressions.test", + "semanal-classes.test", + "semanal-types.test", + "semanal-modules.test", + "semanal-statements.test", + "semanal-abstractclasses.test", + ] native_sep = True def run_case(self, testcase: DataDrivenTestCase) -> None: @@ -33,15 +33,15 @@ def test_transform(testcase: DataDrivenTestCase) -> None: """Perform an identity transform test case.""" try: - src = '\n'.join(testcase.input) + src = "\n".join(testcase.input) options = parse_options(src, testcase, 1) options.use_builtins_fixtures = True options.semantic_analysis_only = True options.enable_incomplete_features = True options.show_traceback = True - result = build.build(sources=[BuildSource('main', None, src)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource("main", None, src)], options=options, alt_lib_path=test_temp_dir + ) a = result.errors if a: raise CompileError(a) @@ -53,22 +53,29 @@ def test_transform(testcase: DataDrivenTestCase) -> None: # Omit the builtins module and files with a special marker in the # path. # TODO the test is not reliable - if (not f.path.endswith((os.sep + 'builtins.pyi', - 'typing_extensions.pyi', - 'typing.pyi', - 'abc.pyi', - 'sys.pyi')) - and not os.path.basename(f.path).startswith('_') - and not os.path.splitext( - os.path.basename(f.path))[0].endswith('_')): + if ( + not f.path.endswith( + ( + os.sep + "builtins.pyi", + "typing_extensions.pyi", + "typing.pyi", + "abc.pyi", + "sys.pyi", + ) + ) + and not os.path.basename(f.path).startswith("_") + and not os.path.splitext(os.path.basename(f.path))[0].endswith("_") + ): t = TypeAssertTransformVisitor() t.test_only = True f = t.mypyfile(f) - a += str(f).split('\n') + a += str(f).split("\n") except CompileError as e: a = e.messages if testcase.normalize_output: a = normalize_error_messages(a) assert_string_arrays_equal( - testcase.output, a, - f'Invalid semantic analyzer output ({testcase.file}, line {testcase.line})') + testcase.output, + a, + f"Invalid semantic analyzer output ({testcase.file}, line {testcase.line})", + ) diff --git a/mypy/test/testtypegen.py b/mypy/test/testtypegen.py index a91cd0a2972d3..089637630db2d 100644 --- a/mypy/test/testtypegen.py +++ b/mypy/test/testtypegen.py @@ -3,38 +3,40 @@ import re from mypy import build +from mypy.errors import CompileError from mypy.modulefinder import BuildSource +from mypy.nodes import NameExpr +from mypy.options import Options from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal from mypy.test.visitors import SkippedNodeSearcher, ignore_node from mypy.util import short_type -from mypy.nodes import NameExpr -from mypy.errors import CompileError -from mypy.options import Options class TypeExportSuite(DataSuite): required_out_section = True - files = ['typexport-basic.test'] + files = ["typexport-basic.test"] def run_case(self, testcase: DataDrivenTestCase) -> None: try: line = testcase.input[0] - mask = '' - if line.startswith('##'): - mask = '(' + line[2:].strip() + ')$' + mask = "" + if line.startswith("##"): + mask = "(" + line[2:].strip() + ")$" - src = '\n'.join(testcase.input) + src = "\n".join(testcase.input) options = Options() options.strict_optional = False # TODO: Enable strict optional checking options.use_builtins_fixtures = True options.show_traceback = True options.export_types = True options.preserve_asts = True - result = build.build(sources=[BuildSource('main', None, src)], - options=options, - alt_lib_path=test_temp_dir) + result = build.build( + sources=[BuildSource("main", None, src)], + options=options, + alt_lib_path=test_temp_dir, + ) a = result.errors map = result.types nodes = map.keys() @@ -52,20 +54,20 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: if node.line is not None and node.line != -1 and map[node]: if ignore_node(node) or node in ignored: continue - if (re.match(mask, short_type(node)) - or (isinstance(node, NameExpr) - and re.match(mask, node.name))): + if re.match(mask, short_type(node)) or ( + isinstance(node, NameExpr) and re.match(mask, node.name) + ): # Include node in output. keys.append(node) - for key in sorted(keys, - key=lambda n: (n.line, short_type(n), - str(n) + str(map[n]))): - ts = str(map[key]).replace('*', '') # Remove erased tags - ts = ts.replace('__main__.', '') - a.append(f'{short_type(key)}({key.line}) : {ts}') + for key in sorted(keys, key=lambda n: (n.line, short_type(n), str(n) + str(map[n]))): + ts = str(map[key]).replace("*", "") # Remove erased tags + ts = ts.replace("__main__.", "") + a.append(f"{short_type(key)}({key.line}) : {ts}") except CompileError as e: a = e.messages assert_string_arrays_equal( - testcase.output, a, - f'Invalid type checker output ({testcase.file}, line {testcase.line})') + testcase.output, + a, + f"Invalid type checker output ({testcase.file}, line {testcase.line})", + ) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 08469a60aba7e..fb9e3e80b854c 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -2,96 +2,140 @@ from typing import List, Tuple -from mypy.test.helpers import Suite, assert_equal, assert_type, skip from mypy.erasetype import erase_type, remove_instance_last_known_values from mypy.expandtype import expand_type -from mypy.join import join_types, join_simple +from mypy.indirection import TypeIndirectionVisitor +from mypy.join import join_simple, join_types from mypy.meet import meet_types, narrow_declared_type +from mypy.nodes import ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, CONTRAVARIANT, COVARIANT, INVARIANT from mypy.sametypes import is_same_type -from mypy.indirection import TypeIndirectionVisitor +from mypy.state import state +from mypy.subtypes import is_more_precise, is_proper_subtype, is_subtype +from mypy.test.helpers import Suite, assert_equal, assert_type, skip +from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture +from mypy.typeops import false_only, make_simplified_union, true_only from mypy.types import ( - UnboundType, AnyType, CallableType, TupleType, TypeVarType, Type, Instance, NoneType, - Overloaded, TypeType, UnionType, UninhabitedType, TypeVarId, TypeOfAny, ProperType, - LiteralType, get_proper_type + AnyType, + CallableType, + Instance, + LiteralType, + NoneType, + Overloaded, + ProperType, + TupleType, + Type, + TypeOfAny, + TypeType, + TypeVarId, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + get_proper_type, ) -from mypy.nodes import ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, CONTRAVARIANT, INVARIANT, COVARIANT -from mypy.subtypes import is_subtype, is_more_precise, is_proper_subtype -from mypy.test.typefixture import TypeFixture, InterfaceTypeFixture -from mypy.state import state -from mypy.typeops import true_only, false_only, make_simplified_union class TypesSuite(Suite): def setUp(self) -> None: - self.x = UnboundType('X') # Helpers - self.y = UnboundType('Y') + self.x = UnboundType("X") # Helpers + self.y = UnboundType("Y") self.fx = TypeFixture() self.function = self.fx.function def test_any(self) -> None: - assert_equal(str(AnyType(TypeOfAny.special_form)), 'Any') + assert_equal(str(AnyType(TypeOfAny.special_form)), "Any") def test_simple_unbound_type(self) -> None: - u = UnboundType('Foo') - assert_equal(str(u), 'Foo?') + u = UnboundType("Foo") + assert_equal(str(u), "Foo?") def test_generic_unbound_type(self) -> None: - u = UnboundType('Foo', [UnboundType('T'), AnyType(TypeOfAny.special_form)]) - assert_equal(str(u), 'Foo?[T?, Any]') + u = UnboundType("Foo", [UnboundType("T"), AnyType(TypeOfAny.special_form)]) + assert_equal(str(u), "Foo?[T?, Any]") def test_callable_type(self) -> None: - c = CallableType([self.x, self.y], - [ARG_POS, ARG_POS], - [None, None], - AnyType(TypeOfAny.special_form), self.function) - assert_equal(str(c), 'def (X?, Y?) -> Any') + c = CallableType( + [self.x, self.y], + [ARG_POS, ARG_POS], + [None, None], + AnyType(TypeOfAny.special_form), + self.function, + ) + assert_equal(str(c), "def (X?, Y?) -> Any") c2 = CallableType([], [], [], NoneType(), self.fx.function) - assert_equal(str(c2), 'def ()') + assert_equal(str(c2), "def ()") def test_callable_type_with_default_args(self) -> None: - c = CallableType([self.x, self.y], [ARG_POS, ARG_OPT], [None, None], - AnyType(TypeOfAny.special_form), self.function) - assert_equal(str(c), 'def (X?, Y? =) -> Any') - - c2 = CallableType([self.x, self.y], [ARG_OPT, ARG_OPT], [None, None], - AnyType(TypeOfAny.special_form), self.function) - assert_equal(str(c2), 'def (X? =, Y? =) -> Any') + c = CallableType( + [self.x, self.y], + [ARG_POS, ARG_OPT], + [None, None], + AnyType(TypeOfAny.special_form), + self.function, + ) + assert_equal(str(c), "def (X?, Y? =) -> Any") + + c2 = CallableType( + [self.x, self.y], + [ARG_OPT, ARG_OPT], + [None, None], + AnyType(TypeOfAny.special_form), + self.function, + ) + assert_equal(str(c2), "def (X? =, Y? =) -> Any") def test_callable_type_with_var_args(self) -> None: - c = CallableType([self.x], [ARG_STAR], [None], AnyType(TypeOfAny.special_form), - self.function) - assert_equal(str(c), 'def (*X?) -> Any') - - c2 = CallableType([self.x, self.y], [ARG_POS, ARG_STAR], - [None, None], AnyType(TypeOfAny.special_form), self.function) - assert_equal(str(c2), 'def (X?, *Y?) -> Any') - - c3 = CallableType([self.x, self.y], [ARG_OPT, ARG_STAR], [None, None], - AnyType(TypeOfAny.special_form), self.function) - assert_equal(str(c3), 'def (X? =, *Y?) -> Any') + c = CallableType( + [self.x], [ARG_STAR], [None], AnyType(TypeOfAny.special_form), self.function + ) + assert_equal(str(c), "def (*X?) -> Any") + + c2 = CallableType( + [self.x, self.y], + [ARG_POS, ARG_STAR], + [None, None], + AnyType(TypeOfAny.special_form), + self.function, + ) + assert_equal(str(c2), "def (X?, *Y?) -> Any") + + c3 = CallableType( + [self.x, self.y], + [ARG_OPT, ARG_STAR], + [None, None], + AnyType(TypeOfAny.special_form), + self.function, + ) + assert_equal(str(c3), "def (X? =, *Y?) -> Any") def test_tuple_type(self) -> None: - assert_equal(str(TupleType([], self.fx.std_tuple)), 'Tuple[]') - assert_equal(str(TupleType([self.x], self.fx.std_tuple)), 'Tuple[X?]') - assert_equal(str(TupleType([self.x, AnyType(TypeOfAny.special_form)], - self.fx.std_tuple)), 'Tuple[X?, Any]') + assert_equal(str(TupleType([], self.fx.std_tuple)), "Tuple[]") + assert_equal(str(TupleType([self.x], self.fx.std_tuple)), "Tuple[X?]") + assert_equal( + str(TupleType([self.x, AnyType(TypeOfAny.special_form)], self.fx.std_tuple)), + "Tuple[X?, Any]", + ) def test_type_variable_binding(self) -> None: - assert_equal(str(TypeVarType('X', 'X', 1, [], self.fx.o)), 'X`1') - assert_equal(str(TypeVarType('X', 'X', 1, [self.x, self.y], self.fx.o)), - 'X`1') + assert_equal(str(TypeVarType("X", "X", 1, [], self.fx.o)), "X`1") + assert_equal(str(TypeVarType("X", "X", 1, [self.x, self.y], self.fx.o)), "X`1") def test_generic_function_type(self) -> None: - c = CallableType([self.x, self.y], [ARG_POS, ARG_POS], [None, None], - self.y, self.function, name=None, - variables=[TypeVarType('X', 'X', -1, [], self.fx.o)]) - assert_equal(str(c), 'def [X] (X?, Y?) -> Y?') - - v = [TypeVarType('Y', 'Y', -1, [], self.fx.o), - TypeVarType('X', 'X', -2, [], self.fx.o)] + c = CallableType( + [self.x, self.y], + [ARG_POS, ARG_POS], + [None, None], + self.y, + self.function, + name=None, + variables=[TypeVarType("X", "X", -1, [], self.fx.o)], + ) + assert_equal(str(c), "def [X] (X?, Y?) -> Y?") + + v = [TypeVarType("Y", "Y", -1, [], self.fx.o), TypeVarType("X", "X", -2, [], self.fx.o)] c2 = CallableType([], [], [], NoneType(), self.function, name=None, variables=v) - assert_equal(str(c2), 'def [Y, X] ()') + assert_equal(str(c2), "def [Y, X] ()") def test_type_alias_expand_once(self) -> None: A, target = self.fx.def_alias_1(self.fx.a) @@ -109,22 +153,21 @@ def test_type_alias_expand_all(self) -> None: assert A.expand_all_if_possible() is None B = self.fx.non_rec_alias(self.fx.a) - C = self.fx.non_rec_alias(TupleType([B, B], Instance(self.fx.std_tuplei, - [B]))) - assert C.expand_all_if_possible() == TupleType([self.fx.a, self.fx.a], - Instance(self.fx.std_tuplei, - [self.fx.a])) + C = self.fx.non_rec_alias(TupleType([B, B], Instance(self.fx.std_tuplei, [B]))) + assert C.expand_all_if_possible() == TupleType( + [self.fx.a, self.fx.a], Instance(self.fx.std_tuplei, [self.fx.a]) + ) def test_indirection_no_infinite_recursion(self) -> None: A, _ = self.fx.def_alias_1(self.fx.a) visitor = TypeIndirectionVisitor() modules = A.accept(visitor) - assert modules == {'__main__', 'builtins'} + assert modules == {"__main__", "builtins"} A, _ = self.fx.def_alias_2(self.fx.a) visitor = TypeIndirectionVisitor() modules = A.accept(visitor) - assert modules == {'__main__', 'builtins'} + assert modules == {"__main__", "builtins"} class TypeOpsSuite(Suite): @@ -136,9 +179,15 @@ def setUp(self) -> None: # expand_type def test_trivial_expand(self) -> None: - for t in (self.fx.a, self.fx.o, self.fx.t, self.fx.nonet, - self.tuple(self.fx.a), - self.callable([], self.fx.a, self.fx.a), self.fx.anyt): + for t in ( + self.fx.a, + self.fx.o, + self.fx.t, + self.fx.nonet, + self.tuple(self.fx.a), + self.callable([], self.fx.a, self.fx.a), + self.fx.anyt, + ): self.assert_expand(t, [], t) self.assert_expand(t, [], t) self.assert_expand(t, [], t) @@ -161,11 +210,9 @@ def test_expand_basic_generic_types(self) -> None: # callable types # multiple arguments - def assert_expand(self, - orig: Type, - map_items: List[Tuple[TypeVarId, Type]], - result: Type, - ) -> None: + def assert_expand( + self, orig: Type, map_items: List[Tuple[TypeVarId, Type]], result: Type + ) -> None: lower_bounds = {} for id, t in map_items: @@ -173,7 +220,7 @@ def assert_expand(self, exp = expand_type(orig, lower_bounds) # Remove erased tags (asterisks). - assert_equal(str(exp).replace('*', ''), str(result)) + assert_equal(str(exp).replace("*", ""), str(result)) # erase_type @@ -186,8 +233,7 @@ def test_erase_with_type_variable(self) -> None: def test_erase_with_generic_type(self) -> None: self.assert_erase(self.fx.ga, self.fx.gdyn) - self.assert_erase(self.fx.hab, - Instance(self.fx.hi, [self.fx.anyt, self.fx.anyt])) + self.assert_erase(self.fx.hab, Instance(self.fx.hi, [self.fx.anyt, self.fx.anyt])) def test_erase_with_generic_type_recursive(self) -> None: tuple_any = Instance(self.fx.std_tuplei, [AnyType(TypeOfAny.explicit)]) @@ -200,20 +246,28 @@ def test_erase_with_tuple_type(self) -> None: self.assert_erase(self.tuple(self.fx.a), self.fx.std_tuple) def test_erase_with_function_type(self) -> None: - self.assert_erase(self.fx.callable(self.fx.a, self.fx.b), - CallableType(arg_types=[self.fx.anyt, self.fx.anyt], - arg_kinds=[ARG_STAR, ARG_STAR2], - arg_names=[None, None], - ret_type=self.fx.anyt, - fallback=self.fx.function)) + self.assert_erase( + self.fx.callable(self.fx.a, self.fx.b), + CallableType( + arg_types=[self.fx.anyt, self.fx.anyt], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=[None, None], + ret_type=self.fx.anyt, + fallback=self.fx.function, + ), + ) def test_erase_with_type_object(self) -> None: - self.assert_erase(self.fx.callable_type(self.fx.a, self.fx.b), - CallableType(arg_types=[self.fx.anyt, self.fx.anyt], - arg_kinds=[ARG_STAR, ARG_STAR2], - arg_names=[None, None], - ret_type=self.fx.anyt, - fallback=self.fx.type_type)) + self.assert_erase( + self.fx.callable_type(self.fx.a, self.fx.b), + CallableType( + arg_types=[self.fx.anyt, self.fx.anyt], + arg_kinds=[ARG_STAR, ARG_STAR2], + arg_names=[None, None], + ret_type=self.fx.anyt, + fallback=self.fx.type_type, + ), + ) def test_erase_with_type_type(self) -> None: self.assert_erase(self.fx.type_a, self.fx.type_a) @@ -230,10 +284,8 @@ def test_is_more_precise(self) -> None: assert is_more_precise(fx.b, fx.b) assert is_more_precise(fx.b, fx.b) assert is_more_precise(fx.b, fx.anyt) - assert is_more_precise(self.tuple(fx.b, fx.a), - self.tuple(fx.b, fx.a)) - assert is_more_precise(self.tuple(fx.b, fx.b), - self.tuple(fx.b, fx.a)) + assert is_more_precise(self.tuple(fx.b, fx.a), self.tuple(fx.b, fx.a)) + assert is_more_precise(self.tuple(fx.b, fx.b), self.tuple(fx.b, fx.a)) assert not is_more_precise(fx.a, fx.b) assert not is_more_precise(fx.anyt, fx.b) @@ -264,10 +316,8 @@ def test_is_proper_subtype(self) -> None: assert not is_proper_subtype(fx.t, fx.s) assert is_proper_subtype(fx.a, UnionType([fx.a, fx.b])) - assert is_proper_subtype(UnionType([fx.a, fx.b]), - UnionType([fx.a, fx.b, fx.c])) - assert not is_proper_subtype(UnionType([fx.a, fx.b]), - UnionType([fx.b, fx.c])) + assert is_proper_subtype(UnionType([fx.a, fx.b]), UnionType([fx.a, fx.b, fx.c])) + assert not is_proper_subtype(UnionType([fx.a, fx.b]), UnionType([fx.b, fx.c])) def test_is_proper_subtype_covariance(self) -> None: fx_co = self.fx_co @@ -356,8 +406,7 @@ def test_empty_tuple_always_false(self) -> None: assert not tuple_type.can_be_true def test_nonempty_tuple_always_true(self) -> None: - tuple_type = self.tuple(AnyType(TypeOfAny.special_form), - AnyType(TypeOfAny.special_form)) + tuple_type = self.tuple(AnyType(TypeOfAny.special_form), AnyType(TypeOfAny.special_form)) assert tuple_type.can_be_true assert not tuple_type.can_be_false @@ -441,8 +490,9 @@ def test_false_only_of_union(self) -> None: tup_type = self.tuple() # Union of something that is unknown, something that is always true, something # that is always false - union_type = UnionType([self.fx.a, self.tuple(AnyType(TypeOfAny.special_form)), - tup_type]) + union_type = UnionType( + [self.fx.a, self.tuple(AnyType(TypeOfAny.special_form)), tup_type] + ) assert_equal(len(union_type.items), 3) fo = false_only(union_type) assert isinstance(fo, UnionType) @@ -463,8 +513,9 @@ def test_simplified_union(self) -> None: self.assert_simplified_union([fx.ga, fx.gsba], fx.ga) self.assert_simplified_union([fx.a, UnionType([fx.d])], UnionType([fx.a, fx.d])) self.assert_simplified_union([fx.a, UnionType([fx.a])], fx.a) - self.assert_simplified_union([fx.b, UnionType([fx.c, UnionType([fx.d])])], - UnionType([fx.b, fx.c, fx.d])) + self.assert_simplified_union( + [fx.b, UnionType([fx.c, UnionType([fx.d])])], UnionType([fx.b, fx.c, fx.d]) + ) def test_simplified_union_with_literals(self) -> None: fx = self.fx @@ -477,10 +528,12 @@ def test_simplified_union_with_literals(self) -> None: self.assert_simplified_union([fx.lit1, fx.uninhabited], fx.lit1) self.assert_simplified_union([fx.lit1_inst, fx.a], fx.a) self.assert_simplified_union([fx.lit1_inst, fx.lit1_inst], fx.lit1_inst) - self.assert_simplified_union([fx.lit1_inst, fx.lit2_inst], - UnionType([fx.lit1_inst, fx.lit2_inst])) - self.assert_simplified_union([fx.lit1_inst, fx.lit3_inst], - UnionType([fx.lit1_inst, fx.lit3_inst])) + self.assert_simplified_union( + [fx.lit1_inst, fx.lit2_inst], UnionType([fx.lit1_inst, fx.lit2_inst]) + ) + self.assert_simplified_union( + [fx.lit1_inst, fx.lit3_inst], UnionType([fx.lit1_inst, fx.lit3_inst]) + ) self.assert_simplified_union([fx.lit1_inst, fx.uninhabited], fx.lit1_inst) self.assert_simplified_union([fx.lit1, fx.lit1_inst], UnionType([fx.lit1, fx.lit1_inst])) self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) @@ -491,10 +544,13 @@ def test_simplified_union_with_str_literals(self) -> None: self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.str_type], fx.str_type) self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1], fx.lit_str1) - self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3], - UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3])) - self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.uninhabited], - UnionType([fx.lit_str1, fx.lit_str2])) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str2, fx.lit_str3], + UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3]), + ) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str2, fx.uninhabited], UnionType([fx.lit_str1, fx.lit_str2]) + ) def test_simplify_very_large_union(self) -> None: fx = self.fx @@ -507,26 +563,32 @@ def test_simplify_very_large_union(self) -> None: def test_simplified_union_with_str_instance_literals(self) -> None: fx = self.fx - self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type], - fx.str_type) - self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst], - fx.lit_str1_inst) - self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst], - UnionType([fx.lit_str1_inst, - fx.lit_str2_inst, - fx.lit_str3_inst])) - self.assert_simplified_union([fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited], - UnionType([fx.lit_str1_inst, fx.lit_str2_inst])) + self.assert_simplified_union( + [fx.lit_str1_inst, fx.lit_str2_inst, fx.str_type], fx.str_type + ) + self.assert_simplified_union( + [fx.lit_str1_inst, fx.lit_str1_inst, fx.lit_str1_inst], fx.lit_str1_inst + ) + self.assert_simplified_union( + [fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst], + UnionType([fx.lit_str1_inst, fx.lit_str2_inst, fx.lit_str3_inst]), + ) + self.assert_simplified_union( + [fx.lit_str1_inst, fx.lit_str2_inst, fx.uninhabited], + UnionType([fx.lit_str1_inst, fx.lit_str2_inst]), + ) def test_simplified_union_with_mixed_str_literals(self) -> None: fx = self.fx - self.assert_simplified_union([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], - UnionType([fx.lit_str1, - fx.lit_str2, - fx.lit_str3_inst])) - self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], - UnionType([fx.lit_str1, fx.lit_str1_inst])) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], + UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), + ) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], + UnionType([fx.lit_str1, fx.lit_str1_inst]), + ) def assert_simplified_union(self, original: List[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) @@ -547,13 +609,15 @@ def callable(self, vars: List[str], *a: Type) -> CallableType: for v in vars: tv.append(TypeVarType(v, v, n, [], self.fx.o)) n -= 1 - return CallableType(list(a[:-1]), - [ARG_POS] * (len(a) - 1), - [None] * (len(a) - 1), - a[-1], - self.fx.function, - name=None, - variables=tv) + return CallableType( + list(a[:-1]), + [ARG_POS] * (len(a) - 1), + [None] * (len(a) - 1), + a[-1], + self.fx.function, + name=None, + variables=tv, + ) class JoinSuite(Suite): @@ -575,54 +639,56 @@ def test_class_subtyping(self) -> None: def test_tuples(self) -> None: self.assert_join(self.tuple(), self.tuple(), self.tuple()) - self.assert_join(self.tuple(self.fx.a), - self.tuple(self.fx.a), - self.tuple(self.fx.a)) - self.assert_join(self.tuple(self.fx.b, self.fx.c), - self.tuple(self.fx.a, self.fx.d), - self.tuple(self.fx.a, self.fx.o)) - - self.assert_join(self.tuple(self.fx.a, self.fx.a), - self.fx.std_tuple, - self.var_tuple(self.fx.anyt)) - self.assert_join(self.tuple(self.fx.a), - self.tuple(self.fx.a, self.fx.a), - self.var_tuple(self.fx.a)) - self.assert_join(self.tuple(self.fx.b), - self.tuple(self.fx.a, self.fx.c), - self.var_tuple(self.fx.a)) - self.assert_join(self.tuple(), - self.tuple(self.fx.a), - self.var_tuple(self.fx.a)) + self.assert_join(self.tuple(self.fx.a), self.tuple(self.fx.a), self.tuple(self.fx.a)) + self.assert_join( + self.tuple(self.fx.b, self.fx.c), + self.tuple(self.fx.a, self.fx.d), + self.tuple(self.fx.a, self.fx.o), + ) + + self.assert_join( + self.tuple(self.fx.a, self.fx.a), self.fx.std_tuple, self.var_tuple(self.fx.anyt) + ) + self.assert_join( + self.tuple(self.fx.a), self.tuple(self.fx.a, self.fx.a), self.var_tuple(self.fx.a) + ) + self.assert_join( + self.tuple(self.fx.b), self.tuple(self.fx.a, self.fx.c), self.var_tuple(self.fx.a) + ) + self.assert_join(self.tuple(), self.tuple(self.fx.a), self.var_tuple(self.fx.a)) def test_var_tuples(self) -> None: - self.assert_join(self.tuple(self.fx.a), - self.var_tuple(self.fx.a), - self.var_tuple(self.fx.a)) - self.assert_join(self.var_tuple(self.fx.a), - self.tuple(self.fx.a), - self.var_tuple(self.fx.a)) - self.assert_join(self.var_tuple(self.fx.a), - self.tuple(), - self.var_tuple(self.fx.a)) + self.assert_join( + self.tuple(self.fx.a), self.var_tuple(self.fx.a), self.var_tuple(self.fx.a) + ) + self.assert_join( + self.var_tuple(self.fx.a), self.tuple(self.fx.a), self.var_tuple(self.fx.a) + ) + self.assert_join(self.var_tuple(self.fx.a), self.tuple(), self.var_tuple(self.fx.a)) def test_function_types(self) -> None: - self.assert_join(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.b)) - - self.assert_join(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.b, self.fx.b), - self.callable(self.fx.b, self.fx.b)) - self.assert_join(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.a), - self.callable(self.fx.a, self.fx.a)) - self.assert_join(self.callable(self.fx.a, self.fx.b), - self.fx.function, - self.fx.function) - self.assert_join(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.d, self.fx.b), - self.fx.function) + self.assert_join( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.b), + ) + + self.assert_join( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.b, self.fx.b), + self.callable(self.fx.b, self.fx.b), + ) + self.assert_join( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.a), + self.callable(self.fx.a, self.fx.a), + ) + self.assert_join(self.callable(self.fx.a, self.fx.b), self.fx.function, self.fx.function) + self.assert_join( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.d, self.fx.b), + self.fx.function, + ) def test_type_vars(self) -> None: self.assert_join(self.fx.t, self.fx.t, self.fx.t) @@ -631,27 +697,47 @@ def test_type_vars(self) -> None: def test_none(self) -> None: # Any type t joined with None results in t. - for t in [NoneType(), self.fx.a, self.fx.o, UnboundType('x'), - self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b), self.fx.anyt]: + for t in [ + NoneType(), + self.fx.a, + self.fx.o, + UnboundType("x"), + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + self.fx.anyt, + ]: self.assert_join(t, NoneType(), t) def test_unbound_type(self) -> None: - self.assert_join(UnboundType('x'), UnboundType('x'), self.fx.anyt) - self.assert_join(UnboundType('x'), UnboundType('y'), self.fx.anyt) + self.assert_join(UnboundType("x"), UnboundType("x"), self.fx.anyt) + self.assert_join(UnboundType("x"), UnboundType("y"), self.fx.anyt) # Any type t joined with an unbound type results in dynamic. Unbound # type means that there is an error somewhere in the program, so this # does not affect type safety (whatever the result). - for t in [self.fx.a, self.fx.o, self.fx.ga, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: - self.assert_join(t, UnboundType('X'), self.fx.anyt) + for t in [ + self.fx.a, + self.fx.o, + self.fx.ga, + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: + self.assert_join(t, UnboundType("X"), self.fx.anyt) def test_any_type(self) -> None: # Join against 'Any' type always results in 'Any'. - for t in [self.fx.anyt, self.fx.a, self.fx.o, NoneType(), - UnboundType('x'), self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: + for t in [ + self.fx.anyt, + self.fx.a, + self.fx.o, + NoneType(), + UnboundType("x"), + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: self.assert_join(t, self.fx.anyt, self.fx.anyt) def test_mixed_truth_restricted_type_simple(self) -> None: @@ -672,10 +758,8 @@ def test_mixed_truth_restricted_type(self) -> None: def test_other_mixed_types(self) -> None: # In general, joining unrelated types produces object. - for t1 in [self.fx.a, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: - for t2 in [self.fx.a, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: + for t1 in [self.fx.a, self.fx.t, self.tuple(), self.callable(self.fx.a, self.fx.b)]: + for t2 in [self.fx.a, self.fx.t, self.tuple(), self.callable(self.fx.a, self.fx.b)]: if str(t1) != str(t2): self.assert_join(t1, t2, self.fx.o) @@ -683,8 +767,13 @@ def test_simple_generics(self) -> None: self.assert_join(self.fx.ga, self.fx.nonet, self.fx.ga) self.assert_join(self.fx.ga, self.fx.anyt, self.fx.anyt) - for t in [self.fx.a, self.fx.o, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: + for t in [ + self.fx.a, + self.fx.o, + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: self.assert_join(t, self.fx.ga, self.fx.o) def test_generics_invariant(self) -> None: @@ -726,12 +815,11 @@ def test_generic_types_and_any(self) -> None: self.assert_join(self.fx_contra.gdyn, self.fx_contra.ga, self.fx_contra.gdyn) def test_callables_with_any(self) -> None: - self.assert_join(self.callable(self.fx.a, self.fx.a, self.fx.anyt, - self.fx.a), - self.callable(self.fx.a, self.fx.anyt, self.fx.a, - self.fx.anyt), - self.callable(self.fx.a, self.fx.anyt, self.fx.anyt, - self.fx.anyt)) + self.assert_join( + self.callable(self.fx.a, self.fx.a, self.fx.anyt, self.fx.a), + self.callable(self.fx.a, self.fx.anyt, self.fx.a, self.fx.anyt), + self.callable(self.fx.a, self.fx.anyt, self.fx.anyt, self.fx.anyt), + ) def test_overloaded(self) -> None: c = self.callable @@ -806,8 +894,7 @@ def test_simple_type_objects(self) -> None: self.assert_join(t1, t2, tr) self.assert_join(t1, self.fx.type_type, self.fx.type_type) - self.assert_join(self.fx.type_type, self.fx.type_type, - self.fx.type_type) + self.assert_join(self.fx.type_type, self.fx.type_type, self.fx.type_type) def test_type_type(self) -> None: self.assert_join(self.fx.type_a, self.fx.type_b, self.fx.type_a) @@ -838,20 +925,18 @@ def test_literal_type(self) -> None: self.assert_join(UnionType([d, lit3]), d, UnionType([d, lit3])) self.assert_join(UnionType([a, lit1]), lit1, a) self.assert_join(UnionType([a, lit1]), lit2, a) - self.assert_join(UnionType([lit1, lit2]), - UnionType([lit1, lit2]), - UnionType([lit1, lit2])) + self.assert_join(UnionType([lit1, lit2]), UnionType([lit1, lit2]), UnionType([lit1, lit2])) # The order in which we try joining two unions influences the # ordering of the items in the final produced unions. So, we # manually call 'assert_simple_join' and tune the output # after swapping the arguments here. - self.assert_simple_join(UnionType([lit1, lit2]), - UnionType([lit2, lit3]), - UnionType([lit1, lit2, lit3])) - self.assert_simple_join(UnionType([lit2, lit3]), - UnionType([lit1, lit2]), - UnionType([lit2, lit3, lit1])) + self.assert_simple_join( + UnionType([lit1, lit2]), UnionType([lit2, lit3]), UnionType([lit1, lit2, lit3]) + ) + self.assert_simple_join( + UnionType([lit2, lit3]), UnionType([lit1, lit2]), UnionType([lit2, lit3, lit1]) + ) # There are additional test cases in check-inference.test. @@ -865,10 +950,9 @@ def assert_simple_join(self, s: Type, t: Type, join: Type) -> None: result = join_types(s, t) actual = str(result) expected = str(join) - assert_equal(actual, expected, - f'join({s}, {t}) == {{}} ({{}} expected)') - assert is_subtype(s, result), f'{s} not subtype of {result}' - assert is_subtype(t, result), f'{t} not subtype of {result}' + assert_equal(actual, expected, f"join({s}, {t}) == {{}} ({{}} expected)") + assert is_subtype(s, result), f"{s} not subtype of {result}" + assert is_subtype(t, result), f"{t} not subtype of {result}" def tuple(self, *a: Type) -> TupleType: return TupleType(list(a), self.fx.std_tuple) @@ -882,8 +966,7 @@ def callable(self, *a: Type) -> CallableType: a1, ... an and return type r. """ n = len(a) - 1 - return CallableType(list(a[:-1]), [ARG_POS] * n, [None] * n, - a[-1], self.fx.function) + return CallableType(list(a[:-1]), [ARG_POS] * n, [None] * n, a[-1], self.fx.function) def type_callable(self, *a: Type) -> CallableType: """type_callable(a1, ..., an, r) constructs a callable with @@ -891,8 +974,7 @@ def type_callable(self, *a: Type) -> CallableType: represents a type. """ n = len(a) - 1 - return CallableType(list(a[:-1]), [ARG_POS] * n, [None] * n, - a[-1], self.fx.type_type) + return CallableType(list(a[:-1]), [ARG_POS] * n, [None] * n, a[-1], self.fx.type_type) class MeetSuite(Suite): @@ -912,31 +994,35 @@ def test_class_subtyping(self) -> None: def test_tuples(self) -> None: self.assert_meet(self.tuple(), self.tuple(), self.tuple()) - self.assert_meet(self.tuple(self.fx.a), - self.tuple(self.fx.a), - self.tuple(self.fx.a)) - self.assert_meet(self.tuple(self.fx.b, self.fx.c), - self.tuple(self.fx.a, self.fx.d), - self.tuple(self.fx.b, NoneType())) - - self.assert_meet(self.tuple(self.fx.a, self.fx.a), - self.fx.std_tuple, - self.tuple(self.fx.a, self.fx.a)) - self.assert_meet(self.tuple(self.fx.a), - self.tuple(self.fx.a, self.fx.a), - NoneType()) + self.assert_meet(self.tuple(self.fx.a), self.tuple(self.fx.a), self.tuple(self.fx.a)) + self.assert_meet( + self.tuple(self.fx.b, self.fx.c), + self.tuple(self.fx.a, self.fx.d), + self.tuple(self.fx.b, NoneType()), + ) + + self.assert_meet( + self.tuple(self.fx.a, self.fx.a), self.fx.std_tuple, self.tuple(self.fx.a, self.fx.a) + ) + self.assert_meet(self.tuple(self.fx.a), self.tuple(self.fx.a, self.fx.a), NoneType()) def test_function_types(self) -> None: - self.assert_meet(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.b)) - - self.assert_meet(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.b, self.fx.b), - self.callable(self.fx.a, self.fx.b)) - self.assert_meet(self.callable(self.fx.a, self.fx.b), - self.callable(self.fx.a, self.fx.a), - self.callable(self.fx.a, self.fx.b)) + self.assert_meet( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.b), + ) + + self.assert_meet( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.b, self.fx.b), + self.callable(self.fx.a, self.fx.b), + ) + self.assert_meet( + self.callable(self.fx.a, self.fx.b), + self.callable(self.fx.a, self.fx.a), + self.callable(self.fx.a, self.fx.b), + ) def test_type_vars(self) -> None: self.assert_meet(self.fx.t, self.fx.t, self.fx.t) @@ -949,28 +1035,46 @@ def test_none(self) -> None: self.assert_meet(NoneType(), self.fx.anyt, NoneType()) # Any type t joined with None results in None, unless t is Any. - for t in [self.fx.a, self.fx.o, UnboundType('x'), self.fx.t, - self.tuple(), self.callable(self.fx.a, self.fx.b)]: + for t in [ + self.fx.a, + self.fx.o, + UnboundType("x"), + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: self.assert_meet(t, NoneType(), NoneType()) def test_unbound_type(self) -> None: - self.assert_meet(UnboundType('x'), UnboundType('x'), self.fx.anyt) - self.assert_meet(UnboundType('x'), UnboundType('y'), self.fx.anyt) + self.assert_meet(UnboundType("x"), UnboundType("x"), self.fx.anyt) + self.assert_meet(UnboundType("x"), UnboundType("y"), self.fx.anyt) - self.assert_meet(UnboundType('x'), self.fx.anyt, UnboundType('x')) + self.assert_meet(UnboundType("x"), self.fx.anyt, UnboundType("x")) # The meet of any type t with an unbound type results in dynamic. # Unbound type means that there is an error somewhere in the program, # so this does not affect type safety. - for t in [self.fx.a, self.fx.o, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: - self.assert_meet(t, UnboundType('X'), self.fx.anyt) + for t in [ + self.fx.a, + self.fx.o, + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: + self.assert_meet(t, UnboundType("X"), self.fx.anyt) def test_dynamic_type(self) -> None: # Meet against dynamic type always results in dynamic. - for t in [self.fx.anyt, self.fx.a, self.fx.o, NoneType(), - UnboundType('x'), self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: + for t in [ + self.fx.anyt, + self.fx.a, + self.fx.o, + NoneType(), + UnboundType("x"), + self.fx.t, + self.tuple(), + self.callable(self.fx.a, self.fx.b), + ]: self.assert_meet(t, self.fx.anyt, t) def test_simple_generics(self) -> None: @@ -983,8 +1087,7 @@ def test_simple_generics(self) -> None: self.assert_meet(self.fx.ga, self.fx.nonet, self.fx.nonet) self.assert_meet(self.fx.ga, self.fx.anyt, self.fx.ga) - for t in [self.fx.a, self.fx.t, self.tuple(), - self.callable(self.fx.a, self.fx.b)]: + for t in [self.fx.a, self.fx.t, self.tuple(), self.callable(self.fx.a, self.fx.b)]: self.assert_meet(t, self.fx.ga, self.fx.nonet) def test_generics_with_multiple_args(self) -> None: @@ -1005,12 +1108,11 @@ def test_generic_types_and_dynamic(self) -> None: self.assert_meet(self.fx.gdyn, self.fx.ga, self.fx.ga) def test_callables_with_dynamic(self) -> None: - self.assert_meet(self.callable(self.fx.a, self.fx.a, self.fx.anyt, - self.fx.a), - self.callable(self.fx.a, self.fx.anyt, self.fx.a, - self.fx.anyt), - self.callable(self.fx.a, self.fx.anyt, self.fx.anyt, - self.fx.anyt)) + self.assert_meet( + self.callable(self.fx.a, self.fx.a, self.fx.anyt, self.fx.a), + self.callable(self.fx.a, self.fx.anyt, self.fx.a, self.fx.anyt), + self.callable(self.fx.a, self.fx.anyt, self.fx.anyt, self.fx.anyt), + ) def test_meet_interface_types(self) -> None: self.assert_meet(self.fx.f, self.fx.f, self.fx.f) @@ -1080,10 +1182,9 @@ def assert_simple_meet(self, s: Type, t: Type, meet: Type) -> None: result = meet_types(s, t) actual = str(result) expected = str(meet) - assert_equal(actual, expected, - f'meet({s}, {t}) == {{}} ({{}} expected)') - assert is_subtype(result, s), f'{result} not subtype of {s}' - assert is_subtype(result, t), f'{result} not subtype of {t}' + assert_equal(actual, expected, f"meet({s}, {t}) == {{}} ({{}} expected)") + assert is_subtype(result, s), f"{result} not subtype of {s}" + assert is_subtype(result, t), f"{result} not subtype of {t}" def tuple(self, *a: Type) -> TupleType: return TupleType(list(a), self.fx.std_tuple) @@ -1093,9 +1194,7 @@ def callable(self, *a: Type) -> CallableType: a1, ... an and return type r. """ n = len(a) - 1 - return CallableType(list(a[:-1]), - [ARG_POS] * n, [None] * n, - a[-1], self.fx.function) + return CallableType(list(a[:-1]), [ARG_POS] * n, [None] * n, a[-1], self.fx.function) class SameTypeSuite(Suite): @@ -1131,15 +1230,14 @@ def assert_not_same(self, s: Type, t: Type, strict: bool = True) -> None: def assert_simple_is_same(self, s: Type, t: Type, expected: bool, strict: bool) -> None: actual = is_same_type(s, t) - assert_equal(actual, expected, - f'is_same_type({s}, {t}) is {{}} ({{}} expected)') + assert_equal(actual, expected, f"is_same_type({s}, {t}) is {{}} ({{}} expected)") if strict: - actual2 = (s == t) - assert_equal(actual2, expected, - f'({s} == {t}) is {{}} ({{}} expected)') - assert_equal(hash(s) == hash(t), expected, - f'(hash({s}) == hash({t}) is {{}} ({{}} expected)') + actual2 = s == t + assert_equal(actual2, expected, f"({s} == {t}) is {{}} ({{}} expected)") + assert_equal( + hash(s) == hash(t), expected, f"(hash({s}) == hash({t}) is {{}} ({{}} expected)" + ) class RemoveLastKnownValueSuite(Suite): @@ -1169,10 +1267,9 @@ def test_single_last_known_value(self) -> None: def test_last_known_values_with_merge(self) -> None: t = UnionType.make_union([self.fx.lit1_inst, self.fx.lit2_inst, self.fx.lit4_inst]) assert remove_instance_last_known_values(t) == self.fx.a - t = UnionType.make_union([self.fx.lit1_inst, - self.fx.b, - self.fx.lit2_inst, - self.fx.lit4_inst]) + t = UnionType.make_union( + [self.fx.lit1_inst, self.fx.b, self.fx.lit2_inst, self.fx.lit4_inst] + ) self.assert_union_result(t, [self.fx.a, self.fx.b]) def test_generics(self) -> None: diff --git a/mypy/test/testutil.py b/mypy/test/testutil.py index fe3cdfa7e7d2f..8b278dba35eae 100644 --- a/mypy/test/testutil.py +++ b/mypy/test/testutil.py @@ -1,5 +1,5 @@ import os -from unittest import mock, TestCase +from unittest import TestCase, mock from mypy.util import get_terminal_width @@ -9,7 +9,7 @@ def test_get_terminal_size_in_pty_defaults_to_80(self) -> None: # when run using a pty, `os.get_terminal_size()` returns `0, 0` ret = os.terminal_size((0, 0)) mock_environ = os.environ.copy() - mock_environ.pop('COLUMNS', None) - with mock.patch.object(os, 'get_terminal_size', return_value=ret): + mock_environ.pop("COLUMNS", None) + with mock.patch.object(os, "get_terminal_size", return_value=ret): with mock.patch.dict(os.environ, values=mock_environ, clear=True): assert get_terminal_width() == 80 diff --git a/mypy/test/typefixture.py b/mypy/test/typefixture.py index f5c47d968ba82..85276259b7b83 100644 --- a/mypy/test/typefixture.py +++ b/mypy/test/typefixture.py @@ -5,15 +5,36 @@ from typing import List, Optional, Tuple +from mypy.nodes import ( + ARG_OPT, + ARG_POS, + ARG_STAR, + COVARIANT, + MDEF, + Block, + ClassDef, + FuncDef, + SymbolTable, + SymbolTableNode, + TypeAlias, + TypeInfo, +) from mypy.semanal_shared import set_callable_name from mypy.types import ( - Type, AnyType, NoneType, Instance, CallableType, TypeVarType, TypeType, - UninhabitedType, TypeOfAny, TypeAliasType, UnionType, LiteralType, - TypeVarLikeType, TypeVarTupleType -) -from mypy.nodes import ( - TypeInfo, ClassDef, FuncDef, Block, ARG_POS, ARG_OPT, ARG_STAR, SymbolTable, - COVARIANT, TypeAlias, SymbolTableNode, MDEF, + AnyType, + CallableType, + Instance, + LiteralType, + NoneType, + Type, + TypeAliasType, + TypeOfAny, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UninhabitedType, + UnionType, ) @@ -25,29 +46,30 @@ class TypeFixture: def __init__(self, variance: int = COVARIANT) -> None: # The 'object' class - self.oi = self.make_type_info('builtins.object') # class object - self.o = Instance(self.oi, []) # object + self.oi = self.make_type_info("builtins.object") # class object + self.o = Instance(self.oi, []) # object # Type variables (these are effectively global) - def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type, - variance: int) -> TypeVarType: + def make_type_var( + name: str, id: int, values: List[Type], upper_bound: Type, variance: int + ) -> TypeVarType: return TypeVarType(name, name, id, values, upper_bound, variance) def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleType: return TypeVarTupleType(name, name, id, upper_bound) - self.t = make_type_var('T', 1, [], self.o, variance) # T`1 (type variable) - self.tf = make_type_var('T', -1, [], self.o, variance) # T`-1 (type variable) - self.tf2 = make_type_var('T', -2, [], self.o, variance) # T`-2 (type variable) - self.s = make_type_var('S', 2, [], self.o, variance) # S`2 (type variable) - self.s1 = make_type_var('S', 1, [], self.o, variance) # S`1 (type variable) - self.sf = make_type_var('S', -2, [], self.o, variance) # S`-2 (type variable) - self.sf1 = make_type_var('S', -1, [], self.o, variance) # S`-1 (type variable) + self.t = make_type_var("T", 1, [], self.o, variance) # T`1 (type variable) + self.tf = make_type_var("T", -1, [], self.o, variance) # T`-1 (type variable) + self.tf2 = make_type_var("T", -2, [], self.o, variance) # T`-2 (type variable) + self.s = make_type_var("S", 2, [], self.o, variance) # S`2 (type variable) + self.s1 = make_type_var("S", 1, [], self.o, variance) # S`1 (type variable) + self.sf = make_type_var("S", -2, [], self.o, variance) # S`-2 (type variable) + self.sf1 = make_type_var("S", -1, [], self.o, variance) # S`-1 (type variable) - self.ts = make_type_var_tuple('Ts', 1, self.o) # Ts`1 (type var tuple) - self.ss = make_type_var_tuple('Ss', 2, self.o) # Ss`2 (type var tuple) - self.us = make_type_var_tuple('Us', 3, self.o) # Us`3 (type var tuple) + self.ts = make_type_var_tuple("Ts", 1, self.o) # Ts`1 (type var tuple) + self.ss = make_type_var_tuple("Ss", 2, self.o) # Ss`2 (type var tuple) + self.us = make_type_var_tuple("Us", 3, self.o) # Us`3 (type var tuple) # Simple types self.anyt = AnyType(TypeOfAny.special_form) @@ -57,112 +79,111 @@ def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleTy # Abstract class TypeInfos # class F - self.fi = self.make_type_info('F', is_abstract=True) + self.fi = self.make_type_info("F", is_abstract=True) # class F2 - self.f2i = self.make_type_info('F2', is_abstract=True) + self.f2i = self.make_type_info("F2", is_abstract=True) # class F3(F) - self.f3i = self.make_type_info('F3', is_abstract=True, mro=[self.fi]) + self.f3i = self.make_type_info("F3", is_abstract=True, mro=[self.fi]) # Class TypeInfos - self.std_tuplei = self.make_type_info('builtins.tuple', - mro=[self.oi], - typevars=['T'], - variances=[COVARIANT]) # class tuple - self.type_typei = self.make_type_info('builtins.type') # class type - self.bool_type_info = self.make_type_info('builtins.bool') - self.str_type_info = self.make_type_info('builtins.str') - self.functioni = self.make_type_info('builtins.function') # function TODO - self.ai = self.make_type_info('A', mro=[self.oi]) # class A - self.bi = self.make_type_info('B', mro=[self.ai, self.oi]) # class B(A) - self.ci = self.make_type_info('C', mro=[self.ai, self.oi]) # class C(A) - self.di = self.make_type_info('D', mro=[self.oi]) # class D + self.std_tuplei = self.make_type_info( + "builtins.tuple", mro=[self.oi], typevars=["T"], variances=[COVARIANT] + ) # class tuple + self.type_typei = self.make_type_info("builtins.type") # class type + self.bool_type_info = self.make_type_info("builtins.bool") + self.str_type_info = self.make_type_info("builtins.str") + self.functioni = self.make_type_info("builtins.function") # function TODO + self.ai = self.make_type_info("A", mro=[self.oi]) # class A + self.bi = self.make_type_info("B", mro=[self.ai, self.oi]) # class B(A) + self.ci = self.make_type_info("C", mro=[self.ai, self.oi]) # class C(A) + self.di = self.make_type_info("D", mro=[self.oi]) # class D # class E(F) - self.ei = self.make_type_info('E', mro=[self.fi, self.oi]) + self.ei = self.make_type_info("E", mro=[self.fi, self.oi]) # class E2(F2, F) - self.e2i = self.make_type_info('E2', mro=[self.f2i, self.fi, self.oi]) + self.e2i = self.make_type_info("E2", mro=[self.f2i, self.fi, self.oi]) # class E3(F, F2) - self.e3i = self.make_type_info('E3', mro=[self.fi, self.f2i, self.oi]) + self.e3i = self.make_type_info("E3", mro=[self.fi, self.f2i, self.oi]) # Generic class TypeInfos # G[T] - self.gi = self.make_type_info('G', mro=[self.oi], - typevars=['T'], - variances=[variance]) + self.gi = self.make_type_info("G", mro=[self.oi], typevars=["T"], variances=[variance]) # G2[T] - self.g2i = self.make_type_info('G2', mro=[self.oi], - typevars=['T'], - variances=[variance]) + self.g2i = self.make_type_info("G2", mro=[self.oi], typevars=["T"], variances=[variance]) # H[S, T] - self.hi = self.make_type_info('H', mro=[self.oi], - typevars=['S', 'T'], - variances=[variance, variance]) + self.hi = self.make_type_info( + "H", mro=[self.oi], typevars=["S", "T"], variances=[variance, variance] + ) # GS[T, S] <: G[S] - self.gsi = self.make_type_info('GS', mro=[self.gi, self.oi], - typevars=['T', 'S'], - variances=[variance, variance], - bases=[Instance(self.gi, [self.s])]) + self.gsi = self.make_type_info( + "GS", + mro=[self.gi, self.oi], + typevars=["T", "S"], + variances=[variance, variance], + bases=[Instance(self.gi, [self.s])], + ) # GS2[S] <: G[S] - self.gs2i = self.make_type_info('GS2', mro=[self.gi, self.oi], - typevars=['S'], - variances=[variance], - bases=[Instance(self.gi, [self.s1])]) - - self.gvi = self.make_type_info('GV', mro=[self.oi], - typevars=['Ts'], - typevar_tuple_index=0) + self.gs2i = self.make_type_info( + "GS2", + mro=[self.gi, self.oi], + typevars=["S"], + variances=[variance], + bases=[Instance(self.gi, [self.s1])], + ) + + self.gvi = self.make_type_info("GV", mro=[self.oi], typevars=["Ts"], typevar_tuple_index=0) # list[T] - self.std_listi = self.make_type_info('builtins.list', mro=[self.oi], - typevars=['T'], - variances=[variance]) + self.std_listi = self.make_type_info( + "builtins.list", mro=[self.oi], typevars=["T"], variances=[variance] + ) # Instance types - self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple - self.type_type = Instance(self.type_typei, []) # type + self.std_tuple = Instance(self.std_tuplei, [self.anyt]) # tuple + self.type_type = Instance(self.type_typei, []) # type self.function = Instance(self.functioni, []) # function TODO self.str_type = Instance(self.str_type_info, []) - self.a = Instance(self.ai, []) # A - self.b = Instance(self.bi, []) # B - self.c = Instance(self.ci, []) # C - self.d = Instance(self.di, []) # D + self.a = Instance(self.ai, []) # A + self.b = Instance(self.bi, []) # B + self.c = Instance(self.ci, []) # C + self.d = Instance(self.di, []) # D - self.e = Instance(self.ei, []) # E - self.e2 = Instance(self.e2i, []) # E2 - self.e3 = Instance(self.e3i, []) # E3 + self.e = Instance(self.ei, []) # E + self.e2 = Instance(self.e2i, []) # E2 + self.e3 = Instance(self.e3i, []) # E3 - self.f = Instance(self.fi, []) # F - self.f2 = Instance(self.f2i, []) # F2 - self.f3 = Instance(self.f3i, []) # F3 + self.f = Instance(self.fi, []) # F + self.f2 = Instance(self.f2i, []) # F2 + self.f3 = Instance(self.f3i, []) # F3 # Generic instance types - self.ga = Instance(self.gi, [self.a]) # G[A] - self.gb = Instance(self.gi, [self.b]) # G[B] - self.gd = Instance(self.gi, [self.d]) # G[D] - self.go = Instance(self.gi, [self.o]) # G[object] - self.gt = Instance(self.gi, [self.t]) # G[T`1] - self.gtf = Instance(self.gi, [self.tf]) # G[T`-1] - self.gtf2 = Instance(self.gi, [self.tf2]) # G[T`-2] - self.gs = Instance(self.gi, [self.s]) # G[S] - self.gdyn = Instance(self.gi, [self.anyt]) # G[Any] - self.gn = Instance(self.gi, [NoneType()]) # G[None] - - self.g2a = Instance(self.g2i, [self.a]) # G2[A] + self.ga = Instance(self.gi, [self.a]) # G[A] + self.gb = Instance(self.gi, [self.b]) # G[B] + self.gd = Instance(self.gi, [self.d]) # G[D] + self.go = Instance(self.gi, [self.o]) # G[object] + self.gt = Instance(self.gi, [self.t]) # G[T`1] + self.gtf = Instance(self.gi, [self.tf]) # G[T`-1] + self.gtf2 = Instance(self.gi, [self.tf2]) # G[T`-2] + self.gs = Instance(self.gi, [self.s]) # G[S] + self.gdyn = Instance(self.gi, [self.anyt]) # G[Any] + self.gn = Instance(self.gi, [NoneType()]) # G[None] + + self.g2a = Instance(self.g2i, [self.a]) # G2[A] self.gsaa = Instance(self.gsi, [self.a, self.a]) # GS[A, A] self.gsab = Instance(self.gsi, [self.a, self.b]) # GS[A, B] self.gsba = Instance(self.gsi, [self.b, self.a]) # GS[B, A] - self.gs2a = Instance(self.gs2i, [self.a]) # GS2[A] - self.gs2b = Instance(self.gs2i, [self.b]) # GS2[B] - self.gs2d = Instance(self.gs2i, [self.d]) # GS2[D] + self.gs2a = Instance(self.gs2i, [self.a]) # GS2[A] + self.gs2b = Instance(self.gs2i, [self.b]) # GS2[B] + self.gs2d = Instance(self.gs2i, [self.d]) # GS2[D] - self.hab = Instance(self.hi, [self.a, self.b]) # H[A, B] - self.haa = Instance(self.hi, [self.a, self.a]) # H[A, A] - self.hbb = Instance(self.hi, [self.b, self.b]) # H[B, B] - self.hts = Instance(self.hi, [self.t, self.s]) # H[T, S] - self.had = Instance(self.hi, [self.a, self.d]) # H[A, D] - self.hao = Instance(self.hi, [self.a, self.o]) # H[A, object] + self.hab = Instance(self.hi, [self.a, self.b]) # H[A, B] + self.haa = Instance(self.hi, [self.a, self.a]) # H[A, A] + self.hbb = Instance(self.hi, [self.b, self.b]) # H[B, B] + self.hts = Instance(self.hi, [self.t, self.s]) # H[T, S] + self.had = Instance(self.hi, [self.a, self.d]) # H[A, D] + self.hao = Instance(self.hi, [self.a, self.o]) # H[A, object] self.lsta = Instance(self.std_listi, [self.a]) # List[A] self.lstb = Instance(self.std_listi, [self.b]) # List[B] @@ -195,7 +216,7 @@ def make_type_var_tuple(name: str, id: int, upper_bound: Type) -> TypeVarTupleTy def _add_bool_dunder(self, type_info: TypeInfo) -> None: signature = CallableType([], [], [], Instance(self.bool_type_info, []), self.function) - bool_func = FuncDef('__bool__', [], Block([])) + bool_func = FuncDef("__bool__", [], Block([])) bool_func.type = set_callable_name(signature, bool_func) type_info.names[bool_func.name] = SymbolTableNode(MDEF, bool_func) @@ -205,16 +226,18 @@ def callable(self, *a: Type) -> CallableType: """callable(a1, ..., an, r) constructs a callable with argument types a1, ... an and return type r. """ - return CallableType(list(a[:-1]), [ARG_POS] * (len(a) - 1), - [None] * (len(a) - 1), a[-1], self.function) + return CallableType( + list(a[:-1]), [ARG_POS] * (len(a) - 1), [None] * (len(a) - 1), a[-1], self.function + ) def callable_type(self, *a: Type) -> CallableType: """callable_type(a1, ..., an, r) constructs a callable with argument types a1, ... an and return type r, and which represents a type. """ - return CallableType(list(a[:-1]), [ARG_POS] * (len(a) - 1), - [None] * (len(a) - 1), a[-1], self.type_type) + return CallableType( + list(a[:-1]), [ARG_POS] * (len(a) - 1), [None] * (len(a) - 1), a[-1], self.type_type + ) def callable_default(self, min_args: int, *a: Type) -> CallableType: """callable_default(min_args, a1, ..., an, r) constructs a @@ -222,45 +245,53 @@ def callable_default(self, min_args: int, *a: Type) -> CallableType: with min_args mandatory fixed arguments. """ n = len(a) - 1 - return CallableType(list(a[:-1]), - [ARG_POS] * min_args + [ARG_OPT] * (n - min_args), - [None] * n, - a[-1], self.function) + return CallableType( + list(a[:-1]), + [ARG_POS] * min_args + [ARG_OPT] * (n - min_args), + [None] * n, + a[-1], + self.function, + ) def callable_var_arg(self, min_args: int, *a: Type) -> CallableType: """callable_var_arg(min_args, a1, ..., an, r) constructs a callable with argument types a1, ... *an and return type r. """ n = len(a) - 1 - return CallableType(list(a[:-1]), - [ARG_POS] * min_args + - [ARG_OPT] * (n - 1 - min_args) + - [ARG_STAR], [None] * n, - a[-1], self.function) - - def make_type_info(self, name: str, - module_name: Optional[str] = None, - is_abstract: bool = False, - mro: Optional[List[TypeInfo]] = None, - bases: Optional[List[Instance]] = None, - typevars: Optional[List[str]] = None, - typevar_tuple_index: Optional[int] = None, - variances: Optional[List[int]] = None) -> TypeInfo: + return CallableType( + list(a[:-1]), + [ARG_POS] * min_args + [ARG_OPT] * (n - 1 - min_args) + [ARG_STAR], + [None] * n, + a[-1], + self.function, + ) + + def make_type_info( + self, + name: str, + module_name: Optional[str] = None, + is_abstract: bool = False, + mro: Optional[List[TypeInfo]] = None, + bases: Optional[List[Instance]] = None, + typevars: Optional[List[str]] = None, + typevar_tuple_index: Optional[int] = None, + variances: Optional[List[int]] = None, + ) -> TypeInfo: """Make a TypeInfo suitable for use in unit tests.""" class_def = ClassDef(name, Block([]), None, []) class_def.fullname = name if module_name is None: - if '.' in name: - module_name = name.rsplit('.', 1)[0] + if "." in name: + module_name = name.rsplit(".", 1)[0] else: - module_name = '__main__' + module_name = "__main__" if typevars: v: List[TypeVarLikeType] = [] for id, n in enumerate(typevars, 1): - if typevar_tuple_index is not None and id-1 == typevar_tuple_index: + if typevar_tuple_index is not None and id - 1 == typevar_tuple_index: v.append(TypeVarTupleType(n, n, id, self.o)) else: if variances: @@ -273,7 +304,7 @@ def make_type_info(self, name: str, info = TypeInfo(SymbolTable(), class_def, module_name) if mro is None: mro = [] - if name != 'builtins.object': + if name != "builtins.object": mro.append(self.oi) info.mro = [info] + mro if bases is None: @@ -288,22 +319,24 @@ def make_type_info(self, name: str, def def_alias_1(self, base: Instance) -> Tuple[TypeAliasType, Type]: A = TypeAliasType(None, []) - target = Instance(self.std_tuplei, - [UnionType([base, A])]) # A = Tuple[Union[base, A], ...] - AN = TypeAlias(target, '__main__.A', -1, -1) + target = Instance( + self.std_tuplei, [UnionType([base, A])] + ) # A = Tuple[Union[base, A], ...] + AN = TypeAlias(target, "__main__.A", -1, -1) A.alias = AN return A, target def def_alias_2(self, base: Instance) -> Tuple[TypeAliasType, Type]: A = TypeAliasType(None, []) - target = UnionType([base, - Instance(self.std_tuplei, [A])]) # A = Union[base, Tuple[A, ...]] - AN = TypeAlias(target, '__main__.A', -1, -1) + target = UnionType( + [base, Instance(self.std_tuplei, [A])] + ) # A = Union[base, Tuple[A, ...]] + AN = TypeAlias(target, "__main__.A", -1, -1) A.alias = AN return A, target def non_rec_alias(self, target: Type) -> TypeAliasType: - AN = TypeAlias(target, '__main__.A', -1, -1) + AN = TypeAlias(target, "__main__.A", -1, -1) return TypeAliasType(AN, []) @@ -314,13 +347,12 @@ class InterfaceTypeFixture(TypeFixture): def __init__(self) -> None: super().__init__() # GF[T] - self.gfi = self.make_type_info('GF', typevars=['T'], is_abstract=True) + self.gfi = self.make_type_info("GF", typevars=["T"], is_abstract=True) # M1 <: GF[A] - self.m1i = self.make_type_info('M1', - is_abstract=True, - mro=[self.gfi, self.oi], - bases=[Instance(self.gfi, [self.a])]) + self.m1i = self.make_type_info( + "M1", is_abstract=True, mro=[self.gfi, self.oi], bases=[Instance(self.gfi, [self.a])] + ) self.gfa = Instance(self.gfi, [self.a]) # GF[A] self.gfb = Instance(self.gfi, [self.b]) # GF[B] diff --git a/mypy/test/visitors.py b/mypy/test/visitors.py index b1a84e3529e19..e202963b62a43 100644 --- a/mypy/test/visitors.py +++ b/mypy/test/visitors.py @@ -9,10 +9,15 @@ from typing import Set from mypy.nodes import ( - NameExpr, TypeVarExpr, CallExpr, Expression, MypyFile, AssignmentStmt, IntExpr + AssignmentStmt, + CallExpr, + Expression, + IntExpr, + MypyFile, + NameExpr, + TypeVarExpr, ) from mypy.traverser import TraverserVisitor - from mypy.treetransform import TransformVisitor from mypy.types import Type @@ -24,7 +29,7 @@ def __init__(self) -> None: self.is_typing = False def visit_mypy_file(self, f: MypyFile) -> None: - self.is_typing = f.fullname == 'typing' or f.fullname == 'builtins' + self.is_typing = f.fullname == "typing" or f.fullname == "builtins" super().visit_mypy_file(f) def visit_assignment_stmt(self, s: AssignmentStmt) -> None: @@ -53,12 +58,11 @@ def ignore_node(node: Expression) -> bool: # from the typing module is not easy, we just to strip them all away. if isinstance(node, TypeVarExpr): return True - if isinstance(node, NameExpr) and node.fullname == 'builtins.object': + if isinstance(node, NameExpr) and node.fullname == "builtins.object": return True - if isinstance(node, NameExpr) and node.fullname == 'builtins.None': + if isinstance(node, NameExpr) and node.fullname == "builtins.None": return True - if isinstance(node, CallExpr) and (ignore_node(node.callee) or - node.analyzed): + if isinstance(node, CallExpr) and (ignore_node(node.callee) or node.analyzed): return True return False diff --git a/mypy/traverser.py b/mypy/traverser.py index d4e87b820dfbc..1c2fa8c04dcb3 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -1,25 +1,78 @@ """Generic node traverser visitor""" from typing import List, Tuple + from mypy_extensions import mypyc_attr -from mypy.patterns import ( - AsPattern, OrPattern, ValuePattern, SequencePattern, StarredPattern, MappingPattern, - ClassPattern -) -from mypy.visitor import NodeVisitor from mypy.nodes import ( - AssertTypeExpr, Block, MypyFile, FuncBase, FuncItem, CallExpr, ClassDef, Decorator, FuncDef, - ExpressionStmt, AssignmentStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, ReturnStmt, AssertStmt, DelStmt, IfStmt, RaiseStmt, - TryStmt, WithStmt, MatchStmt, NameExpr, MemberExpr, OpExpr, SliceExpr, CastExpr, - RevealExpr, UnaryExpr, ListExpr, TupleExpr, DictExpr, SetExpr, IndexExpr, AssignmentExpr, - GeneratorExpr, ListComprehension, SetComprehension, DictionaryComprehension, - ConditionalExpr, TypeApplication, ExecStmt, Import, ImportFrom, - LambdaExpr, ComparisonExpr, OverloadedFuncDef, YieldFromExpr, - YieldExpr, StarExpr, BackquoteExpr, AwaitExpr, PrintStmt, SuperExpr, Node, REVEAL_TYPE, + REVEAL_TYPE, + AssertStmt, + AssertTypeExpr, + AssignmentExpr, + AssignmentStmt, + AwaitExpr, + BackquoteExpr, + Block, + CallExpr, + CastExpr, + ClassDef, + ComparisonExpr, + ConditionalExpr, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + ExecStmt, Expression, + ExpressionStmt, + ForStmt, + FuncBase, + FuncDef, + FuncItem, + GeneratorExpr, + IfStmt, + Import, + ImportFrom, + IndexExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MatchStmt, + MemberExpr, + MypyFile, + NameExpr, + Node, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + PrintStmt, + RaiseStmt, + ReturnStmt, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + SuperExpr, + TryStmt, + TupleExpr, + TypeApplication, + UnaryExpr, + WhileStmt, + WithStmt, + YieldExpr, + YieldFromExpr, ) +from mypy.patterns import ( + AsPattern, + ClassPattern, + MappingPattern, + OrPattern, + SequencePattern, + StarredPattern, + ValuePattern, +) +from mypy.visitor import NodeVisitor @mypyc_attr(allow_interpreted_subclasses=True) @@ -249,8 +302,7 @@ def visit_index_expr(self, o: IndexExpr) -> None: o.analyzed.accept(self) def visit_generator_expr(self, o: GeneratorExpr) -> None: - for index, sequence, conditions in zip(o.indices, o.sequences, - o.condlists): + for index, sequence, conditions in zip(o.indices, o.sequences, o.condlists): sequence.accept(self) index.accept(self) for cond in conditions: @@ -258,8 +310,7 @@ def visit_generator_expr(self, o: GeneratorExpr) -> None: o.left_expr.accept(self) def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: - for index, sequence, conditions in zip(o.indices, o.sequences, - o.condlists): + for index, sequence, conditions in zip(o.indices, o.sequences, o.condlists): sequence.accept(self) index.accept(self) for cond in conditions: @@ -357,7 +408,7 @@ def __init__(self) -> None: self.found = False def visit_return_stmt(self, o: ReturnStmt) -> None: - if (o.expr is None or isinstance(o.expr, NameExpr) and o.expr.name == 'None'): + if o.expr is None or isinstance(o.expr, NameExpr) and o.expr.name == "None": return self.found = True diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 0bc72274354ad..7ac73d36ca152 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -3,29 +3,99 @@ Subclass TransformVisitor to perform non-trivial transformations. """ -from typing import List, Dict, cast, Optional, Iterable +from typing import Dict, Iterable, List, Optional, cast from mypy.nodes import ( - AssertTypeExpr, MypyFile, Import, Node, ImportAll, ImportFrom, FuncItem, FuncDef, - OverloadedFuncDef, ClassDef, Decorator, Block, Var, - OperatorAssignmentStmt, ExpressionStmt, AssignmentStmt, ReturnStmt, - RaiseStmt, AssertStmt, DelStmt, BreakStmt, ContinueStmt, - PassStmt, GlobalDecl, WhileStmt, ForStmt, IfStmt, TryStmt, WithStmt, - CastExpr, RevealExpr, TupleExpr, GeneratorExpr, ListComprehension, ListExpr, - ConditionalExpr, DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, - UnicodeExpr, FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, - SliceExpr, OpExpr, UnaryExpr, LambdaExpr, TypeApplication, PrintStmt, - SymbolTable, RefExpr, TypeVarExpr, ParamSpecExpr, NewTypeExpr, PromoteExpr, - ComparisonExpr, TempNode, StarExpr, Statement, Expression, - YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SetComprehension, - DictionaryComprehension, ComplexExpr, TypeAliasExpr, EllipsisExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, AssignmentExpr, - OverloadPart, EnumCallExpr, REVEAL_TYPE, GDEF, TypeVarTupleExpr + GDEF, + REVEAL_TYPE, + Argument, + AssertStmt, + AssertTypeExpr, + AssignmentExpr, + AssignmentStmt, + AwaitExpr, + BackquoteExpr, + Block, + BreakStmt, + BytesExpr, + CallExpr, + CastExpr, + ClassDef, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + ContinueStmt, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + EnumCallExpr, + ExecStmt, + Expression, + ExpressionStmt, + FloatExpr, + ForStmt, + FuncDef, + FuncItem, + GeneratorExpr, + GlobalDecl, + IfStmt, + Import, + ImportAll, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + Node, + NonlocalDecl, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + OverloadPart, + ParamSpecExpr, + PassStmt, + PrintStmt, + PromoteExpr, + RaiseStmt, + RefExpr, + ReturnStmt, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + Statement, + StrExpr, + SuperExpr, + SymbolTable, + TempNode, + TryStmt, + TupleExpr, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeVarExpr, + TypeVarTupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + YieldExpr, + YieldFromExpr, ) -from mypy.types import Type, FunctionLike, ProperType from mypy.traverser import TraverserVisitor -from mypy.visitor import NodeVisitor +from mypy.types import FunctionLike, ProperType, Type from mypy.util import replace_object_state +from mypy.visitor import NodeVisitor class TransformVisitor(NodeVisitor[Node]): @@ -65,10 +135,8 @@ def __init__(self) -> None: def visit_mypy_file(self, node: MypyFile) -> MypyFile: assert self.test_only, "This visitor should not be used for whole files." # NOTE: The 'names' and 'imports' instance variables will be empty! - ignored_lines = {line: codes[:] - for line, codes in node.ignored_lines.items()} - new = MypyFile(self.statements(node.defs), [], node.is_bom, - ignored_lines=ignored_lines) + ignored_lines = {line: codes[:] for line, codes in node.ignored_lines.items()} + new = MypyFile(self.statements(node.defs), [], node.is_bom, ignored_lines=ignored_lines) new._fullname = node._fullname new.path = node.path new.names = SymbolTable() @@ -110,10 +178,12 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: for stmt in node.body.body: stmt.accept(init) - new = FuncDef(node.name, - [self.copy_argument(arg) for arg in node.arguments], - self.block(node.body), - cast(Optional[FunctionLike], self.optional_type(node.type))) + new = FuncDef( + node.name, + [self.copy_argument(arg) for arg in node.arguments], + self.block(node.body), + cast(Optional[FunctionLike], self.optional_type(node.type)), + ) self.copy_function_attributes(new, node) @@ -139,14 +209,15 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: return new def visit_lambda_expr(self, node: LambdaExpr) -> LambdaExpr: - new = LambdaExpr([self.copy_argument(arg) for arg in node.arguments], - self.block(node.body), - cast(Optional[FunctionLike], self.optional_type(node.type))) + new = LambdaExpr( + [self.copy_argument(arg) for arg in node.arguments], + self.block(node.body), + cast(Optional[FunctionLike], self.optional_type(node.type)), + ) self.copy_function_attributes(new, node) return new - def copy_function_attributes(self, new: FuncItem, - original: FuncItem) -> None: + def copy_function_attributes(self, new: FuncItem, original: FuncItem) -> None: new.info = original.info new.min_args = original.min_args new.max_pos = original.max_pos @@ -173,15 +244,16 @@ def visit_overloaded_func_def(self, node: OverloadedFuncDef) -> OverloadedFuncDe return new def visit_class_def(self, node: ClassDef) -> ClassDef: - new = ClassDef(node.name, - self.block(node.defs), - node.type_vars, - self.expressions(node.base_type_exprs), - self.optional_expr(node.metaclass)) + new = ClassDef( + node.name, + self.block(node.defs), + node.type_vars, + self.expressions(node.base_type_exprs), + self.optional_expr(node.metaclass), + ) new.fullname = node.fullname new.info = node.info - new.decorators = [self.expr(decorator) - for decorator in node.decorators] + new.decorators = [self.expr(decorator) for decorator in node.decorators] return new def visit_global_decl(self, node: GlobalDecl) -> GlobalDecl: @@ -197,8 +269,7 @@ def visit_decorator(self, node: Decorator) -> Decorator: # Note that a Decorator must be transformed to a Decorator. func = self.visit_func_def(node.func) func.line = node.func.line - new = Decorator(func, self.expressions(node.decorators), - self.visit_var(node.var)) + new = Decorator(func, self.expressions(node.decorators), self.visit_var(node.var)) new.is_overload = node.is_overload return new @@ -231,31 +302,34 @@ def visit_assignment_stmt(self, node: AssignmentStmt) -> AssignmentStmt: return self.duplicate_assignment(node) def duplicate_assignment(self, node: AssignmentStmt) -> AssignmentStmt: - new = AssignmentStmt(self.expressions(node.lvalues), - self.expr(node.rvalue), - self.optional_type(node.unanalyzed_type)) + new = AssignmentStmt( + self.expressions(node.lvalues), + self.expr(node.rvalue), + self.optional_type(node.unanalyzed_type), + ) new.line = node.line new.is_final_def = node.is_final_def new.type = self.optional_type(node.type) return new - def visit_operator_assignment_stmt(self, - node: OperatorAssignmentStmt) -> OperatorAssignmentStmt: - return OperatorAssignmentStmt(node.op, - self.expr(node.lvalue), - self.expr(node.rvalue)) + def visit_operator_assignment_stmt( + self, node: OperatorAssignmentStmt + ) -> OperatorAssignmentStmt: + return OperatorAssignmentStmt(node.op, self.expr(node.lvalue), self.expr(node.rvalue)) def visit_while_stmt(self, node: WhileStmt) -> WhileStmt: - return WhileStmt(self.expr(node.expr), - self.block(node.body), - self.optional_block(node.else_body)) + return WhileStmt( + self.expr(node.expr), self.block(node.body), self.optional_block(node.else_body) + ) def visit_for_stmt(self, node: ForStmt) -> ForStmt: - new = ForStmt(self.expr(node.index), - self.expr(node.expr), - self.block(node.body), - self.optional_block(node.else_body), - self.optional_type(node.unanalyzed_index_type)) + new = ForStmt( + self.expr(node.index), + self.expr(node.expr), + self.block(node.body), + self.optional_block(node.else_body), + self.optional_type(node.unanalyzed_index_type), + ) new.is_async = node.is_async new.index_type = self.optional_type(node.index_type) return new @@ -270,9 +344,11 @@ def visit_del_stmt(self, node: DelStmt) -> DelStmt: return DelStmt(self.expr(node.expr)) def visit_if_stmt(self, node: IfStmt) -> IfStmt: - return IfStmt(self.expressions(node.expr), - self.blocks(node.body), - self.optional_block(node.else_body)) + return IfStmt( + self.expressions(node.expr), + self.blocks(node.body), + self.optional_block(node.else_body), + ) def visit_break_stmt(self, node: BreakStmt) -> BreakStmt: return BreakStmt() @@ -284,35 +360,38 @@ def visit_pass_stmt(self, node: PassStmt) -> PassStmt: return PassStmt() def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt: - return RaiseStmt(self.optional_expr(node.expr), - self.optional_expr(node.from_expr)) + return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr)) def visit_try_stmt(self, node: TryStmt) -> TryStmt: - return TryStmt(self.block(node.body), - self.optional_names(node.vars), - self.optional_expressions(node.types), - self.blocks(node.handlers), - self.optional_block(node.else_body), - self.optional_block(node.finally_body)) + return TryStmt( + self.block(node.body), + self.optional_names(node.vars), + self.optional_expressions(node.types), + self.blocks(node.handlers), + self.optional_block(node.else_body), + self.optional_block(node.finally_body), + ) def visit_with_stmt(self, node: WithStmt) -> WithStmt: - new = WithStmt(self.expressions(node.expr), - self.optional_expressions(node.target), - self.block(node.body), - self.optional_type(node.unanalyzed_type)) + new = WithStmt( + self.expressions(node.expr), + self.optional_expressions(node.target), + self.block(node.body), + self.optional_type(node.unanalyzed_type), + ) new.is_async = node.is_async new.analyzed_types = [self.type(typ) for typ in node.analyzed_types] return new def visit_print_stmt(self, node: PrintStmt) -> PrintStmt: - return PrintStmt(self.expressions(node.args), - node.newline, - self.optional_expr(node.target)) + return PrintStmt( + self.expressions(node.args), node.newline, self.optional_expr(node.target) + ) def visit_exec_stmt(self, node: ExecStmt) -> ExecStmt: - return ExecStmt(self.expr(node.expr), - self.optional_expr(node.globals), - self.optional_expr(node.locals)) + return ExecStmt( + self.expr(node.expr), self.optional_expr(node.globals), self.optional_expr(node.locals) + ) def visit_star_expr(self, node: StarExpr) -> StarExpr: return StarExpr(node.expr) @@ -350,8 +429,7 @@ def duplicate_name(self, node: NameExpr) -> NameExpr: return new def visit_member_expr(self, node: MemberExpr) -> MemberExpr: - member = MemberExpr(self.expr(node.expr), - node.name) + member = MemberExpr(self.expr(node.expr), node.name) if node.def_var: # This refers to an attribute and we don't transform attributes by default, # just normal variables. @@ -387,11 +465,13 @@ def visit_await_expr(self, node: AwaitExpr) -> AwaitExpr: return AwaitExpr(self.expr(node.expr)) def visit_call_expr(self, node: CallExpr) -> CallExpr: - return CallExpr(self.expr(node.callee), - self.expressions(node.args), - node.arg_kinds[:], - node.arg_names[:], - self.optional_expr(node.analyzed)) + return CallExpr( + self.expr(node.callee), + self.expressions(node.args), + node.arg_kinds[:], + node.arg_names[:], + self.optional_expr(node.analyzed), + ) def visit_op_expr(self, node: OpExpr) -> OpExpr: new = OpExpr(node.op, self.expr(node.left), self.expr(node.right)) @@ -404,8 +484,7 @@ def visit_comparison_expr(self, node: ComparisonExpr) -> ComparisonExpr: return new def visit_cast_expr(self, node: CastExpr) -> CastExpr: - return CastExpr(self.expr(node.expr), - self.type(node.type)) + return CastExpr(self.expr(node.expr), self.type(node.type)) def visit_assert_type_expr(self, node: AssertTypeExpr) -> AssertTypeExpr: return AssertTypeExpr(self.expr(node.expr), self.type(node.type)) @@ -437,8 +516,9 @@ def visit_list_expr(self, node: ListExpr) -> ListExpr: return ListExpr(self.expressions(node.items)) def visit_dict_expr(self, node: DictExpr) -> DictExpr: - return DictExpr([(self.expr(key) if key else None, self.expr(value)) - for key, value in node.items]) + return DictExpr( + [(self.expr(key) if key else None, self.expr(value)) for key, value in node.items] + ) def visit_tuple_expr(self, node: TupleExpr) -> TupleExpr: return TupleExpr(self.expressions(node.items)) @@ -459,8 +539,7 @@ def visit_index_expr(self, node: IndexExpr) -> IndexExpr: return new def visit_type_application(self, node: TypeApplication) -> TypeApplication: - return TypeApplication(self.expr(node.expr), - self.types(node.types)) + return TypeApplication(self.expr(node.expr), self.types(node.types)) def visit_list_comprehension(self, node: ListComprehension) -> ListComprehension: generator = self.duplicate_generator(node.generator) @@ -472,43 +551,53 @@ def visit_set_comprehension(self, node: SetComprehension) -> SetComprehension: generator.set_line(node.generator.line, node.generator.column) return SetComprehension(generator) - def visit_dictionary_comprehension(self, node: DictionaryComprehension - ) -> DictionaryComprehension: - return DictionaryComprehension(self.expr(node.key), self.expr(node.value), - [self.expr(index) for index in node.indices], - [self.expr(s) for s in node.sequences], - [[self.expr(cond) for cond in conditions] - for conditions in node.condlists], - node.is_async) + def visit_dictionary_comprehension( + self, node: DictionaryComprehension + ) -> DictionaryComprehension: + return DictionaryComprehension( + self.expr(node.key), + self.expr(node.value), + [self.expr(index) for index in node.indices], + [self.expr(s) for s in node.sequences], + [[self.expr(cond) for cond in conditions] for conditions in node.condlists], + node.is_async, + ) def visit_generator_expr(self, node: GeneratorExpr) -> GeneratorExpr: return self.duplicate_generator(node) def duplicate_generator(self, node: GeneratorExpr) -> GeneratorExpr: - return GeneratorExpr(self.expr(node.left_expr), - [self.expr(index) for index in node.indices], - [self.expr(s) for s in node.sequences], - [[self.expr(cond) for cond in conditions] - for conditions in node.condlists], - node.is_async) + return GeneratorExpr( + self.expr(node.left_expr), + [self.expr(index) for index in node.indices], + [self.expr(s) for s in node.sequences], + [[self.expr(cond) for cond in conditions] for conditions in node.condlists], + node.is_async, + ) def visit_slice_expr(self, node: SliceExpr) -> SliceExpr: - return SliceExpr(self.optional_expr(node.begin_index), - self.optional_expr(node.end_index), - self.optional_expr(node.stride)) + return SliceExpr( + self.optional_expr(node.begin_index), + self.optional_expr(node.end_index), + self.optional_expr(node.stride), + ) def visit_conditional_expr(self, node: ConditionalExpr) -> ConditionalExpr: - return ConditionalExpr(self.expr(node.cond), - self.expr(node.if_expr), - self.expr(node.else_expr)) + return ConditionalExpr( + self.expr(node.cond), self.expr(node.if_expr), self.expr(node.else_expr) + ) def visit_backquote_expr(self, node: BackquoteExpr) -> BackquoteExpr: return BackquoteExpr(self.expr(node.expr)) def visit_type_var_expr(self, node: TypeVarExpr) -> TypeVarExpr: - return TypeVarExpr(node.name, node.fullname, - self.types(node.values), - self.type(node.upper_bound), variance=node.variance) + return TypeVarExpr( + node.name, + node.fullname, + self.types(node.values), + self.type(node.upper_bound), + variance=node.variance, + ) def visit_paramspec_expr(self, node: ParamSpecExpr) -> ParamSpecExpr: return ParamSpecExpr( @@ -593,8 +682,9 @@ def statements(self, statements: List[Statement]) -> List[Statement]: def expressions(self, expressions: List[Expression]) -> List[Expression]: return [self.expr(expr) for expr in expressions] - def optional_expressions(self, expressions: Iterable[Optional[Expression]] - ) -> List[Optional[Expression]]: + def optional_expressions( + self, expressions: Iterable[Optional[Expression]] + ) -> List[Optional[Expression]]: return [self.optional_expr(expr) for expr in expressions] def blocks(self, blocks: List[Block]) -> List[Block]: @@ -639,5 +729,6 @@ def visit_func_def(self, node: FuncDef) -> None: if node not in self.transformer.func_placeholder_map: # Haven't seen this FuncDef before, so create a placeholder node. self.transformer.func_placeholder_map[node] = FuncDef( - node.name, node.arguments, node.body, None) + node.name, node.arguments, node.body, None + ) super().visit_func_def(node) diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py index ecb00938fec92..8464bb58b336e 100644 --- a/mypy/tvar_scope.py +++ b/mypy/tvar_scope.py @@ -1,9 +1,19 @@ -from typing import Optional, Dict, Union -from mypy.types import ( - TypeVarLikeType, TypeVarType, ParamSpecType, ParamSpecFlavor, TypeVarId, TypeVarTupleType, -) +from typing import Dict, Optional, Union + from mypy.nodes import ( - ParamSpecExpr, TypeVarExpr, TypeVarLikeExpr, SymbolTableNode, TypeVarTupleExpr, + ParamSpecExpr, + SymbolTableNode, + TypeVarExpr, + TypeVarLikeExpr, + TypeVarTupleExpr, +) +from mypy.types import ( + ParamSpecFlavor, + ParamSpecType, + TypeVarId, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, ) @@ -13,11 +23,13 @@ class TypeVarLikeScope: Node fullname -> TypeVarLikeType. """ - def __init__(self, - parent: 'Optional[TypeVarLikeScope]' = None, - is_class_scope: bool = False, - prohibited: 'Optional[TypeVarLikeScope]' = None, - namespace: str = '') -> None: + def __init__( + self, + parent: "Optional[TypeVarLikeScope]" = None, + is_class_scope: bool = False, + prohibited: "Optional[TypeVarLikeScope]" = None, + namespace: str = "", + ) -> None: """Initializer for TypeVarLikeScope Parameters: @@ -37,7 +49,7 @@ def __init__(self, self.func_id = parent.func_id self.class_id = parent.class_id - def get_function_scope(self) -> 'Optional[TypeVarLikeScope]': + def get_function_scope(self) -> "Optional[TypeVarLikeScope]": """Get the nearest parent that's a function scope, not a class scope""" it: Optional[TypeVarLikeScope] = self while it is not None and it.is_class_scope: @@ -53,11 +65,11 @@ def allow_binding(self, fullname: str) -> bool: return False return True - def method_frame(self) -> 'TypeVarLikeScope': + def method_frame(self) -> "TypeVarLikeScope": """A new scope frame for binding a method""" return TypeVarLikeScope(self, False, None) - def class_frame(self, namespace: str) -> 'TypeVarLikeScope': + def class_frame(self, namespace: str) -> "TypeVarLikeScope": """A new scope frame for binding a class. Prohibits *this* class's tvars""" return TypeVarLikeScope(self.get_function_scope(), True, self, namespace=namespace) @@ -70,7 +82,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: self.func_id -= 1 i = self.func_id # TODO: Consider also using namespaces for functions - namespace = '' + namespace = "" if isinstance(tvar_expr, TypeVarExpr): tvar_def: TypeVarLikeType = TypeVarType( name, @@ -80,7 +92,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: upper_bound=tvar_expr.upper_bound, variance=tvar_expr.variance, line=tvar_expr.line, - column=tvar_expr.column + column=tvar_expr.column, ) elif isinstance(tvar_expr, ParamSpecExpr): tvar_def = ParamSpecType( @@ -90,7 +102,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: flavor=ParamSpecFlavor.BARE, upper_bound=tvar_expr.upper_bound, line=tvar_expr.line, - column=tvar_expr.column + column=tvar_expr.column, ) elif isinstance(tvar_expr, TypeVarTupleExpr): tvar_def = TypeVarTupleType( @@ -99,7 +111,7 @@ def bind_new(self, name: str, tvar_expr: TypeVarLikeExpr) -> TypeVarLikeType: i, upper_bound=tvar_expr.upper_bound, line=tvar_expr.line, - column=tvar_expr.column + column=tvar_expr.column, ) else: assert False @@ -120,7 +132,7 @@ def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarLike return None def __str__(self) -> str: - me = ", ".join(f'{k}: {v.name}`{v.id}' for k, v in self.scope.items()) + me = ", ".join(f"{k}: {v.name}`{v.id}" for k, v in self.scope.items()) if self.parent is None: return me return f"{self.parent} <- {me}" diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 79b4cb12d512a..774488b7ac3f8 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -12,19 +12,45 @@ """ from abc import abstractmethod +from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, Set, TypeVar, cast + +from mypy_extensions import mypyc_attr, trait + from mypy.backports import OrderedDict -from typing import Generic, TypeVar, cast, Any, List, Callable, Iterable, Optional, Set, Sequence -from mypy_extensions import trait, mypyc_attr -T = TypeVar('T') +T = TypeVar("T") from mypy.types import ( - Type, AnyType, CallableType, Overloaded, TupleType, TypedDictType, LiteralType, - Parameters, RawExpressionType, Instance, NoneType, TypeType, - UnionType, TypeVarType, PartialType, DeletedType, UninhabitedType, TypeVarLikeType, - UnboundType, ErasedType, StarType, EllipsisType, TypeList, CallableArgument, - PlaceholderType, TypeAliasType, ParamSpecType, UnpackType, TypeVarTupleType, - get_proper_type + AnyType, + CallableArgument, + CallableType, + DeletedType, + EllipsisType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + PlaceholderType, + RawExpressionType, + StarType, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, ) @@ -126,7 +152,7 @@ def visit_unpack_type(self, t: UnpackType) -> T: class SyntheticTypeVisitor(TypeVisitor[T]): """A TypeVisitor that also knows how to visit synthetic AST constructs. - Not just real types.""" + Not just real types.""" @abstractmethod def visit_star_type(self, t: StarType) -> T: @@ -212,36 +238,38 @@ def visit_unpack_type(self, t: UnpackType) -> Type: return UnpackType(t.type.accept(self)) def visit_callable_type(self, t: CallableType) -> Type: - return t.copy_modified(arg_types=self.translate_types(t.arg_types), - ret_type=t.ret_type.accept(self), - variables=self.translate_variables(t.variables)) + return t.copy_modified( + arg_types=self.translate_types(t.arg_types), + ret_type=t.ret_type.accept(self), + variables=self.translate_variables(t.variables), + ) def visit_tuple_type(self, t: TupleType) -> Type: - return TupleType(self.translate_types(t.items), - # TODO: This appears to be unsafe. - cast(Any, t.partial_fallback.accept(self)), - t.line, t.column) + return TupleType( + self.translate_types(t.items), + # TODO: This appears to be unsafe. + cast(Any, t.partial_fallback.accept(self)), + t.line, + t.column, + ) def visit_typeddict_type(self, t: TypedDictType) -> Type: - items = OrderedDict([ - (item_name, item_type.accept(self)) - for (item_name, item_type) in t.items.items() - ]) - return TypedDictType(items, - t.required_keys, - # TODO: This appears to be unsafe. - cast(Any, t.fallback.accept(self)), - t.line, t.column) + items = OrderedDict( + [(item_name, item_type.accept(self)) for (item_name, item_type) in t.items.items()] + ) + return TypedDictType( + items, + t.required_keys, + # TODO: This appears to be unsafe. + cast(Any, t.fallback.accept(self)), + t.line, + t.column, + ) def visit_literal_type(self, t: LiteralType) -> Type: fallback = t.fallback.accept(self) assert isinstance(fallback, Instance) # type: ignore - return LiteralType( - value=t.value, - fallback=fallback, - line=t.line, - column=t.column, - ) + return LiteralType(value=t.value, fallback=fallback, line=t.line, column=t.column) def visit_union_type(self, t: UnionType) -> Type: return UnionType(self.translate_types(t.items), t.line, t.column) @@ -249,8 +277,9 @@ def visit_union_type(self, t: UnionType) -> Type: def translate_types(self, types: Iterable[Type]) -> List[Type]: return [t.accept(self) for t in types] - def translate_variables(self, - variables: Sequence[TypeVarLikeType]) -> Sequence[TypeVarLikeType]: + def translate_variables( + self, variables: Sequence[TypeVarLikeType] + ) -> Sequence[TypeVarLikeType]: return variables def visit_overloaded(self, t: Overloaded) -> Type: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 2700ff10758ea..78cfb8b59935b 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1,78 +1,132 @@ """Semantic analysis of types""" import itertools -from itertools import chain from contextlib import contextmanager -from mypy.backports import OrderedDict +from itertools import chain +from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, TypeVar -from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable, Sequence from typing_extensions import Final, Protocol -from mypy.messages import MessageBuilder, quote_type_string, format_type_bare +from mypy import errorcodes as codes, message_registry, nodes +from mypy.backports import OrderedDict +from mypy.errorcodes import ErrorCode +from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type +from mypy.messages import MessageBuilder, format_type_bare, quote_type_string +from mypy.nodes import ( + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + ARG_STAR, + ARG_STAR2, + SYMBOL_FUNCBASE_TYPES, + ArgKind, + Context, + Decorator, + Expression, + MypyFile, + ParamSpecExpr, + PlaceholderNode, + SymbolTableNode, + TypeAlias, + TypeInfo, + TypeVarExpr, + TypeVarLikeExpr, + TypeVarTupleExpr, + Var, + check_arg_kinds, + check_arg_names, + get_nongen_builtins, +) from mypy.options import Options +from mypy.plugin import AnalyzeTypeContext, Plugin, TypeAnalyzerPluginInterface +from mypy.semanal_shared import SemanticAnalyzerCoreInterface, paramspec_args, paramspec_kwargs +from mypy.tvar_scope import TypeVarLikeScope from mypy.types import ( - NEVER_NAMES, Type, UnboundType, TupleType, TypedDictType, UnionType, Instance, AnyType, - CallableType, NoneType, ErasedType, DeletedType, TypeList, TypeVarType, SyntheticTypeVisitor, - StarType, PartialType, EllipsisType, UninhabitedType, TypeType, CallableArgument, - Parameters, TypeQuery, union_items, TypeOfAny, LiteralType, RawExpressionType, - PlaceholderType, Overloaded, get_proper_type, TypeAliasType, RequiredType, - TypeVarLikeType, ParamSpecType, ParamSpecFlavor, UnpackType, TypeVarTupleType, - callable_with_ellipsis, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, - LITERAL_TYPE_NAMES, ANNOTATED_TYPE_NAMES, -) - -from mypy.nodes import ( - TypeInfo, Context, SymbolTableNode, Var, Expression, - get_nongen_builtins, check_arg_names, check_arg_kinds, ArgKind, ARG_POS, ARG_NAMED, - ARG_OPT, ARG_NAMED_OPT, ARG_STAR, ARG_STAR2, TypeVarExpr, TypeVarLikeExpr, ParamSpecExpr, - TypeAlias, PlaceholderNode, SYMBOL_FUNCBASE_TYPES, Decorator, MypyFile, - TypeVarTupleExpr + ANNOTATED_TYPE_NAMES, + FINAL_TYPE_NAMES, + LITERAL_TYPE_NAMES, + NEVER_NAMES, + TYPE_ALIAS_NAMES, + AnyType, + CallableArgument, + CallableType, + DeletedType, + EllipsisType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecFlavor, + ParamSpecType, + PartialType, + PlaceholderType, + RawExpressionType, + RequiredType, + StarType, + SyntheticTypeVisitor, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeOfAny, + TypeQuery, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, + callable_with_ellipsis, + get_proper_type, + union_items, ) from mypy.typetraverser import TypeTraverserVisitor -from mypy.tvar_scope import TypeVarLikeScope -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError -from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext -from mypy.semanal_shared import SemanticAnalyzerCoreInterface, paramspec_args, paramspec_kwargs -from mypy.errorcodes import ErrorCode -from mypy import nodes, message_registry, errorcodes as codes -T = TypeVar('T') +T = TypeVar("T") type_constructors: Final = { - 'typing.Callable', - 'typing.Optional', - 'typing.Tuple', - 'typing.Type', - 'typing.Union', + "typing.Callable", + "typing.Optional", + "typing.Tuple", + "typing.Type", + "typing.Union", *LITERAL_TYPE_NAMES, *ANNOTATED_TYPE_NAMES, } ARG_KINDS_BY_CONSTRUCTOR: Final = { - 'mypy_extensions.Arg': ARG_POS, - 'mypy_extensions.DefaultArg': ARG_OPT, - 'mypy_extensions.NamedArg': ARG_NAMED, - 'mypy_extensions.DefaultNamedArg': ARG_NAMED_OPT, - 'mypy_extensions.VarArg': ARG_STAR, - 'mypy_extensions.KwArg': ARG_STAR2, + "mypy_extensions.Arg": ARG_POS, + "mypy_extensions.DefaultArg": ARG_OPT, + "mypy_extensions.NamedArg": ARG_NAMED, + "mypy_extensions.DefaultNamedArg": ARG_NAMED_OPT, + "mypy_extensions.VarArg": ARG_STAR, + "mypy_extensions.KwArg": ARG_STAR2, } GENERIC_STUB_NOT_AT_RUNTIME_TYPES: Final = { - 'queue.Queue', - 'builtins._PathLike', - 'asyncio.futures.Future', + "queue.Queue", + "builtins._PathLike", + "asyncio.futures.Future", } -def analyze_type_alias(node: Expression, - api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarLikeScope, - plugin: Plugin, - options: Options, - is_typeshed_stub: bool, - allow_placeholder: bool = False, - in_dynamic_func: bool = False, - global_scope: bool = True) -> Optional[Tuple[Type, Set[str]]]: +def analyze_type_alias( + node: Expression, + api: SemanticAnalyzerCoreInterface, + tvar_scope: TypeVarLikeScope, + plugin: Plugin, + options: Options, + is_typeshed_stub: bool, + allow_placeholder: bool = False, + in_dynamic_func: bool = False, + global_scope: bool = True, +) -> Optional[Tuple[Type, Set[str]]]: """Analyze r.h.s. of a (potential) type alias definition. If `node` is valid as a type alias rvalue, return the resulting type and a set of @@ -82,11 +136,17 @@ def analyze_type_alias(node: Expression, try: type = expr_to_unanalyzed_type(node, options, api.is_stub_file) except TypeTranslationError: - api.fail('Invalid type alias: expression is not a valid type', node) + api.fail("Invalid type alias: expression is not a valid type", node) return None - analyzer = TypeAnalyser(api, tvar_scope, plugin, options, is_typeshed_stub, - defining_alias=True, - allow_placeholder=allow_placeholder) + analyzer = TypeAnalyser( + api, + tvar_scope, + plugin, + options, + is_typeshed_stub, + defining_alias=True, + allow_placeholder=allow_placeholder, + ) analyzer.in_dynamic_func = in_dynamic_func analyzer.global_scope = global_scope res = type.accept(analyzer) @@ -94,7 +154,7 @@ def analyze_type_alias(node: Expression, def no_subscript_builtin_alias(name: str, propose_alt: bool = True) -> str: - class_name = name.split('.')[-1] + class_name = name.split(".")[-1] msg = f'"{class_name}" is not subscriptable' # This should never be called if the python_version is 3.9 or newer nongen_builtins = get_nongen_builtins((3, 8)) @@ -119,19 +179,22 @@ class TypeAnalyser(SyntheticTypeVisitor[Type], TypeAnalyzerPluginInterface): # Is this called from global scope? global_scope: bool = True - def __init__(self, - api: SemanticAnalyzerCoreInterface, - tvar_scope: TypeVarLikeScope, - plugin: Plugin, - options: Options, - is_typeshed_stub: bool, *, - defining_alias: bool = False, - allow_tuple_literal: bool = False, - allow_unbound_tvars: bool = False, - allow_placeholder: bool = False, - allow_required: bool = False, - allow_param_spec_literals: bool = False, - report_invalid_types: bool = True) -> None: + def __init__( + self, + api: SemanticAnalyzerCoreInterface, + tvar_scope: TypeVarLikeScope, + plugin: Plugin, + options: Options, + is_typeshed_stub: bool, + *, + defining_alias: bool = False, + allow_tuple_literal: bool = False, + allow_unbound_tvars: bool = False, + allow_placeholder: bool = False, + allow_required: bool = False, + allow_param_spec_literals: bool = False, + report_invalid_types: bool = True, + ) -> None: self.api = api self.lookup_qualified = api.lookup_qualified self.lookup_fqn_func = api.lookup_fully_qualified @@ -145,9 +208,8 @@ def __init__(self, self.nesting_level = 0 # Should we allow new type syntax when targeting older Python versions # like 'list[int]' or 'X | Y' (allowed in stubs and with `__future__` import)? - self.always_allow_new_syntax = ( - self.api.is_stub_file - or self.api.is_future_flag_set('annotations') + self.always_allow_new_syntax = self.api.is_stub_file or self.api.is_future_flag_set( + "annotations" ) # Should we accept unbound type variables (always OK in aliases)? self.allow_unbound_tvars = allow_unbound_tvars or defining_alias @@ -200,17 +262,20 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.api.record_incomplete_ref() return AnyType(TypeOfAny.special_form) if node is None: - self.fail(f'Internal error (node is None, kind={sym.kind})', t) + self.fail(f"Internal error (node is None, kind={sym.kind})", t) return AnyType(TypeOfAny.special_form) fullname = node.fullname hook = self.plugin.get_type_analyze_hook(fullname) if hook is not None: return hook(AnalyzeTypeContext(t, t, self)) - if (fullname in get_nongen_builtins(self.options.python_version) - and t.args - and not self.always_allow_new_syntax): - self.fail(no_subscript_builtin_alias(fullname, - propose_alt=not self.defining_alias), t) + if ( + fullname in get_nongen_builtins(self.options.python_version) + and t.args + and not self.always_allow_new_syntax + ): + self.fail( + no_subscript_builtin_alias(fullname, propose_alt=not self.defining_alias), t + ) tvar_def = self.tvar_scope.get_binding(sym) if isinstance(sym.node, ParamSpecExpr): if tvar_def is None: @@ -221,12 +286,20 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(f'ParamSpec "{t.name}" used with arguments', t) # Change the line number return ParamSpecType( - tvar_def.name, tvar_def.fullname, tvar_def.id, tvar_def.flavor, - tvar_def.upper_bound, line=t.line, column=t.column, + tvar_def.name, + tvar_def.fullname, + tvar_def.id, + tvar_def.flavor, + tvar_def.upper_bound, + line=t.line, + column=t.column, ) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None and self.defining_alias: - self.fail('Can\'t use bound type variable "{}"' - ' to define generic alias'.format(t.name), t) + self.fail( + 'Can\'t use bound type variable "{}"' + " to define generic alias".format(t.name), + t, + ) return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarExpr) and tvar_def is not None: assert isinstance(tvar_def, TypeVarType) @@ -234,14 +307,23 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(f'Type variable "{t.name}" used with arguments', t) # Change the line number return TypeVarType( - tvar_def.name, tvar_def.fullname, tvar_def.id, tvar_def.values, - tvar_def.upper_bound, tvar_def.variance, line=t.line, column=t.column, + tvar_def.name, + tvar_def.fullname, + tvar_def.id, + tvar_def.values, + tvar_def.upper_bound, + tvar_def.variance, + line=t.line, + column=t.column, ) if isinstance(sym.node, TypeVarTupleExpr) and ( tvar_def is not None and self.defining_alias ): - self.fail('Can\'t use bound type variable "{}"' - ' to define generic alias'.format(t.name), t) + self.fail( + 'Can\'t use bound type variable "{}"' + " to define generic alias".format(t.name), + t, + ) return AnyType(TypeOfAny.from_error) if isinstance(sym.node, TypeVarTupleExpr): if tvar_def is None: @@ -252,8 +334,12 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.fail(f'Type variable "{t.name}" used with arguments', t) # Change the line number return TypeVarTupleType( - tvar_def.name, tvar_def.fullname, tvar_def.id, - tvar_def.upper_bound, line=t.line, column=t.column, + tvar_def.name, + tvar_def.fullname, + tvar_def.id, + tvar_def.upper_bound, + line=t.line, + column=t.column, ) special = self.try_analyze_special_unbound_type(t, fullname) if special is not None: @@ -262,14 +348,22 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) self.aliases_used.add(fullname) an_args = self.anal_array(t.args) disallow_any = self.options.disallow_any_generics and not self.is_typeshed_stub - res = expand_type_alias(node, an_args, self.fail, node.no_args, t, - unexpanded_type=t, - disallow_any=disallow_any) + res = expand_type_alias( + node, + an_args, + self.fail, + node.no_args, + t, + unexpanded_type=t, + disallow_any=disallow_any, + ) # The only case where expand_type_alias() can return an incorrect instance is # when it is top-level instance, so no need to recurse. - if (isinstance(res, Instance) and # type: ignore[misc] - len(res.args) != len(res.type.type_vars) and - not self.defining_alias): + if ( + isinstance(res, Instance) # type: ignore[misc] + and len(res.args) != len(res.type.type_vars) + and not self.defining_alias + ): fix_instance( res, self.fail, @@ -277,7 +371,8 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) disallow_any=disallow_any, python_version=self.options.python_version, use_generic_error=True, - unexpanded_type=t) + unexpanded_type=t, + ) if node.eager: # TODO: Generate error if recursive (once we have recursive types) res = get_proper_type(res) @@ -287,7 +382,7 @@ def visit_unbound_type_nonoptional(self, t: UnboundType, defining_literal: bool) elif node.fullname in TYPE_ALIAS_NAMES: return AnyType(TypeOfAny.special_form) # Concatenate is an operator, no need for a proper type - elif node.fullname in ('typing_extensions.Concatenate', 'typing.Concatenate'): + elif node.fullname in ("typing_extensions.Concatenate", "typing.Concatenate"): # We check the return type further up the stack for valid use locations return self.apply_concatenate_operator(t) else: @@ -299,25 +394,23 @@ def cannot_resolve_type(self, t: UnboundType) -> None: # TODO: Move error message generation to messages.py. We'd first # need access to MessageBuilder here. Also move the similar # message generation logic in semanal.py. - self.api.fail( - f'Cannot resolve name "{t.name}" (possible cyclic definition)', - t) + self.api.fail(f'Cannot resolve name "{t.name}" (possible cyclic definition)', t) def apply_concatenate_operator(self, t: UnboundType) -> Type: if len(t.args) == 0: - self.api.fail('Concatenate needs type arguments', t) + self.api.fail("Concatenate needs type arguments", t) return AnyType(TypeOfAny.from_error) # last argument has to be ParamSpec ps = self.anal_type(t.args[-1], allow_param_spec=True) if not isinstance(ps, ParamSpecType): - self.api.fail('The last parameter to Concatenate needs to be a ParamSpec', t) + self.api.fail("The last parameter to Concatenate needs to be a ParamSpec", t) return AnyType(TypeOfAny.from_error) # TODO: this may not work well with aliases, if those worked. # Those should be special-cased. elif ps.prefix.arg_types: - self.api.fail('Nested Concatenates are invalid', t) + self.api.fail("Nested Concatenates are invalid", t) args = self.anal_array(t.args[:-1]) pre = ps.prefix @@ -325,9 +418,9 @@ def apply_concatenate_operator(self, t: UnboundType) -> Type: # mypy can't infer this :( names: List[Optional[str]] = [None] * len(args) - pre = Parameters(args + pre.arg_types, - [ARG_POS] * len(args) + pre.arg_kinds, - names + pre.arg_names) + pre = Parameters( + args + pre.arg_types, [ARG_POS] * len(args) + pre.arg_kinds, names + pre.arg_names + ) return ps.copy_modified(prefix=pre) def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Optional[Type]: @@ -335,22 +428,24 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt Return the bound type if successful, and return None if the type is a normal type. """ - if fullname == 'builtins.None': + if fullname == "builtins.None": return NoneType() - elif fullname == 'typing.Any' or fullname == 'builtins.Any': + elif fullname == "typing.Any" or fullname == "builtins.Any": return AnyType(TypeOfAny.explicit) elif fullname in FINAL_TYPE_NAMES: - self.fail("Final can be only used as an outermost qualifier" - " in a variable annotation", t) + self.fail( + "Final can be only used as an outermost qualifier" " in a variable annotation", t + ) return AnyType(TypeOfAny.from_error) - elif (fullname == 'typing.Tuple' or - (fullname == 'builtins.tuple' - and (self.always_allow_new_syntax or self.options.python_version >= (3, 9)))): + elif fullname == "typing.Tuple" or ( + fullname == "builtins.tuple" + and (self.always_allow_new_syntax or self.options.python_version >= (3, 9)) + ): # Tuple is special because it is involved in builtin import cycle # and may be not ready when used. - sym = self.api.lookup_fully_qualified_or_none('builtins.tuple') + sym = self.api.lookup_fully_qualified_or_none("builtins.tuple") if not sym or isinstance(sym.node, PlaceholderNode): - if self.api.is_incomplete_namespace('builtins'): + if self.api.is_incomplete_namespace("builtins"): self.api.record_incomplete_ref() else: self.fail('Name "tuple" is not defined', t) @@ -358,30 +453,30 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt if len(t.args) == 0 and not t.empty_tuple_index: # Bare 'Tuple' is same as 'tuple' any_type = self.get_omitted_any(t) - return self.named_type('builtins.tuple', [any_type], - line=t.line, column=t.column) + return self.named_type("builtins.tuple", [any_type], line=t.line, column=t.column) if len(t.args) == 2 and isinstance(t.args[1], EllipsisType): # Tuple[T, ...] (uniform, variable-length tuple) - instance = self.named_type('builtins.tuple', [self.anal_type(t.args[0])]) + instance = self.named_type("builtins.tuple", [self.anal_type(t.args[0])]) instance.line = t.line return instance return self.tuple_type(self.anal_array(t.args)) - elif fullname == 'typing.Union': + elif fullname == "typing.Union": items = self.anal_array(t.args) return UnionType.make_union(items) - elif fullname == 'typing.Optional': + elif fullname == "typing.Optional": if len(t.args) != 1: - self.fail('Optional[...] must have exactly one type argument', t) + self.fail("Optional[...] must have exactly one type argument", t) return AnyType(TypeOfAny.from_error) item = self.anal_type(t.args[0]) return make_optional_type(item) - elif fullname == 'typing.Callable': + elif fullname == "typing.Callable": return self.analyze_callable_type(t) - elif (fullname == 'typing.Type' or - (fullname == 'builtins.type' - and (self.always_allow_new_syntax or self.options.python_version >= (3, 9)))): + elif fullname == "typing.Type" or ( + fullname == "builtins.type" + and (self.always_allow_new_syntax or self.options.python_version >= (3, 9)) + ): if len(t.args) == 0: - if fullname == 'typing.Type': + if fullname == "typing.Type": any_type = self.get_omitted_any(t) return TypeType(any_type, line=t.line, column=t.column) else: @@ -389,17 +484,17 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt # See https://github.com/python/mypy/issues/9476 for more information return None if len(t.args) != 1: - type_str = 'Type[...]' if fullname == 'typing.Type' else 'type[...]' - self.fail(type_str + ' must have exactly one type argument', t) + type_str = "Type[...]" if fullname == "typing.Type" else "type[...]" + self.fail(type_str + " must have exactly one type argument", t) item = self.anal_type(t.args[0]) return TypeType.make_normalized(item, line=t.line) - elif fullname == 'typing.ClassVar': + elif fullname == "typing.ClassVar": if self.nesting_level > 0: - self.fail('Invalid type: ClassVar nested inside other type', t) + self.fail("Invalid type: ClassVar nested inside other type", t) if len(t.args) == 0: return AnyType(TypeOfAny.from_omitted_generics, line=t.line, column=t.column) if len(t.args) != 1: - self.fail('ClassVar[...] must have at most one type argument', t) + self.fail("ClassVar[...] must have at most one type argument", t) return AnyType(TypeOfAny.from_error) return self.anal_type(t.args[0]) elif fullname in NEVER_NAMES: @@ -408,11 +503,14 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt return self.analyze_literal_type(t) elif fullname in ANNOTATED_TYPE_NAMES: if len(t.args) < 2: - self.fail("Annotated[...] must have exactly one type argument" - " and at least one annotation", t) + self.fail( + "Annotated[...] must have exactly one type argument" + " and at least one annotation", + t, + ) return AnyType(TypeOfAny.from_error) return self.anal_type(t.args[0]) - elif fullname in ('typing_extensions.Required', 'typing.Required'): + elif fullname in ("typing_extensions.Required", "typing.Required"): if not self.allow_required: self.fail("Required[] can be only used in a TypedDict definition", t) return AnyType(TypeOfAny.from_error) @@ -420,7 +518,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt self.fail("Required[] must have exactly one type argument", t) return AnyType(TypeOfAny.from_error) return RequiredType(self.anal_type(t.args[0]), required=True) - elif fullname in ('typing_extensions.NotRequired', 'typing.NotRequired'): + elif fullname in ("typing_extensions.NotRequired", "typing.NotRequired"): if not self.allow_required: self.fail("NotRequired[] can be only used in a TypedDict definition", t) return AnyType(TypeOfAny.from_error) @@ -430,30 +528,30 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Opt return RequiredType(self.anal_type(t.args[0]), required=False) elif self.anal_type_guard_arg(t, fullname) is not None: # In most contexts, TypeGuard[...] acts as an alias for bool (ignoring its args) - return self.named_type('builtins.bool') - elif fullname in ('typing.Unpack', 'typing_extensions.Unpack'): + return self.named_type("builtins.bool") + elif fullname in ("typing.Unpack", "typing_extensions.Unpack"): # We don't want people to try to use this yet. if not self.options.enable_incomplete_features: self.fail('"Unpack" is not supported by mypy yet', t) return AnyType(TypeOfAny.from_error) - return UnpackType( - self.anal_type(t.args[0]), line=t.line, column=t.column, - ) + return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column) return None def get_omitted_any(self, typ: Type, fullname: Optional[str] = None) -> AnyType: disallow_any = not self.is_typeshed_stub and self.options.disallow_any_generics - return get_omitted_any(disallow_any, self.fail, self.note, typ, - self.options.python_version, fullname) + return get_omitted_any( + disallow_any, self.fail, self.note, typ, self.options.python_version, fullname + ) def analyze_type_with_type_info( - self, info: TypeInfo, args: Sequence[Type], ctx: Context) -> Type: + self, info: TypeInfo, args: Sequence[Type], ctx: Context + ) -> Type: """Bind unbound type when were able to find target TypeInfo. This handles simple cases like 'int', 'modname.UserClass[str]', etc. """ - if len(args) > 0 and info.fullname == 'builtins.tuple': + if len(args) > 0 and info.fullname == "builtins.tuple": fallback = Instance(info, [AnyType(TypeOfAny.special_form)], ctx.line) return TupleType(self.anal_array(args), fallback, ctx.line) @@ -465,8 +563,9 @@ def analyze_type_with_type_info( # checked only later, since we do not always know the # valid count at this point. Thus we may construct an # Instance with an invalid number of type arguments. - instance = Instance(info, self.anal_array(args, allow_param_spec=True), - ctx.line, ctx.column) + instance = Instance( + info, self.anal_array(args, allow_param_spec=True), ctx.line, ctx.column + ) # "aesthetic" paramspec literals # these do not support mypy_extensions VarArgs, etc. as they were already analyzed @@ -476,9 +575,14 @@ def analyze_type_with_type_info( first_arg = get_proper_type(instance.args[0]) # TODO: can I use tuple syntax to isinstance multiple in 3.6? - if not (len(instance.args) == 1 and (isinstance(first_arg, Parameters) or - isinstance(first_arg, ParamSpecType) or - isinstance(first_arg, AnyType))): + if not ( + len(instance.args) == 1 + and ( + isinstance(first_arg, Parameters) + or isinstance(first_arg, ParamSpecType) + or isinstance(first_arg, AnyType) + ) + ): args = instance.args instance.args = (Parameters(args, [ARG_POS] * len(args), [None] * len(args)),) @@ -490,39 +594,43 @@ def analyze_type_with_type_info( # Check type argument count. if not valid_arg_length and not self.defining_alias: - fix_instance(instance, self.fail, self.note, - disallow_any=self.options.disallow_any_generics and - not self.is_typeshed_stub, - python_version=self.options.python_version) + fix_instance( + instance, + self.fail, + self.note, + disallow_any=self.options.disallow_any_generics and not self.is_typeshed_stub, + python_version=self.options.python_version, + ) tup = info.tuple_type if tup is not None: # The class has a Tuple[...] base class so it will be # represented as a tuple type. if args: - self.fail('Generic tuple types not supported', ctx) + self.fail("Generic tuple types not supported", ctx) return AnyType(TypeOfAny.from_error) - return tup.copy_modified(items=self.anal_array(tup.items), - fallback=instance) + return tup.copy_modified(items=self.anal_array(tup.items), fallback=instance) td = info.typeddict_type if td is not None: # The class has a TypedDict[...] base class so it will be # represented as a typeddict type. if args: - self.fail('Generic TypedDict types not supported', ctx) + self.fail("Generic TypedDict types not supported", ctx) return AnyType(TypeOfAny.from_error) # Create a named TypedDictType - return td.copy_modified(item_types=self.anal_array(list(td.items.values())), - fallback=instance) + return td.copy_modified( + item_types=self.anal_array(list(td.items.values())), fallback=instance + ) - if info.fullname == 'types.NoneType': + if info.fullname == "types.NoneType": self.fail("NoneType should not be used as a type, please use None instead", ctx) return NoneType(ctx.line, ctx.column) return instance - def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTableNode, - defining_literal: bool) -> Type: + def analyze_unbound_type_without_type_info( + self, t: UnboundType, sym: SymbolTableNode, defining_literal: bool + ) -> Type: """Figure out what an unbound type that doesn't refer to a TypeInfo node means. This is something unusual. We try our best to find out what it is. @@ -539,13 +647,16 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl if isinstance(sym.node, Var): typ = get_proper_type(sym.node.type) if isinstance(typ, AnyType): - return AnyType(TypeOfAny.from_unimported_type, - missing_import_name=typ.missing_import_name) + return AnyType( + TypeOfAny.from_unimported_type, missing_import_name=typ.missing_import_name + ) # Option 2: # Unbound type variable. Currently these may be still valid, # for example when defining a generic type alias. - unbound_tvar = (isinstance(sym.node, (TypeVarExpr, TypeVarTupleExpr)) and - self.tvar_scope.get_binding(sym) is None) + unbound_tvar = ( + isinstance(sym.node, (TypeVarExpr, TypeVarTupleExpr)) + and self.tvar_scope.get_binding(sym) is None + ) if self.allow_unbound_tvars and unbound_tvar: return t @@ -563,7 +674,8 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl base_enum_short_name = sym.node.info.name if not defining_literal: msg = message_registry.INVALID_TYPE_RAW_ENUM_VALUE.format( - base_enum_short_name, value) + base_enum_short_name, value + ) self.fail(msg, t) return AnyType(TypeOfAny.from_error) return LiteralType( @@ -579,14 +691,16 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl # TODO: Move this message building logic to messages.py. notes: List[str] = [] if isinstance(sym.node, Var): - notes.append('See https://mypy.readthedocs.io/en/' - 'stable/common_issues.html#variables-vs-type-aliases') + notes.append( + "See https://mypy.readthedocs.io/en/" + "stable/common_issues.html#variables-vs-type-aliases" + ) message = 'Variable "{}" is not valid as a type' elif isinstance(sym.node, (SYMBOL_FUNCBASE_TYPES, Decorator)): message = 'Function "{}" is not valid as a type' - if name == 'builtins.any': + if name == "builtins.any": notes.append('Perhaps you meant "typing.Any" instead of "any"?') - elif name == 'builtins.callable': + elif name == "builtins.callable": notes.append('Perhaps you meant "typing.Callable" instead of "callable"?') else: notes.append('Perhaps you need "Callable[...]" or a callback protocol?') @@ -595,11 +709,17 @@ def analyze_unbound_type_without_type_info(self, t: UnboundType, sym: SymbolTabl message = 'Module "{}" is not valid as a type' elif unbound_tvar: message = 'Type variable "{}" is unbound' - short = name.split('.')[-1] - notes.append(('(Hint: Use "Generic[{}]" or "Protocol[{}]" base class' - ' to bind "{}" inside a class)').format(short, short, short)) - notes.append('(Hint: Use "{}" in function signature to bind "{}"' - ' inside a function)'.format(short, short)) + short = name.split(".")[-1] + notes.append( + ( + '(Hint: Use "Generic[{}]" or "Protocol[{}]" base class' + ' to bind "{}" inside a class)' + ).format(short, short, short) + ) + notes.append( + '(Hint: Use "{}" in function signature to bind "{}"' + " inside a function)".format(short, short) + ) else: message = 'Cannot interpret reference "{}" as a type' self.fail(message.format(name), t, code=codes.VALID_TYPE) @@ -644,7 +764,7 @@ def visit_type_list(self, t: TypeList) -> Type: return AnyType(TypeOfAny.from_error) def visit_callable_argument(self, t: CallableArgument) -> Type: - self.fail('Invalid type', t) + self.fail("Invalid type", t) return AnyType(TypeOfAny.from_error) def visit_instance(self, t: Instance) -> Type: @@ -685,15 +805,15 @@ def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: ] else: arg_types = self.anal_array(t.arg_types, nested=nested) - ret = t.copy_modified(arg_types=arg_types, - ret_type=self.anal_type(t.ret_type, nested=nested), - # If the fallback isn't filled in yet, - # its type will be the falsey FakeInfo - fallback=(t.fallback if t.fallback.type - else self.named_type('builtins.function')), - variables=self.anal_var_defs(variables), - type_guard=special, - ) + ret = t.copy_modified( + arg_types=arg_types, + ret_type=self.anal_type(t.ret_type, nested=nested), + # If the fallback isn't filled in yet, + # its type will be the falsey FakeInfo + fallback=(t.fallback if t.fallback.type else self.named_type("builtins.function")), + variables=self.anal_var_defs(variables), + type_guard=special, + ) return ret def anal_type_guard(self, t: Type) -> Optional[Type]: @@ -705,7 +825,7 @@ def anal_type_guard(self, t: Type) -> Optional[Type]: return None def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]: - if fullname in ('typing_extensions.TypeGuard', 'typing.TypeGuard'): + if fullname in ("typing_extensions.TypeGuard", "typing.TypeGuard"): if len(t.args) != 1: self.fail("TypeGuard must have exactly one type argument", t) return AnyType(TypeOfAny.from_error) @@ -715,9 +835,9 @@ def anal_type_guard_arg(self, t: UnboundType, fullname: str) -> Optional[Type]: def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type: """Analyze signature argument type for *args and **kwargs argument.""" # TODO: Check that suffix and kind match - if isinstance(t, UnboundType) and t.name and '.' in t.name and not t.args: - components = t.name.split('.') - sym = self.lookup_qualified('.'.join(components[:-1]), t) + if isinstance(t, UnboundType) and t.name and "." in t.name and not t.args: + components = t.name.split(".") + sym = self.lookup_qualified(".".join(components[:-1]), t) if sym is not None and isinstance(sym.node, ParamSpecExpr): tvar_def = self.tvar_scope.get_binding(sym) if isinstance(tvar_def, ParamSpecType): @@ -727,9 +847,14 @@ def anal_star_arg_type(self, t: Type, kind: ArgKind, nested: bool) -> Type: make_paramspec = paramspec_kwargs else: assert False, kind - return make_paramspec(tvar_def.name, tvar_def.fullname, tvar_def.id, - named_type_func=self.named_type, - line=t.line, column=t.column) + return make_paramspec( + tvar_def.name, + tvar_def.fullname, + tvar_def.id, + named_type_func=self.named_type, + line=t.line, + column=t.column, + ) return self.anal_type(t, nested=nested) def visit_overloaded(self, t: Overloaded) -> Type: @@ -744,36 +869,47 @@ def visit_tuple_type(self, t: TupleType) -> Type: # Types such as (t1, t2, ...) only allowed in assignment statements. They'll # generate errors elsewhere, and Tuple[t1, t2, ...] must be used instead. if t.implicit and not self.allow_tuple_literal: - self.fail('Syntax error in type annotation', t, code=codes.SYNTAX) + self.fail("Syntax error in type annotation", t, code=codes.SYNTAX) if len(t.items) == 0: - self.note('Suggestion: Use Tuple[()] instead of () for an empty tuple, or ' - 'None for a function without a return value', t, code=codes.SYNTAX) + self.note( + "Suggestion: Use Tuple[()] instead of () for an empty tuple, or " + "None for a function without a return value", + t, + code=codes.SYNTAX, + ) elif len(t.items) == 1: - self.note('Suggestion: Is there a spurious trailing comma?', t, code=codes.SYNTAX) + self.note("Suggestion: Is there a spurious trailing comma?", t, code=codes.SYNTAX) else: - self.note('Suggestion: Use Tuple[T1, ..., Tn] instead of (T1, ..., Tn)', t, - code=codes.SYNTAX) + self.note( + "Suggestion: Use Tuple[T1, ..., Tn] instead of (T1, ..., Tn)", + t, + code=codes.SYNTAX, + ) return AnyType(TypeOfAny.from_error) star_count = sum(1 for item in t.items if isinstance(item, StarType)) if star_count > 1: - self.fail('At most one star type allowed in a tuple', t) + self.fail("At most one star type allowed in a tuple", t) if t.implicit: - return TupleType([AnyType(TypeOfAny.from_error) for _ in t.items], - self.named_type('builtins.tuple'), - t.line) + return TupleType( + [AnyType(TypeOfAny.from_error) for _ in t.items], + self.named_type("builtins.tuple"), + t.line, + ) else: return AnyType(TypeOfAny.from_error) any_type = AnyType(TypeOfAny.special_form) # If the fallback isn't filled in yet, its type will be the falsey FakeInfo - fallback = (t.partial_fallback if t.partial_fallback.type - else self.named_type('builtins.tuple', [any_type])) + fallback = ( + t.partial_fallback + if t.partial_fallback.type + else self.named_type("builtins.tuple", [any_type]) + ) return TupleType(self.anal_array(t.items), fallback, t.line) def visit_typeddict_type(self, t: TypedDictType) -> Type: - items = OrderedDict([ - (item_name, self.anal_type(item_type)) - for (item_name, item_type) in t.items.items() - ]) + items = OrderedDict( + [(item_name, self.anal_type(item_type)) for (item_name, item_type) in t.items.items()] + ) return TypedDictType(items, set(t.required_keys), t.fallback) def visit_raw_expression_type(self, t: RawExpressionType) -> Type: @@ -788,11 +924,11 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type: # instead. if self.report_invalid_types: - if t.base_type_name in ('builtins.int', 'builtins.bool'): + if t.base_type_name in ("builtins.int", "builtins.bool"): # The only time it makes sense to use an int or bool is inside of # a literal type. msg = f"Invalid type: try using Literal[{repr(t.literal_value)}] instead?" - elif t.base_type_name in ('builtins.float', 'builtins.complex'): + elif t.base_type_name in ("builtins.float", "builtins.complex"): # We special-case warnings for floats and complex numbers. msg = f"Invalid type: {t.simple_name()} literals cannot be used as a type" else: @@ -801,7 +937,7 @@ def visit_raw_expression_type(self, t: RawExpressionType) -> Type: # but not ints or bools is because whenever we see an out-of-place # string, it's unclear if the user meant to construct a literal type # or just misspelled a regular type. So we avoid guessing. - msg = 'Invalid type comment or annotation' + msg = "Invalid type comment or annotation" self.fail(msg, t, code=codes.VALID_TYPE) if t.note is not None: @@ -816,10 +952,12 @@ def visit_star_type(self, t: StarType) -> Type: return StarType(self.anal_type(t.type), t.line) def visit_union_type(self, t: UnionType) -> Type: - if (t.uses_pep604_syntax is True - and t.is_evaluated is True - and not self.always_allow_new_syntax - and not self.options.python_version >= (3, 10)): + if ( + t.uses_pep604_syntax is True + and t.is_evaluated is True + and not self.always_allow_new_syntax + and not self.options.python_version >= (3, 10) + ): self.fail("X | Y syntax for unions requires Python 3.10", t) return UnionType(self.anal_array(t.items), t.line) @@ -829,10 +967,9 @@ def visit_partial_type(self, t: PartialType) -> Type: def visit_ellipsis_type(self, t: EllipsisType) -> Type: if self.allow_param_spec_literals: any_type = AnyType(TypeOfAny.explicit) - return Parameters([any_type, any_type], - [ARG_STAR, ARG_STAR2], - [None, None], - is_ellipsis_args=True) + return Parameters( + [any_type, any_type], [ARG_STAR, ARG_STAR2], [None, None], is_ellipsis_args=True + ) else: self.fail('Unexpected "..."', t) return AnyType(TypeOfAny.from_error) @@ -851,10 +988,7 @@ def visit_placeholder_type(self, t: PlaceholderType) -> Type: return self.analyze_type_with_type_info(n.node, t.args, t) def analyze_callable_args_for_paramspec( - self, - callable_args: Type, - ret_type: Type, - fallback: Instance, + self, callable_args: Type, ret_type: Type, fallback: Instance ) -> Optional[CallableType]: """Construct a 'Callable[P, RET]', where P is ParamSpec, return None if we cannot.""" if not isinstance(callable_args, UnboundType): @@ -867,10 +1001,14 @@ def analyze_callable_args_for_paramspec( return None return CallableType( - [paramspec_args(tvar_def.name, tvar_def.fullname, tvar_def.id, - named_type_func=self.named_type), - paramspec_kwargs(tvar_def.name, tvar_def.fullname, tvar_def.id, - named_type_func=self.named_type)], + [ + paramspec_args( + tvar_def.name, tvar_def.fullname, tvar_def.id, named_type_func=self.named_type + ), + paramspec_kwargs( + tvar_def.name, tvar_def.fullname, tvar_def.id, named_type_func=self.named_type + ), + ], [nodes.ARG_STAR, nodes.ARG_STAR2], [None, None], ret_type=ret_type, @@ -878,10 +1016,7 @@ def analyze_callable_args_for_paramspec( ) def analyze_callable_args_for_concatenate( - self, - callable_args: Type, - ret_type: Type, - fallback: Instance, + self, callable_args: Type, ret_type: Type, fallback: Instance ) -> Optional[CallableType]: """Construct a 'Callable[C, RET]', where C is Concatenate[..., P], returning None if we cannot. @@ -893,7 +1028,7 @@ def analyze_callable_args_for_concatenate( return None if sym.node is None: return None - if sym.node.fullname not in ('typing_extensions.Concatenate', 'typing.Concatenate'): + if sym.node.fullname not in ("typing_extensions.Concatenate", "typing.Concatenate"): return None tvar_def = self.anal_type(callable_args, allow_param_spec=True) @@ -905,11 +1040,15 @@ def analyze_callable_args_for_concatenate( # we don't set the prefix here as generic arguments will get updated at some point # in the future. CallableType.param_spec() accounts for this. return CallableType( - [*prefix.arg_types, - paramspec_args(tvar_def.name, tvar_def.fullname, tvar_def.id, - named_type_func=self.named_type), - paramspec_kwargs(tvar_def.name, tvar_def.fullname, tvar_def.id, - named_type_func=self.named_type)], + [ + *prefix.arg_types, + paramspec_args( + tvar_def.name, tvar_def.fullname, tvar_def.id, named_type_func=self.named_type + ), + paramspec_kwargs( + tvar_def.name, tvar_def.fullname, tvar_def.id, named_type_func=self.named_type + ), + ], [*prefix.arg_kinds, nodes.ARG_STAR, nodes.ARG_STAR2], [*prefix.arg_names, None, None], ret_type=ret_type, @@ -918,7 +1057,7 @@ def analyze_callable_args_for_concatenate( ) def analyze_callable_type(self, t: UnboundType) -> Type: - fallback = self.named_type('builtins.function') + fallback = self.named_type("builtins.function") if len(t.args) == 0: # Callable (bare). Treat as Callable[..., Any]. any_type = self.get_omitted_any(t) @@ -932,35 +1071,27 @@ def analyze_callable_type(self, t: UnboundType) -> Type: if analyzed_args is None: return AnyType(TypeOfAny.from_error) args, kinds, names = analyzed_args - ret = CallableType(args, - kinds, - names, - ret_type=ret_type, - fallback=fallback) + ret = CallableType(args, kinds, names, ret_type=ret_type, fallback=fallback) elif isinstance(callable_args, EllipsisType): # Callable[..., RET] (with literal ellipsis; accept arbitrary arguments) - ret = callable_with_ellipsis(AnyType(TypeOfAny.explicit), - ret_type=ret_type, - fallback=fallback) + ret = callable_with_ellipsis( + AnyType(TypeOfAny.explicit), ret_type=ret_type, fallback=fallback + ) else: # Callable[P, RET] (where P is ParamSpec) maybe_ret = self.analyze_callable_args_for_paramspec( - callable_args, - ret_type, - fallback - ) or self.analyze_callable_args_for_concatenate( - callable_args, - ret_type, - fallback - ) + callable_args, ret_type, fallback + ) or self.analyze_callable_args_for_concatenate(callable_args, ret_type, fallback) if maybe_ret is None: # Callable[?, RET] (where ? is something invalid) self.fail( - 'The first argument to Callable must be a ' - 'list of types, parameter specification, or "..."', t) + "The first argument to Callable must be a " + 'list of types, parameter specification, or "..."', + t, + ) self.note( - 'See https://mypy.readthedocs.io/en/stable/kinds_of_types.html#callable-types-and-lambdas', # noqa: E501 - t + "See https://mypy.readthedocs.io/en/stable/kinds_of_types.html#callable-types-and-lambdas", # noqa: E501 + t, ) return AnyType(TypeOfAny.from_error) ret = maybe_ret @@ -973,9 +1104,9 @@ def analyze_callable_type(self, t: UnboundType) -> Type: assert isinstance(ret, CallableType) return ret.accept(self) - def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], - List[ArgKind], - List[Optional[str]]]]: + def analyze_callable_args( + self, arglist: TypeList + ) -> Optional[Tuple[List[Type], List[ArgKind], List[Optional[str]]]]: args: List[Type] = [] kinds: List[ArgKind] = [] names: List[Optional[str]] = [] @@ -997,8 +1128,9 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], kind = ARG_KINDS_BY_CONSTRUCTOR[found.fullname] kinds.append(kind) if arg.name is not None and kind.is_star(): - self.fail("{} arguments should not have names".format( - arg.constructor), arg) + self.fail( + "{} arguments should not have names".format(arg.constructor), arg + ) return None else: args.append(arg) @@ -1011,7 +1143,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], def analyze_literal_type(self, t: UnboundType) -> Type: if len(t.args) == 0: - self.fail('Literal[...] must have at least one parameter', t) + self.fail("Literal[...] must have at least one parameter", t) return AnyType(TypeOfAny.from_error) output: List[Type] = [] @@ -1027,12 +1159,14 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L # This UnboundType was originally defined as a string. if isinstance(arg, UnboundType) and arg.original_str_expr is not None: assert arg.original_str_fallback is not None - return [LiteralType( - value=arg.original_str_expr, - fallback=self.named_type_with_normalized_str(arg.original_str_fallback), - line=arg.line, - column=arg.column, - )] + return [ + LiteralType( + value=arg.original_str_expr, + fallback=self.named_type_with_normalized_str(arg.original_str_fallback), + line=arg.line, + column=arg.column, + ) + ] # If arg is an UnboundType that was *not* originally defined as # a string, try expanding it in case it's a type alias or something. @@ -1066,10 +1200,10 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L # A raw literal. Convert it directly into a literal if we can. if arg.literal_value is None: name = arg.simple_name() - if name in ('float', 'complex'): + if name in ("float", "complex"): msg = f'Parameter {idx} of Literal[...] cannot be of type "{name}"' else: - msg = 'Invalid type: Literal[...] cannot contain arbitrary expressions' + msg = "Invalid type: Literal[...] cannot contain arbitrary expressions" self.fail(msg, ctx) # Note: we deliberately ignore arg.note here: the extra info might normally be # helpful, but it generally won't make sense in the context of a Literal[...]. @@ -1094,7 +1228,7 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L out.extend(union_result) return out else: - self.fail(f'Parameter {idx} of Literal[...] is invalid', ctx) + self.fail(f"Parameter {idx} of Literal[...] is invalid", ctx) return None def analyze_type(self, t: Type) -> Type: @@ -1113,8 +1247,7 @@ def tvar_scope_frame(self) -> Iterator[None]: yield self.tvar_scope = old_scope - def infer_type_variables(self, - type: CallableType) -> List[Tuple[str, TypeVarLikeExpr]]: + def infer_type_variables(self, type: CallableType) -> List[Tuple[str, TypeVarLikeExpr]]: """Return list of unique type variables referred to in a callable.""" names: List[str] = [] tvars: List[TypeVarLikeExpr] = [] @@ -1151,8 +1284,9 @@ def bind_function_type_variables( return fun_type.variables typevars = self.infer_type_variables(fun_type) # Do not define a new type variable if already defined in scope. - typevars = [(name, tvar) for name, tvar in typevars - if not self.is_defined_type_var(name, defn)] + typevars = [ + (name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn) + ] defs: List[TypeVarLikeType] = [] for name, tvar in typevars: if not self.tvar_scope.allow_binding(tvar.fullname): @@ -1170,10 +1304,9 @@ def is_defined_type_var(self, tvar: str, context: Context) -> bool: return False return self.tvar_scope.get_binding(tvar_node) is not None - def anal_array(self, - a: Iterable[Type], - nested: bool = True, *, - allow_param_spec: bool = False) -> List[Type]: + def anal_array( + self, a: Iterable[Type], nested: bool = True, *, allow_param_spec: bool = False + ) -> List[Type]: res: List[Type] = [] for t in a: res.append(self.anal_type(t, nested, allow_param_spec=allow_param_spec)) @@ -1190,21 +1323,20 @@ def anal_type(self, t: Type, nested: bool = True, *, allow_param_spec: bool = Fa if nested: self.nesting_level -= 1 self.allow_required = old_allow_required - if (not allow_param_spec - and isinstance(analyzed, ParamSpecType) - and analyzed.flavor == ParamSpecFlavor.BARE): + if ( + not allow_param_spec + and isinstance(analyzed, ParamSpecType) + and analyzed.flavor == ParamSpecFlavor.BARE + ): if analyzed.prefix.arg_types: - self.fail('Invalid location for Concatenate', t) - self.note( - 'You can use Concatenate as the first argument to Callable', - t - ) + self.fail("Invalid location for Concatenate", t) + self.note("You can use Concatenate as the first argument to Callable", t) else: self.fail(f'Invalid location for ParamSpec "{analyzed.name}"', t) self.note( - 'You can use ParamSpec as the first argument to Callable, e.g., ' + "You can use ParamSpec as the first argument to Callable, e.g., " "'Callable[{}, int]'".format(analyzed.name), - t + t, ) return analyzed @@ -1217,7 +1349,7 @@ def anal_var_def(self, var_def: TypeVarLikeType) -> TypeVarLikeType: self.anal_array(var_def.values), var_def.upper_bound.accept(self), var_def.variance, - var_def.line + var_def.line, ) else: return var_def @@ -1230,25 +1362,29 @@ def named_type_with_normalized_str(self, fully_qualified_name: str) -> Instance: unalias `builtins.bytes` and `builtins.unicode` to `builtins.str` as appropriate. """ python_version = self.options.python_version - if python_version[0] == 2 and fully_qualified_name == 'builtins.bytes': - fully_qualified_name = 'builtins.str' - if python_version[0] >= 3 and fully_qualified_name == 'builtins.unicode': - fully_qualified_name = 'builtins.str' + if python_version[0] == 2 and fully_qualified_name == "builtins.bytes": + fully_qualified_name = "builtins.str" + if python_version[0] >= 3 and fully_qualified_name == "builtins.unicode": + fully_qualified_name = "builtins.str" return self.named_type(fully_qualified_name) - def named_type(self, fully_qualified_name: str, - args: Optional[List[Type]] = None, - line: int = -1, - column: int = -1) -> Instance: + def named_type( + self, + fully_qualified_name: str, + args: Optional[List[Type]] = None, + line: int = -1, + column: int = -1, + ) -> Instance: node = self.lookup_fqn_func(fully_qualified_name) assert isinstance(node.node, TypeInfo) any_type = AnyType(TypeOfAny.special_form) - return Instance(node.node, args or [any_type] * len(node.node.defn.type_vars), - line=line, column=column) + return Instance( + node.node, args or [any_type] * len(node.node.defn.type_vars), line=line, column=column + ) def tuple_type(self, items: List[Type]) -> TupleType: any_type = AnyType(TypeOfAny.special_form) - return TupleType(items, fallback=self.named_type('builtins.tuple', [any_type])) + return TupleType(items, fallback=self.named_type("builtins.tuple", [any_type])) @contextmanager def set_allow_param_spec_literals(self, to: bool) -> Iterator[None]: @@ -1264,27 +1400,30 @@ def set_allow_param_spec_literals(self, to: bool) -> Iterator[None]: class MsgCallback(Protocol): - def __call__( - self, - __msg: str, - __ctx: Context, - *, - code: Optional[ErrorCode] = None - ) -> None: ... - - -def get_omitted_any(disallow_any: bool, fail: MsgCallback, note: MsgCallback, - orig_type: Type, python_version: Tuple[int, int], - fullname: Optional[str] = None, - unexpanded_type: Optional[Type] = None) -> AnyType: + def __call__(self, __msg: str, __ctx: Context, *, code: Optional[ErrorCode] = None) -> None: + ... + + +def get_omitted_any( + disallow_any: bool, + fail: MsgCallback, + note: MsgCallback, + orig_type: Type, + python_version: Tuple[int, int], + fullname: Optional[str] = None, + unexpanded_type: Optional[Type] = None, +) -> AnyType: if disallow_any: nongen_builtins = get_nongen_builtins(python_version) if fullname in nongen_builtins: typ = orig_type # We use a dedicated error message for builtin generics (as the most common case). alternative = nongen_builtins[fullname] - fail(message_registry.IMPLICIT_GENERIC_ANY_BUILTIN.format(alternative), typ, - code=codes.TYPE_ARG) + fail( + message_registry.IMPLICIT_GENERIC_ANY_BUILTIN.format(alternative), + typ, + code=codes.TYPE_ARG, + ) else: typ = unexpanded_type or orig_type type_str = typ.name if isinstance(typ, UnboundType) else format_type_bare(typ) @@ -1292,7 +1431,8 @@ def get_omitted_any(disallow_any: bool, fail: MsgCallback, note: MsgCallback, fail( message_registry.BARE_GENERIC.format(quote_type_string(type_str)), typ, - code=codes.TYPE_ARG) + code=codes.TYPE_ARG, + ) base_type = get_proper_type(orig_type) base_fullname = ( base_type.type.fullname if isinstance(base_type, Instance) else fullname @@ -1307,7 +1447,8 @@ def get_omitted_any(disallow_any: bool, fail: MsgCallback, note: MsgCallback, "escaping, see https://mypy.readthedocs.io/en/stable/runtime_troubles.html" "#not-generic-runtime", typ, - code=codes.TYPE_ARG) + code=codes.TYPE_ARG, + ) any_type = AnyType(TypeOfAny.from_error, line=typ.line, column=typ.column) else: @@ -1317,10 +1458,15 @@ def get_omitted_any(disallow_any: bool, fail: MsgCallback, note: MsgCallback, return any_type -def fix_instance(t: Instance, fail: MsgCallback, note: MsgCallback, - disallow_any: bool, python_version: Tuple[int, int], - use_generic_error: bool = False, - unexpanded_type: Optional[Type] = None,) -> None: +def fix_instance( + t: Instance, + fail: MsgCallback, + note: MsgCallback, + disallow_any: bool, + python_version: Tuple[int, int], + use_generic_error: bool = False, + unexpanded_type: Optional[Type] = None, +) -> None: """Fix a malformed instance by replacing all type arguments with Any. Also emit a suitable error if this is not due to implicit Any's. @@ -1330,20 +1476,21 @@ def fix_instance(t: Instance, fail: MsgCallback, note: MsgCallback, fullname: Optional[str] = None else: fullname = t.type.fullname - any_type = get_omitted_any(disallow_any, fail, note, t, python_version, fullname, - unexpanded_type) + any_type = get_omitted_any( + disallow_any, fail, note, t, python_version, fullname, unexpanded_type + ) t.args = (any_type,) * len(t.type.type_vars) return # Invalid number of type parameters. n = len(t.type.type_vars) - s = f'{n} type arguments' + s = f"{n} type arguments" if n == 0: - s = 'no type arguments' + s = "no type arguments" elif n == 1: - s = '1 type argument' + s = "1 type argument" act = str(len(t.args)) - if act == '0': - act = 'none' + if act == "0": + act = "none" fail(f'"{t.type.name}" expects {s}, but {act} given', t, code=codes.TYPE_ARG) # Construct the correct number of type arguments, as # otherwise the type checker may crash as it expects @@ -1352,10 +1499,16 @@ def fix_instance(t: Instance, fail: MsgCallback, note: MsgCallback, t.invalid = True -def expand_type_alias(node: TypeAlias, args: List[Type], - fail: MsgCallback, no_args: bool, ctx: Context, *, - unexpanded_type: Optional[Type] = None, - disallow_any: bool = False) -> Type: +def expand_type_alias( + node: TypeAlias, + args: List[Type], + fail: MsgCallback, + no_args: bool, + ctx: Context, + *, + unexpanded_type: Optional[Type] = None, + disallow_any: bool = False, +) -> Type: """Expand a (generic) type alias target following the rules outlined in TypeAlias docstring. Here: @@ -1370,9 +1523,14 @@ def expand_type_alias(node: TypeAlias, args: List[Type], act_len = len(args) if exp_len > 0 and act_len == 0: # Interpret bare Alias same as normal generic, i.e., Alias[Any, Any, ...] - return set_any_tvars(node, ctx.line, ctx.column, - disallow_any=disallow_any, fail=fail, - unexpanded_type=unexpanded_type) + return set_any_tvars( + node, + ctx.line, + ctx.column, + disallow_any=disallow_any, + fail=fail, + unexpanded_type=unexpanded_type, + ) if exp_len == 0 and act_len == 0: if no_args: assert isinstance(node.target, Instance) # type: ignore[misc] @@ -1380,34 +1538,45 @@ def expand_type_alias(node: TypeAlias, args: List[Type], # no_args aliases like L = List in the docstring for TypeAlias class. return Instance(node.target.type, [], line=ctx.line, column=ctx.column) return TypeAliasType(node, [], line=ctx.line, column=ctx.column) - if (exp_len == 0 and act_len > 0 - and isinstance(node.target, Instance) # type: ignore[misc] - and no_args): + if ( + exp_len == 0 + and act_len > 0 + and isinstance(node.target, Instance) # type: ignore[misc] + and no_args + ): tp = Instance(node.target.type, args) tp.line = ctx.line tp.column = ctx.column return tp if act_len != exp_len: - fail('Bad number of arguments for type alias, expected: %s, given: %s' - % (exp_len, act_len), ctx) + fail( + "Bad number of arguments for type alias, expected: %s, given: %s" % (exp_len, act_len), + ctx, + ) return set_any_tvars(node, ctx.line, ctx.column, from_error=True) typ = TypeAliasType(node, args, ctx.line, ctx.column) assert typ.alias is not None # HACK: Implement FlexibleAlias[T, typ] by expanding it to typ here. - if (isinstance(typ.alias.target, Instance) # type: ignore - and typ.alias.target.type.fullname == 'mypy_extensions.FlexibleAlias'): + if ( + isinstance(typ.alias.target, Instance) # type: ignore + and typ.alias.target.type.fullname == "mypy_extensions.FlexibleAlias" + ): exp = get_proper_type(typ) assert isinstance(exp, Instance) return exp.args[-1] return typ -def set_any_tvars(node: TypeAlias, - newline: int, newcolumn: int, *, - from_error: bool = False, - disallow_any: bool = False, - fail: Optional[MsgCallback] = None, - unexpanded_type: Optional[Type] = None) -> Type: +def set_any_tvars( + node: TypeAlias, + newline: int, + newcolumn: int, + *, + from_error: bool = False, + disallow_any: bool = False, + fail: Optional[MsgCallback] = None, + unexpanded_type: Optional[Type] = None, +) -> Type: if from_error or disallow_any: type_of_any = TypeOfAny.from_error else: @@ -1417,8 +1586,11 @@ def set_any_tvars(node: TypeAlias, otype = unexpanded_type or node.target type_str = otype.name if isinstance(otype, UnboundType) else format_type_bare(otype) - fail(message_registry.BARE_GENERIC.format(quote_type_string(type_str)), - Context(newline, newcolumn), code=codes.TYPE_ARG) + fail( + message_registry.BARE_GENERIC.format(quote_type_string(type_str)), + Context(newline, newcolumn), + code=codes.TYPE_ARG, + ) any_type = AnyType(type_of_any, line=newline, column=newcolumn) return TypeAliasType(node, [any_type] * len(node.alias_tvars), newline, newcolumn) @@ -1441,12 +1613,14 @@ def flatten_tvars(ll: Iterable[List[T]]) -> List[T]: class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]): """Find TypeVar and ParamSpec references in an unbound type.""" - def __init__(self, - lookup: Callable[[str, Context], Optional[SymbolTableNode]], - scope: 'TypeVarLikeScope', - *, - include_callables: bool = True, - include_bound_tvars: bool = False) -> None: + def __init__( + self, + lookup: Callable[[str, Context], Optional[SymbolTableNode]], + scope: "TypeVarLikeScope", + *, + include_callables: bool = True, + include_bound_tvars: bool = False, + ) -> None: self.include_callables = include_callables self.lookup = lookup self.scope = scope @@ -1464,17 +1638,20 @@ def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList: name = t.name node = None # Special case P.args and P.kwargs for ParamSpecs only. - if name.endswith('args'): - if name.endswith('.args') or name.endswith('.kwargs'): - base = '.'.join(name.split('.')[:-1]) + if name.endswith("args"): + if name.endswith(".args") or name.endswith(".kwargs"): + base = ".".join(name.split(".")[:-1]) n = self.lookup(base, t) if n is not None and isinstance(n.node, ParamSpecExpr): node = n name = base if node is None: node = self.lookup(name, t) - if node and isinstance(node.node, TypeVarLikeExpr) and ( - self.include_bound_tvars or self.scope.get_binding(node) is None): + if ( + node + and isinstance(node.node, TypeVarLikeExpr) + and (self.include_bound_tvars or self.scope.get_binding(node) is None) + ): assert isinstance(node.node, TypeVarLikeExpr) return [(name, node.node)] elif not self.include_callables and self._seems_like_callable(t): @@ -1494,15 +1671,14 @@ def visit_callable_type(self, t: CallableType) -> TypeVarLikeList: return [] -def check_for_explicit_any(typ: Optional[Type], - options: Options, - is_typeshed_stub: bool, - msg: MessageBuilder, - context: Context) -> None: - if (options.disallow_any_explicit and - not is_typeshed_stub and - typ and - has_explicit_any(typ)): +def check_for_explicit_any( + typ: Optional[Type], + options: Options, + is_typeshed_stub: bool, + msg: MessageBuilder, + context: Context, +) -> None: + if options.disallow_any_explicit and not is_typeshed_stub and typ and has_explicit_any(typ): msg.explicit_any(context) @@ -1576,15 +1752,15 @@ def make_optional_type(t: Type) -> Type: if isinstance(t, NoneType): return t elif isinstance(t, UnionType): - items = [item for item in union_items(t) - if not isinstance(item, NoneType)] + items = [item for item in union_items(t) if not isinstance(item, NoneType)] return UnionType(items + [NoneType()], t.line, t.column) else: return UnionType([t, NoneType()], t.line, t.column) -def fix_instance_types(t: Type, fail: MsgCallback, note: MsgCallback, - python_version: Tuple[int, int]) -> None: +def fix_instance_types( + t: Type, fail: MsgCallback, note: MsgCallback, python_version: Tuple[int, int] +) -> None: """Recursively fix all instance types (type argument count) in a given type. For example 'Union[Dict, List[str, int]]' will be transformed into @@ -1604,5 +1780,11 @@ def __init__( def visit_instance(self, typ: Instance) -> None: super().visit_instance(typ) if len(typ.args) != len(typ.type.type_vars): - fix_instance(typ, self.fail, self.note, disallow_any=False, - python_version=self.python_version, use_generic_error=True) + fix_instance( + typ, + self.fail, + self.note, + disallow_any=False, + python_version=self.python_version, + use_generic_error=True, + ) diff --git a/mypy/typeops.py b/mypy/typeops.py index b47c3c246fd2c..7fc012fd3c78d 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,35 +5,70 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union -from typing_extensions import Type as TypingType import itertools import sys +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast + +from typing_extensions import Type as TypingType +from mypy.copytype import copy_type +from mypy.expandtype import expand_type, expand_type_by_instance +from mypy.maptype import map_instance_to_supertype +from mypy.nodes import ( + ARG_POS, + ARG_STAR, + ARG_STAR2, + SYMBOL_FUNCBASE_TYPES, + Decorator, + Expression, + FuncBase, + FuncDef, + FuncItem, + OverloadedFuncDef, + StrExpr, + TypeInfo, + Var, +) +from mypy.state import state from mypy.types import ( - TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded, - TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType, - AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types, - TypeAliasType, TypeQuery, ParamSpecType, Parameters, UnpackType, TypeVarTupleType, ENUM_REMOVED_PROPS, + AnyType, + CallableType, + FormalArgument, + FunctionLike, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + ProperType, + TupleType, + Type, + TypeAliasType, + TypeOfAny, + TypeQuery, + TypeType, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UninhabitedType, + UnionType, + UnpackType, + get_proper_type, + get_proper_types, ) -from mypy.nodes import ( - FuncBase, FuncItem, FuncDef, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS, - Expression, StrExpr, Var, Decorator, SYMBOL_FUNCBASE_TYPES -) -from mypy.maptype import map_instance_to_supertype -from mypy.expandtype import expand_type_by_instance, expand_type -from mypy.copytype import copy_type - from mypy.typevars import fill_typevars -from mypy.state import state - def is_recursive_pair(s: Type, t: Type) -> bool: """Is this a pair of recursive type aliases?""" - return (isinstance(s, TypeAliasType) and isinstance(t, TypeAliasType) and - s.is_recursive and t.is_recursive) + return ( + isinstance(s, TypeAliasType) + and isinstance(t, TypeAliasType) + and s.is_recursive + and t.is_recursive + ) def tuple_fallback(typ: TupleType) -> Instance: @@ -41,7 +76,7 @@ def tuple_fallback(typ: TupleType) -> Instance: from mypy.join import join_type_list info = typ.partial_fallback.type - if info.fullname != 'builtins.tuple': + if info.fullname != "builtins.tuple": return typ.partial_fallback items = [] for item in typ.items: @@ -61,11 +96,9 @@ def tuple_fallback(typ: TupleType) -> Instance: return Instance(info, [join_type_list(items)]) -def type_object_type_from_function(signature: FunctionLike, - info: TypeInfo, - def_info: TypeInfo, - fallback: Instance, - is_new: bool) -> FunctionLike: +def type_object_type_from_function( + signature: FunctionLike, info: TypeInfo, def_info: TypeInfo, fallback: Instance, is_new: bool +) -> FunctionLike: # We first need to record all non-trivial (explicit) self types in __init__, # since they will not be available after we bind them. Note, we use explicit # self-types only in the defining class, similar to __new__ (but not exactly the same, @@ -73,8 +106,14 @@ def type_object_type_from_function(signature: FunctionLike, # classes such as subprocess.Popen. default_self = fill_typevars(info) if not is_new and not info.is_newtype: - orig_self_types = [(it.arg_types[0] if it.arg_types and it.arg_types[0] != default_self - and it.arg_kinds[0] == ARG_POS else None) for it in signature.items] + orig_self_types = [ + ( + it.arg_types[0] + if it.arg_types and it.arg_types[0] != default_self and it.arg_kinds[0] == ARG_POS + else None + ) + for it in signature.items + ] else: orig_self_types = [None] * len(signature.items) @@ -92,9 +131,9 @@ def type_object_type_from_function(signature: FunctionLike, signature = cast(FunctionLike, map_type_from_supertype(signature, info, def_info)) special_sig: Optional[str] = None - if def_info.fullname == 'builtins.dict': + if def_info.fullname == "builtins.dict": # Special signature! - special_sig = 'dict' + special_sig = "dict" if isinstance(signature, CallableType): return class_callable(signature, info, fallback, special_sig, is_new, orig_self_types[0]) @@ -107,9 +146,14 @@ def type_object_type_from_function(signature: FunctionLike, return Overloaded(items) -def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, - special_sig: Optional[str], - is_new: bool, orig_self_type: Optional[Type] = None) -> CallableType: +def class_callable( + init_type: CallableType, + info: TypeInfo, + type_type: Instance, + special_sig: Optional[str], + is_new: bool, + orig_self_type: Optional[Type] = None, +) -> CallableType: """Create a type object type based on the signature of __init__.""" variables: List[TypeVarLikeType] = [] variables.extend(info.defn.type_vars) @@ -136,15 +180,17 @@ def class_callable(init_type: CallableType, info: TypeInfo, type_type: Instance, ret_type = default_ret_type callable_type = init_type.copy_modified( - ret_type=ret_type, fallback=type_type, name=None, variables=variables, - special_sig=special_sig) + ret_type=ret_type, + fallback=type_type, + name=None, + variables=variables, + special_sig=special_sig, + ) c = callable_type.with_name(info.name) return c -def map_type_from_supertype(typ: Type, - sub_info: TypeInfo, - super_info: TypeInfo) -> Type: +def map_type_from_supertype(typ: Type, sub_info: TypeInfo, super_info: TypeInfo) -> Type: """Map type variables in a type defined in a supertype context to be valid in the subtype context. Assume that the result is unique; if more than one type is possible, return one of the alternatives. @@ -181,11 +227,12 @@ def supported_self_type(typ: ProperType) -> bool: """ if isinstance(typ, TypeType): return supported_self_type(typ.item) - return (isinstance(typ, TypeVarType) or - (isinstance(typ, Instance) and typ != fill_typevars(typ.type))) + return isinstance(typ, TypeVarType) or ( + isinstance(typ, Instance) and typ != fill_typevars(typ.type) + ) -F = TypeVar('F', bound=FunctionLike) +F = TypeVar("F", bound=FunctionLike) def bind_self(method: F, original_type: Optional[Type] = None, is_classmethod: bool = False) -> F: @@ -212,8 +259,9 @@ class B(A): pass """ if isinstance(method, Overloaded): - return cast(F, Overloaded([bind_self(c, original_type, is_classmethod) - for c in method.items])) + return cast( + F, Overloaded([bind_self(c, original_type, is_classmethod) for c in method.items]) + ) assert isinstance(method, CallableType) func = method if not func.arg_types: @@ -238,18 +286,19 @@ class B(A): pass original_type = get_proper_type(original_type) all_ids = func.type_var_ids() - typeargs = infer_type_arguments(all_ids, self_param_type, original_type, - is_supertype=True) - if (is_classmethod - # TODO: why do we need the extra guards here? - and any(isinstance(get_proper_type(t), UninhabitedType) for t in typeargs) - and isinstance(original_type, (Instance, TypeVarType, TupleType))): + typeargs = infer_type_arguments(all_ids, self_param_type, original_type, is_supertype=True) + if ( + is_classmethod + # TODO: why do we need the extra guards here? + and any(isinstance(get_proper_type(t), UninhabitedType) for t in typeargs) + and isinstance(original_type, (Instance, TypeVarType, TupleType)) + ): # In case we call a classmethod through an instance x, fallback to type(x) - typeargs = infer_type_arguments(all_ids, self_param_type, TypeType(original_type), - is_supertype=True) + typeargs = infer_type_arguments( + all_ids, self_param_type, TypeType(original_type), is_supertype=True + ) - ids = [tid for tid in all_ids - if any(tid == t.id for t in get_type_vars(self_param_type))] + ids = [tid for tid in all_ids if any(tid == t.id for t in get_type_vars(self_param_type))] # Technically, some constrains might be unsolvable, make them . to_apply = [t if t is not None else UninhabitedType() for t in typeargs] @@ -268,12 +317,14 @@ def expand(target: Type) -> Type: original_type = get_proper_type(original_type) if isinstance(original_type, CallableType) and original_type.is_type_obj(): original_type = TypeType.make_normalized(original_type.ret_type) - res = func.copy_modified(arg_types=arg_types, - arg_kinds=func.arg_kinds[1:], - arg_names=func.arg_names[1:], - variables=variables, - ret_type=ret_type, - bound_args=[original_type]) + res = func.copy_modified( + arg_types=arg_types, + arg_kinds=func.arg_kinds[1:], + arg_names=func.arg_names[1:], + variables=variables, + ret_type=ret_type, + bound_args=[original_type], + ) return cast(F, res) @@ -288,8 +339,9 @@ def erase_to_bound(t: Type) -> Type: return t -def callable_corresponding_argument(typ: Union[CallableType, Parameters], - model: FormalArgument) -> Optional[FormalArgument]: +def callable_corresponding_argument( + typ: Union[CallableType, Parameters], model: FormalArgument +) -> Optional[FormalArgument]: """Return the argument a function that corresponds to `model`""" by_name = typ.argument_by_name(model.name) @@ -307,10 +359,12 @@ def callable_corresponding_argument(typ: Union[CallableType, Parameters], # def left(__a: int = ..., *, a: int = ...) -> None: ... from mypy.subtypes import is_equivalent - if (not (by_name.required or by_pos.required) - and by_pos.name is None - and by_name.pos is None - and is_equivalent(by_name.typ, by_pos.typ)): + if ( + not (by_name.required or by_pos.required) + and by_pos.name is None + and by_name.pos is None + and is_equivalent(by_name.typ, by_pos.typ) + ): return FormalArgument(by_name.name, by_pos.pos, by_name.typ, False) return by_name if by_name is not None else by_pos @@ -325,12 +379,12 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]: Instance with string last_known_value are supported. """ if isinstance(t, LiteralType): - if t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str': + if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str": assert isinstance(t.value, str) - return 'literal', t.value, t.fallback.type.fullname + return "literal", t.value, t.fallback.type.fullname if isinstance(t, Instance): if t.last_known_value is not None and isinstance(t.last_known_value.value, str): - return 'instance', t.last_known_value.value, t.type.fullname + return "instance", t.last_known_value.value, t.type.fullname return None @@ -346,16 +400,20 @@ def simple_literal_type(t: Optional[ProperType]) -> Optional[Instance]: def is_simple_literal(t: ProperType) -> bool: """Fast way to check if simple_literal_value_key() would return a non-None value.""" if isinstance(t, LiteralType): - return t.fallback.type.is_enum or t.fallback.type.fullname == 'builtins.str' + return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str" if isinstance(t, Instance): return t.last_known_value is not None and isinstance(t.last_known_value.value, str) return False -def make_simplified_union(items: Sequence[Type], - line: int = -1, column: int = -1, - *, keep_erased: bool = False, - contract_literals: bool = True) -> ProperType: +def make_simplified_union( + items: Sequence[Type], + line: int = -1, + column: int = -1, + *, + keep_erased: bool = False, + contract_literals: bool = True, +) -> ProperType: """Build union type with redundant union items removed. If only a single item remains, this may return a non-union type. @@ -449,9 +507,8 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> ): continue # actual redundancy checks - if ( - is_redundant_literal_instance(item, tj) # XXX? - and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + if is_redundant_literal_instance(item, tj) and is_proper_subtype( # XXX? + tj, item, keep_erased_types=keep_erased ): # We found a redundant item in the union. removed.add(j) @@ -591,24 +648,29 @@ def function_type(func: FuncBase, fallback: Instance) -> FunctionLike: # TODO: should we instead always set the type in semantic analyzer? assert isinstance(func, OverloadedFuncDef) any_type = AnyType(TypeOfAny.from_error) - dummy = CallableType([any_type, any_type], - [ARG_STAR, ARG_STAR2], - [None, None], any_type, - fallback, - line=func.line, is_ellipsis_args=True) + dummy = CallableType( + [any_type, any_type], + [ARG_STAR, ARG_STAR2], + [None, None], + any_type, + fallback, + line=func.line, + is_ellipsis_args=True, + ) # Return an Overloaded, because some callers may expect that # an OverloadedFuncDef has an Overloaded type. return Overloaded([dummy]) -def callable_type(fdef: FuncItem, fallback: Instance, - ret_type: Optional[Type] = None) -> CallableType: +def callable_type( + fdef: FuncItem, fallback: Instance, ret_type: Optional[Type] = None +) -> CallableType: # TODO: somewhat unfortunate duplication with prepare_method_signature in semanal if fdef.info and not fdef.is_static and fdef.arg_names: self_type: Type = fill_typevars(fdef.info) - if fdef.is_class or fdef.name == '__new__': + if fdef.is_class or fdef.name == "__new__": self_type = TypeType.make_normalized(self_type) - args = [self_type] + [AnyType(TypeOfAny.unannotated)] * (len(fdef.arg_names)-1) + args = [self_type] + [AnyType(TypeOfAny.unannotated)] * (len(fdef.arg_names) - 1) else: args = [AnyType(TypeOfAny.unannotated)] * len(fdef.arg_names) @@ -668,12 +730,12 @@ def try_getting_int_literals_from_type(typ: Type) -> Optional[List[int]]: return try_getting_literals_from_type(typ, int, "builtins.int") -T = TypeVar('T') +T = TypeVar("T") -def try_getting_literals_from_type(typ: Type, - target_literal_type: TypingType[T], - target_fullname: str) -> Optional[List[T]]: +def try_getting_literals_from_type( + typ: Type, target_literal_type: TypingType[T], target_fullname: str +) -> Optional[List[T]]: """If the given expression or type corresponds to a Literal or union of Literals where the underlying values correspond to the given target type, returns a list of those underlying values. Otherwise, @@ -713,8 +775,9 @@ def is_literal_type_like(t: Optional[Type]) -> bool: elif isinstance(t, UnionType): return any(is_literal_type_like(item) for item in t.items) elif isinstance(t, TypeVarType): - return (is_literal_type_like(t.upper_bound) - or any(is_literal_type_like(item) for item in t.values)) + return is_literal_type_like(t.upper_bound) or any( + is_literal_type_like(item) for item in t.values + ) else: return False @@ -762,8 +825,7 @@ class Status(Enum): if isinstance(typ, UnionType): items = [ - try_expanding_sum_type_to_union(item, target_fullname) - for item in typ.relevant_items() + try_expanding_sum_type_to_union(item, target_fullname) for item in typ.relevant_items() ] return make_simplified_union(items, contract_literals=False) elif isinstance(typ, Instance) and typ.type.fullname == target_fullname: @@ -787,8 +849,7 @@ class Status(Enum): return make_simplified_union(new_items, contract_literals=False) elif typ.type.fullname == "builtins.bool": return make_simplified_union( - [LiteralType(True, typ), LiteralType(False, typ)], - contract_literals=False + [LiteralType(True, typ), LiteralType(False, typ)], contract_literals=False ) return typ @@ -813,10 +874,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType] fullname = typ.fallback.type.fullname if typ.fallback.type.is_enum or isinstance(typ.value, bool): if fullname not in sum_types: - sum_types[fullname] = (set(typ.fallback.get_enum_values()) - if typ.fallback.type.is_enum - else {True, False}, - []) + sum_types[fullname] = ( + set(typ.fallback.get_enum_values()) + if typ.fallback.type.is_enum + else {True, False}, + [], + ) literals, indexes = sum_types[fullname] literals.discard(typ.value) indexes.append(idx) @@ -824,8 +887,11 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType] first, *rest = indexes proper_types[first] = typ.fallback marked_for_deletion |= set(rest) - return list(itertools.compress(proper_types, [(i not in marked_for_deletion) - for i in range(len(proper_types))])) + return list( + itertools.compress( + proper_types, [(i not in marked_for_deletion) for i in range(len(proper_types))] + ) + ) def coerce_to_literal(typ: Type) -> Type: @@ -875,7 +941,7 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool method = typ.type.get(name) if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): if method.node.info: - return not method.node.info.fullname.startswith('builtins.') + return not method.node.info.fullname.startswith("builtins.") return False if isinstance(typ, UnionType): if check_all: diff --git a/mypy/types.py b/mypy/types.py index a70c6885dff51..5d830d8091d61 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2,25 +2,40 @@ import sys from abc import abstractmethod - from typing import ( - Any, TypeVar, Dict, List, Tuple, cast, Set, Optional, Union, Iterable, NamedTuple, - Sequence + Any, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, ) -from typing_extensions import ClassVar, Final, TYPE_CHECKING, overload, TypeAlias as _TypeAlias -from mypy.backports import OrderedDict +from typing_extensions import TYPE_CHECKING, ClassVar, Final, TypeAlias as _TypeAlias, overload + import mypy.nodes -from mypy.state import state +from mypy.backports import OrderedDict +from mypy.bogus_type import Bogus from mypy.nodes import ( - INVARIANT, SymbolNode, FuncDef, FakeInfo, - ArgKind, ARG_POS, ARG_STAR, ARG_STAR2, + ARG_POS, + ARG_STAR, + ARG_STAR2, + INVARIANT, + ArgKind, + FakeInfo, + FuncDef, + SymbolNode, ) +from mypy.state import state from mypy.util import IdMapper -from mypy.bogus_type import Bogus - -T = TypeVar('T') +T = TypeVar("T") JsonDict: _TypeAlias = Dict[str, Any] @@ -62,14 +77,11 @@ # semantic analyzer! if TYPE_CHECKING: from mypy.type_visitor import ( - TypeVisitor as TypeVisitor, SyntheticTypeVisitor as SyntheticTypeVisitor, + TypeVisitor as TypeVisitor, ) -TYPED_NAMEDTUPLE_NAMES: Final = ( - "typing.NamedTuple", - "typing_extensions.NamedTuple" -) +TYPED_NAMEDTUPLE_NAMES: Final = ("typing.NamedTuple", "typing_extensions.NamedTuple") # Supported names of TypedDict type constructors. TPDICT_NAMES: Final = ( @@ -86,80 +98,52 @@ ) # Supported names of Protocol base class. -PROTOCOL_NAMES: Final = ( - 'typing.Protocol', - 'typing_extensions.Protocol', -) +PROTOCOL_NAMES: Final = ("typing.Protocol", "typing_extensions.Protocol") # Supported TypeAlias names. -TYPE_ALIAS_NAMES: Final = ( - "typing.TypeAlias", - "typing_extensions.TypeAlias", -) +TYPE_ALIAS_NAMES: Final = ("typing.TypeAlias", "typing_extensions.TypeAlias") # Supported Final type names. -FINAL_TYPE_NAMES: Final = ( - 'typing.Final', - 'typing_extensions.Final', -) +FINAL_TYPE_NAMES: Final = ("typing.Final", "typing_extensions.Final") # Supported @final decorator names. -FINAL_DECORATOR_NAMES: Final = ( - 'typing.final', - 'typing_extensions.final', -) +FINAL_DECORATOR_NAMES: Final = ("typing.final", "typing_extensions.final") # Supported Literal type names. -LITERAL_TYPE_NAMES: Final = ( - 'typing.Literal', - 'typing_extensions.Literal', -) +LITERAL_TYPE_NAMES: Final = ("typing.Literal", "typing_extensions.Literal") # Supported Annotated type names. -ANNOTATED_TYPE_NAMES: Final = ( - 'typing.Annotated', - 'typing_extensions.Annotated', -) +ANNOTATED_TYPE_NAMES: Final = ("typing.Annotated", "typing_extensions.Annotated") # We use this constant in various places when checking `tuple` subtyping: TUPLE_LIKE_INSTANCE_NAMES: Final = ( - 'builtins.tuple', - 'typing.Iterable', - 'typing.Container', - 'typing.Sequence', - 'typing.Reversible', + "builtins.tuple", + "typing.Iterable", + "typing.Container", + "typing.Sequence", + "typing.Reversible", ) REVEAL_TYPE_NAMES: Final = ( - 'builtins.reveal_type', - 'typing.reveal_type', - 'typing_extensions.reveal_type', + "builtins.reveal_type", + "typing.reveal_type", + "typing_extensions.reveal_type", ) -ASSERT_TYPE_NAMES: Final = ( - 'typing.assert_type', - 'typing_extensions.assert_type', -) +ASSERT_TYPE_NAMES: Final = ("typing.assert_type", "typing_extensions.assert_type") -OVERLOAD_NAMES: Final = ( - 'typing.overload', - 'typing_extensions.overload', -) +OVERLOAD_NAMES: Final = ("typing.overload", "typing_extensions.overload") # Attributes that can optionally be defined in the body of a subclass of # enum.Enum but are removed from the class __dict__ by EnumMeta. -ENUM_REMOVED_PROPS: Final = ( - '_ignore_', - '_order_', - '__order__', -) +ENUM_REMOVED_PROPS: Final = ("_ignore_", "_order_", "__order__") NEVER_NAMES: Final = ( - 'typing.NoReturn', - 'typing_extensions.NoReturn', - 'mypy_extensions.NoReturn', - 'typing.Never', - 'typing_extensions.Never', + "typing.NoReturn", + "typing_extensions.NoReturn", + "mypy_extensions.NoReturn", + "typing.Never", + "typing_extensions.Never", ) # A placeholder used for Bogus[...] parameters @@ -197,20 +181,20 @@ class TypeOfAny: suggestion_engine: Final = 9 -def deserialize_type(data: Union[JsonDict, str]) -> 'Type': +def deserialize_type(data: Union[JsonDict, str]) -> "Type": if isinstance(data, str): return Instance.deserialize(data) - classname = data['.class'] + classname = data[".class"] method = deserialize_map.get(classname) if method is not None: return method(data) - raise NotImplementedError(f'unexpected .class {classname}') + raise NotImplementedError(f"unexpected .class {classname}") class Type(mypy.nodes.Context): """Abstract base class for all types.""" - __slots__ = ('can_be_true', 'can_be_false') + __slots__ = ("can_be_true", "can_be_false") # 'can_be_true' and 'can_be_false' mean whether the value of the # expression can be true or false in a boolean context. They are useful # when inferring the type of logic expressions like `x and y`. @@ -232,18 +216,18 @@ def can_be_true_default(self) -> bool: def can_be_false_default(self) -> bool: return True - def accept(self, visitor: 'TypeVisitor[T]') -> T: - raise RuntimeError('Not implemented') + def accept(self, visitor: "TypeVisitor[T]") -> T: + raise RuntimeError("Not implemented") def __repr__(self) -> str: return self.accept(TypeStrVisitor()) def serialize(self) -> Union[JsonDict, str]: - raise NotImplementedError(f'Cannot serialize {self.__class__.__name__} instance') + raise NotImplementedError(f"Cannot serialize {self.__class__.__name__} instance") @classmethod - def deserialize(cls, data: JsonDict) -> 'Type': - raise NotImplementedError(f'Cannot deserialize {cls.__name__} instance') + def deserialize(cls, data: JsonDict) -> "Type": + raise NotImplementedError(f"Cannot deserialize {cls.__name__} instance") def is_singleton_type(self) -> bool: return False @@ -266,10 +250,15 @@ class Node: can be represented in a tree-like manner. """ - __slots__ = ('alias', 'args', 'type_ref') + __slots__ = ("alias", "args", "type_ref") - def __init__(self, alias: Optional[mypy.nodes.TypeAlias], args: List[Type], - line: int = -1, column: int = -1) -> None: + def __init__( + self, + alias: Optional[mypy.nodes.TypeAlias], + args: List[Type], + line: int = -1, + column: int = -1, + ) -> None: self.alias = alias self.args = args self.type_ref: Optional[str] = None @@ -288,17 +277,18 @@ def _expand_once(self) -> Type: # as their target. assert isinstance(self.alias.target, Instance) # type: ignore[misc] return self.alias.target.copy_modified(args=self.args) - return replace_alias_tvars(self.alias.target, self.alias.alias_tvars, self.args, - self.line, self.column) + return replace_alias_tvars( + self.alias.target, self.alias.alias_tvars, self.args, self.line, self.column + ) - def _partial_expansion(self) -> Tuple['ProperType', bool]: + def _partial_expansion(self) -> Tuple["ProperType", bool]: # Private method mostly for debugging and testing. unroller = UnrollAliasVisitor(set()) unrolled = self.accept(unroller) assert isinstance(unrolled, ProperType) return unrolled, unroller.recursed - def expand_all_if_possible(self) -> Optional['ProperType']: + def expand_all_if_possible(self) -> Optional["ProperType"]: """Attempt a full expansion of the type alias (including nested aliases). If the expansion is not possible, i.e. the alias is (mutually-)recursive, @@ -311,7 +301,7 @@ def expand_all_if_possible(self) -> Optional['ProperType']: @property def is_recursive(self) -> bool: - assert self.alias is not None, 'Unfixed type alias' + assert self.alias is not None, "Unfixed type alias" is_recursive = self.alias._is_recursive if is_recursive is None: is_recursive = self.expand_all_if_possible() is None @@ -330,7 +320,7 @@ def can_be_false_default(self) -> bool: return self.alias.target.can_be_false return super().can_be_false_default() - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_type_alias_type(self) def __hash__(self) -> int: @@ -340,8 +330,7 @@ def __eq__(self, other: object) -> bool: # Note: never use this to determine subtype relationships, use is_subtype(). if not isinstance(other, TypeAliasType): return NotImplemented - return (self.alias == other.alias - and self.args == other.args) + return self.alias == other.alias and self.args == other.args def serialize(self) -> JsonDict: assert self.alias is not None @@ -353,29 +342,27 @@ def serialize(self) -> JsonDict: return data @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeAliasType': - assert data['.class'] == 'TypeAliasType' + def deserialize(cls, data: JsonDict) -> "TypeAliasType": + assert data[".class"] == "TypeAliasType" args: List[Type] = [] - if 'args' in data: - args_list = data['args'] + if "args" in data: + args_list = data["args"] assert isinstance(args_list, list) args = [deserialize_type(arg) for arg in args_list] alias = TypeAliasType(None, args) - alias.type_ref = data['type_ref'] + alias.type_ref = data["type_ref"] return alias - def copy_modified(self, *, - args: Optional[List[Type]] = None) -> 'TypeAliasType': + def copy_modified(self, *, args: Optional[List[Type]] = None) -> "TypeAliasType": return TypeAliasType( - self.alias, - args if args is not None else self.args.copy(), - self.line, self.column) + self.alias, args if args is not None else self.args.copy(), self.line, self.column + ) class TypeGuardedType(Type): """Only used by find_isinstance_check() etc.""" - __slots__ = ('type_guard',) + __slots__ = ("type_guard",) def __init__(self, type_guard: Type): super().__init__(line=type_guard.line, column=type_guard.column) @@ -399,7 +386,7 @@ def __repr__(self) -> str: else: return f"NotRequired[{self.item}]" - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return self.item.accept(visitor) @@ -438,13 +425,13 @@ class TypeVarId: # definition!), or '' namespace: str - def __init__(self, raw_id: int, meta_level: int = 0, *, namespace: str = '') -> None: + def __init__(self, raw_id: int, meta_level: int = 0, *, namespace: str = "") -> None: self.raw_id = raw_id self.meta_level = meta_level self.namespace = namespace @staticmethod - def new(meta_level: int) -> 'TypeVarId': + def new(meta_level: int) -> "TypeVarId": raw_id = TypeVarId.next_raw_id TypeVarId.next_raw_id += 1 return TypeVarId(raw_id, meta_level) @@ -454,9 +441,11 @@ def __repr__(self) -> str: def __eq__(self, other: object) -> bool: if isinstance(other, TypeVarId): - return (self.raw_id == other.raw_id and - self.meta_level == other.meta_level and - self.namespace == other.namespace) + return ( + self.raw_id == other.raw_id + and self.meta_level == other.meta_level + and self.namespace == other.namespace + ) else: return False @@ -472,7 +461,7 @@ def is_meta_var(self) -> bool: class TypeVarLikeType(ProperType): - __slots__ = ('name', 'fullname', 'id', 'upper_bound') + __slots__ = ("name", "fullname", "id", "upper_bound") name: str # Name (may be qualified) fullname: str # Fully qualified name @@ -480,8 +469,13 @@ class TypeVarLikeType(ProperType): upper_bound: Type def __init__( - self, name: str, fullname: str, id: Union[TypeVarId, int], upper_bound: Type, - line: int = -1, column: int = -1 + self, + name: str, + fullname: str, + id: Union[TypeVarId, int], + upper_bound: Type, + line: int = -1, + column: int = -1, ) -> None: super().__init__(line, column) self.name = name @@ -495,33 +489,49 @@ def serialize(self) -> JsonDict: raise NotImplementedError @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeVarLikeType': + def deserialize(cls, data: JsonDict) -> "TypeVarLikeType": raise NotImplementedError class TypeVarType(TypeVarLikeType): """Type that refers to a type variable.""" - __slots__ = ('values', 'variance') + __slots__ = ("values", "variance") values: List[Type] # Value restriction, empty list if no restriction variance: int - def __init__(self, name: str, fullname: str, id: Union[TypeVarId, int], values: List[Type], - upper_bound: Type, variance: int = INVARIANT, line: int = -1, - column: int = -1) -> None: + def __init__( + self, + name: str, + fullname: str, + id: Union[TypeVarId, int], + values: List[Type], + upper_bound: Type, + variance: int = INVARIANT, + line: int = -1, + column: int = -1, + ) -> None: super().__init__(name, fullname, id, upper_bound, line, column) assert values is not None, "No restrictions must be represented by empty list" self.values = values self.variance = variance @staticmethod - def new_unification_variable(old: 'TypeVarType') -> 'TypeVarType': + def new_unification_variable(old: "TypeVarType") -> "TypeVarType": new_id = TypeVarId.new(meta_level=1) - return TypeVarType(old.name, old.fullname, new_id, old.values, - old.upper_bound, old.variance, old.line, old.column) + return TypeVarType( + old.name, + old.fullname, + new_id, + old.values, + old.upper_bound, + old.variance, + old.line, + old.column, + ) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_type_var(self) def __hash__(self) -> int: @@ -534,26 +544,27 @@ def __eq__(self, other: object) -> bool: def serialize(self) -> JsonDict: assert not self.id.is_meta_var() - return {'.class': 'TypeVarType', - 'name': self.name, - 'fullname': self.fullname, - 'id': self.id.raw_id, - 'namespace': self.id.namespace, - 'values': [v.serialize() for v in self.values], - 'upper_bound': self.upper_bound.serialize(), - 'variance': self.variance, - } + return { + ".class": "TypeVarType", + "name": self.name, + "fullname": self.fullname, + "id": self.id.raw_id, + "namespace": self.id.namespace, + "values": [v.serialize() for v in self.values], + "upper_bound": self.upper_bound.serialize(), + "variance": self.variance, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeVarType': - assert data['.class'] == 'TypeVarType' + def deserialize(cls, data: JsonDict) -> "TypeVarType": + assert data[".class"] == "TypeVarType" return TypeVarType( - data['name'], - data['fullname'], - TypeVarId(data['id'], namespace=data['namespace']), - [deserialize_type(v) for v in data['values']], - deserialize_type(data['upper_bound']), - data['variance'], + data["name"], + data["fullname"], + TypeVarId(data["id"], namespace=data["namespace"]), + [deserialize_type(v) for v in data["values"]], + deserialize_type(data["upper_bound"]), + data["variance"], ) @@ -584,34 +595,58 @@ class ParamSpecType(TypeVarLikeType): always just 'object'). """ - __slots__ = ('flavor', 'prefix') + __slots__ = ("flavor", "prefix") flavor: int - prefix: 'Parameters' + prefix: "Parameters" def __init__( - self, name: str, fullname: str, id: Union[TypeVarId, int], flavor: int, - upper_bound: Type, *, line: int = -1, column: int = -1, - prefix: Optional['Parameters'] = None + self, + name: str, + fullname: str, + id: Union[TypeVarId, int], + flavor: int, + upper_bound: Type, + *, + line: int = -1, + column: int = -1, + prefix: Optional["Parameters"] = None, ) -> None: super().__init__(name, fullname, id, upper_bound, line=line, column=column) self.flavor = flavor self.prefix = prefix or Parameters([], [], []) @staticmethod - def new_unification_variable(old: 'ParamSpecType') -> 'ParamSpecType': + def new_unification_variable(old: "ParamSpecType") -> "ParamSpecType": new_id = TypeVarId.new(meta_level=1) - return ParamSpecType(old.name, old.fullname, new_id, old.flavor, old.upper_bound, - line=old.line, column=old.column, prefix=old.prefix) + return ParamSpecType( + old.name, + old.fullname, + new_id, + old.flavor, + old.upper_bound, + line=old.line, + column=old.column, + prefix=old.prefix, + ) - def with_flavor(self, flavor: int) -> 'ParamSpecType': - return ParamSpecType(self.name, self.fullname, self.id, flavor, - upper_bound=self.upper_bound, prefix=self.prefix) + def with_flavor(self, flavor: int) -> "ParamSpecType": + return ParamSpecType( + self.name, + self.fullname, + self.id, + flavor, + upper_bound=self.upper_bound, + prefix=self.prefix, + ) - def copy_modified(self, *, - id: Bogus[Union[TypeVarId, int]] = _dummy, - flavor: Bogus[int] = _dummy, - prefix: Bogus['Parameters'] = _dummy) -> 'ParamSpecType': + def copy_modified( + self, + *, + id: Bogus[Union[TypeVarId, int]] = _dummy, + flavor: Bogus[int] = _dummy, + prefix: Bogus["Parameters"] = _dummy, + ) -> "ParamSpecType": return ParamSpecType( self.name, self.fullname, @@ -623,15 +658,15 @@ def copy_modified(self, *, prefix=prefix if prefix is not _dummy else self.prefix, ) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_param_spec(self) def name_with_suffix(self) -> str: n = self.name if self.flavor == ParamSpecFlavor.ARGS: - return f'{n}.args' + return f"{n}.args" elif self.flavor == ParamSpecFlavor.KWARGS: - return f'{n}.kwargs' + return f"{n}.kwargs" return n def __hash__(self) -> int: @@ -646,25 +681,25 @@ def __eq__(self, other: object) -> bool: def serialize(self) -> JsonDict: assert not self.id.is_meta_var() return { - '.class': 'ParamSpecType', - 'name': self.name, - 'fullname': self.fullname, - 'id': self.id.raw_id, - 'flavor': self.flavor, - 'upper_bound': self.upper_bound.serialize(), - 'prefix': self.prefix.serialize() + ".class": "ParamSpecType", + "name": self.name, + "fullname": self.fullname, + "id": self.id.raw_id, + "flavor": self.flavor, + "upper_bound": self.upper_bound.serialize(), + "prefix": self.prefix.serialize(), } @classmethod - def deserialize(cls, data: JsonDict) -> 'ParamSpecType': - assert data['.class'] == 'ParamSpecType' + def deserialize(cls, data: JsonDict) -> "ParamSpecType": + assert data[".class"] == "ParamSpecType" return ParamSpecType( - data['name'], - data['fullname'], - data['id'], - data['flavor'], - deserialize_type(data['upper_bound']), - prefix=Parameters.deserialize(data['prefix']) + data["name"], + data["fullname"], + data["id"], + data["flavor"], + deserialize_type(data["upper_bound"]), + prefix=Parameters.deserialize(data["prefix"]), ) @@ -673,26 +708,25 @@ class TypeVarTupleType(TypeVarLikeType): See PEP646 for more information. """ + def serialize(self) -> JsonDict: assert not self.id.is_meta_var() - return {'.class': 'TypeVarTupleType', - 'name': self.name, - 'fullname': self.fullname, - 'id': self.id.raw_id, - 'upper_bound': self.upper_bound.serialize(), - } + return { + ".class": "TypeVarTupleType", + "name": self.name, + "fullname": self.fullname, + "id": self.id.raw_id, + "upper_bound": self.upper_bound.serialize(), + } @classmethod - def deserialize(cls, data: JsonDict) -> 'TypeVarTupleType': - assert data['.class'] == 'TypeVarTupleType' + def deserialize(cls, data: JsonDict) -> "TypeVarTupleType": + assert data[".class"] == "TypeVarTupleType" return TypeVarTupleType( - data['name'], - data['fullname'], - data['id'], - deserialize_type(data['upper_bound']), + data["name"], data["fullname"], data["id"], deserialize_type(data["upper_bound"]) ) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_type_var_tuple(self) def __hash__(self) -> int: @@ -704,28 +738,36 @@ def __eq__(self, other: object) -> bool: return self.id == other.id @staticmethod - def new_unification_variable(old: 'TypeVarTupleType') -> 'TypeVarTupleType': + def new_unification_variable(old: "TypeVarTupleType") -> "TypeVarTupleType": new_id = TypeVarId.new(meta_level=1) - return TypeVarTupleType(old.name, old.fullname, new_id, old.upper_bound, - line=old.line, column=old.column) + return TypeVarTupleType( + old.name, old.fullname, new_id, old.upper_bound, line=old.line, column=old.column + ) class UnboundType(ProperType): """Instance type that has not been bound during semantic analysis.""" - __slots__ = ('name', 'args', 'optional', 'empty_tuple_index', - 'original_str_expr', 'original_str_fallback') - - def __init__(self, - name: Optional[str], - args: Optional[Sequence[Type]] = None, - line: int = -1, - column: int = -1, - optional: bool = False, - empty_tuple_index: bool = False, - original_str_expr: Optional[str] = None, - original_str_fallback: Optional[str] = None, - ) -> None: + __slots__ = ( + "name", + "args", + "optional", + "empty_tuple_index", + "original_str_expr", + "original_str_fallback", + ) + + def __init__( + self, + name: Optional[str], + args: Optional[Sequence[Type]] = None, + line: int = -1, + column: int = -1, + optional: bool = False, + empty_tuple_index: bool = False, + original_str_expr: Optional[str] = None, + original_str_fallback: Optional[str] = None, + ) -> None: super().__init__(line, column) if not args: args = [] @@ -752,9 +794,7 @@ def __init__(self, self.original_str_expr = original_str_expr self.original_str_fallback = original_str_fallback - def copy_modified(self, - args: Bogus[Optional[Sequence[Type]]] = _dummy, - ) -> 'UnboundType': + def copy_modified(self, args: Bogus[Optional[Sequence[Type]]] = _dummy) -> "UnboundType": if args is _dummy: args = self.args return UnboundType( @@ -768,7 +808,7 @@ def copy_modified(self, original_str_fallback=self.original_str_fallback, ) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_unbound_type(self) def __hash__(self) -> int: @@ -777,26 +817,32 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if not isinstance(other, UnboundType): return NotImplemented - return (self.name == other.name and self.optional == other.optional and - self.args == other.args and self.original_str_expr == other.original_str_expr and - self.original_str_fallback == other.original_str_fallback) + return ( + self.name == other.name + and self.optional == other.optional + and self.args == other.args + and self.original_str_expr == other.original_str_expr + and self.original_str_fallback == other.original_str_fallback + ) def serialize(self) -> JsonDict: - return {'.class': 'UnboundType', - 'name': self.name, - 'args': [a.serialize() for a in self.args], - 'expr': self.original_str_expr, - 'expr_fallback': self.original_str_fallback, - } + return { + ".class": "UnboundType", + "name": self.name, + "args": [a.serialize() for a in self.args], + "expr": self.original_str_expr, + "expr_fallback": self.original_str_fallback, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'UnboundType': - assert data['.class'] == 'UnboundType' - return UnboundType(data['name'], - [deserialize_type(a) for a in data['args']], - original_str_expr=data['expr'], - original_str_fallback=data['expr_fallback'], - ) + def deserialize(cls, data: JsonDict) -> "UnboundType": + assert data[".class"] == "UnboundType" + return UnboundType( + data["name"], + [deserialize_type(a) for a in data["args"]], + original_str_expr=data["expr"], + original_str_fallback=data["expr_fallback"], + ) class CallableArgument(ProperType): @@ -805,20 +851,26 @@ class CallableArgument(ProperType): Note that this is a synthetic type for helping parse ASTs, not a real type. """ - __slots__ = ('typ', 'name', 'constructor') + __slots__ = ("typ", "name", "constructor") typ: Type name: Optional[str] constructor: Optional[str] - def __init__(self, typ: Type, name: Optional[str], constructor: Optional[str], - line: int = -1, column: int = -1) -> None: + def __init__( + self, + typ: Type, + name: Optional[str], + constructor: Optional[str], + line: int = -1, + column: int = -1, + ) -> None: super().__init__(line, column) self.typ = typ self.name = name self.constructor = constructor - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_callable_argument(self) @@ -835,7 +887,7 @@ class TypeList(ProperType): types before they are processed into Callable types. """ - __slots__ = ('items',) + __slots__ = ("items",) items: List[Type] @@ -843,7 +895,7 @@ def __init__(self, items: List[Type], line: int = -1, column: int = -1) -> None: super().__init__(line, column) self.items = items - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_type_list(self) @@ -858,20 +910,18 @@ class UnpackType(ProperType): The inner type should be either a TypeVarTuple, a constant size tuple, or a variable length tuple, or a union of one of those. """ + __slots__ = ["type"] def __init__(self, typ: Type, line: int = -1, column: int = -1) -> None: super().__init__(line, column) self.type = typ - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_unpack_type(self) def serialize(self) -> JsonDict: - return { - ".class": "UnpackType", - "type": self.type.serialize(), - } + return {".class": "UnpackType", "type": self.type.serialize()} @classmethod def deserialize(cls, data: JsonDict) -> "UnpackType": @@ -883,14 +933,16 @@ def deserialize(cls, data: JsonDict) -> "UnpackType": class AnyType(ProperType): """The type 'Any'.""" - __slots__ = ('type_of_any', 'source_any', 'missing_import_name') + __slots__ = ("type_of_any", "source_any", "missing_import_name") - def __init__(self, - type_of_any: int, - source_any: Optional['AnyType'] = None, - missing_import_name: Optional[str] = None, - line: int = -1, - column: int = -1) -> None: + def __init__( + self, + type_of_any: int, + source_any: Optional["AnyType"] = None, + missing_import_name: Optional[str] = None, + line: int = -1, + column: int = -1, + ) -> None: super().__init__(line, column) self.type_of_any = type_of_any # If this Any was created as a result of interacting with another 'Any', record the source @@ -905,8 +957,10 @@ def __init__(self, self.missing_import_name = source_any.missing_import_name # Only unimported type anys and anys from other anys should have an import name - assert (missing_import_name is None or - type_of_any in (TypeOfAny.from_unimported_type, TypeOfAny.from_another_any)) + assert missing_import_name is None or type_of_any in ( + TypeOfAny.from_unimported_type, + TypeOfAny.from_another_any, + ) # Only Anys that come from another Any can have source_any. assert type_of_any != TypeOfAny.from_another_any or source_any is not None # We should not have chains of Anys. @@ -916,21 +970,26 @@ def __init__(self, def is_from_error(self) -> bool: return self.type_of_any == TypeOfAny.from_error - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_any(self) - def copy_modified(self, - # Mark with Bogus because _dummy is just an object (with type Any) - type_of_any: Bogus[int] = _dummy, - original_any: Bogus[Optional['AnyType']] = _dummy, - ) -> 'AnyType': + def copy_modified( + self, + # Mark with Bogus because _dummy is just an object (with type Any) + type_of_any: Bogus[int] = _dummy, + original_any: Bogus[Optional["AnyType"]] = _dummy, + ) -> "AnyType": if type_of_any is _dummy: type_of_any = self.type_of_any if original_any is _dummy: original_any = self.source_any - return AnyType(type_of_any=type_of_any, source_any=original_any, - missing_import_name=self.missing_import_name, - line=self.line, column=self.column) + return AnyType( + type_of_any=type_of_any, + source_any=original_any, + missing_import_name=self.missing_import_name, + line=self.line, + column=self.column, + ) def __hash__(self) -> int: return hash(AnyType) @@ -939,17 +998,22 @@ def __eq__(self, other: object) -> bool: return isinstance(other, AnyType) def serialize(self) -> JsonDict: - return {'.class': 'AnyType', 'type_of_any': self.type_of_any, - 'source_any': self.source_any.serialize() if self.source_any is not None else None, - 'missing_import_name': self.missing_import_name} + return { + ".class": "AnyType", + "type_of_any": self.type_of_any, + "source_any": self.source_any.serialize() if self.source_any is not None else None, + "missing_import_name": self.missing_import_name, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'AnyType': - assert data['.class'] == 'AnyType' - source = data['source_any'] - return AnyType(data['type_of_any'], - AnyType.deserialize(source) if source is not None else None, - data['missing_import_name']) + def deserialize(cls, data: JsonDict) -> "AnyType": + assert data[".class"] == "AnyType" + source = data["source_any"] + return AnyType( + data["type_of_any"], + AnyType.deserialize(source) if source is not None else None, + data["missing_import_name"], + ) class UninhabitedType(ProperType): @@ -966,7 +1030,7 @@ class UninhabitedType(ProperType): is_subtype(UninhabitedType, T) = True """ - __slots__ = ('ambiguous', 'is_noreturn',) + __slots__ = ("ambiguous", "is_noreturn") is_noreturn: bool # Does this come from a NoReturn? Purely for error messages. # It is important to track whether this is an actual NoReturn type, or just a result @@ -985,7 +1049,7 @@ def can_be_true_default(self) -> bool: def can_be_false_default(self) -> bool: return False - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_uninhabited_type(self) def __hash__(self) -> int: @@ -995,13 +1059,12 @@ def __eq__(self, other: object) -> bool: return isinstance(other, UninhabitedType) def serialize(self) -> JsonDict: - return {'.class': 'UninhabitedType', - 'is_noreturn': self.is_noreturn} + return {".class": "UninhabitedType", "is_noreturn": self.is_noreturn} @classmethod - def deserialize(cls, data: JsonDict) -> 'UninhabitedType': - assert data['.class'] == 'UninhabitedType' - return UninhabitedType(is_noreturn=data['is_noreturn']) + def deserialize(cls, data: JsonDict) -> "UninhabitedType": + assert data[".class"] == "UninhabitedType" + return UninhabitedType(is_noreturn=data["is_noreturn"]) class NoneType(ProperType): @@ -1024,15 +1087,15 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: return isinstance(other, NoneType) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_none_type(self) def serialize(self) -> JsonDict: - return {'.class': 'NoneType'} + return {".class": "NoneType"} @classmethod - def deserialize(cls, data: JsonDict) -> 'NoneType': - assert data['.class'] == 'NoneType' + def deserialize(cls, data: JsonDict) -> "NoneType": + assert data[".class"] == "NoneType" return NoneType() def is_singleton_type(self) -> bool: @@ -1053,7 +1116,7 @@ class ErasedType(ProperType): __slots__ = () - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_erased_type(self) @@ -1063,7 +1126,7 @@ class DeletedType(ProperType): These can be used as lvalues but not rvalues. """ - __slots__ = ('source',) + __slots__ = ("source",) source: Optional[str] # May be None; name that generated this value @@ -1071,17 +1134,16 @@ def __init__(self, source: Optional[str] = None, line: int = -1, column: int = - super().__init__(line, column) self.source = source - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_deleted_type(self) def serialize(self) -> JsonDict: - return {'.class': 'DeletedType', - 'source': self.source} + return {".class": "DeletedType", "source": self.source} @classmethod - def deserialize(cls, data: JsonDict) -> 'DeletedType': - assert data['.class'] == 'DeletedType' - return DeletedType(data['source']) + def deserialize(cls, data: JsonDict) -> "DeletedType": + assert data[".class"] == "DeletedType" + return DeletedType(data["source"]) # Fake TypeInfo to be used as a placeholder during Instance de-serialization. @@ -1119,11 +1181,17 @@ def try_getting_instance_fallback(typ: ProperType) -> Optional[Instance]: """ - __slots__ = ('type', 'args', 'invalid', 'type_ref', 'last_known_value', '_hash') + __slots__ = ("type", "args", "invalid", "type_ref", "last_known_value", "_hash") - def __init__(self, typ: mypy.nodes.TypeInfo, args: Sequence[Type], - line: int = -1, column: int = -1, *, - last_known_value: Optional['LiteralType'] = None) -> None: + def __init__( + self, + typ: mypy.nodes.TypeInfo, + args: Sequence[Type], + line: int = -1, + column: int = -1, + *, + last_known_value: Optional["LiteralType"] = None, + ) -> None: super().__init__(line, column) self.type = typ self.args = tuple(args) @@ -1180,7 +1248,7 @@ def __init__(self, typ: mypy.nodes.TypeInfo, args: Sequence[Type], # Cached hash value self._hash = -1 - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_instance(self) def __hash__(self) -> int: @@ -1191,51 +1259,55 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if not isinstance(other, Instance): return NotImplemented - return (self.type == other.type - and self.args == other.args - and self.last_known_value == other.last_known_value) + return ( + self.type == other.type + and self.args == other.args + and self.last_known_value == other.last_known_value + ) def serialize(self) -> Union[JsonDict, str]: assert self.type is not None type_ref = self.type.fullname if not self.args and not self.last_known_value: return type_ref - data: JsonDict = { - ".class": "Instance", - } + data: JsonDict = {".class": "Instance"} data["type_ref"] = type_ref data["args"] = [arg.serialize() for arg in self.args] if self.last_known_value is not None: - data['last_known_value'] = self.last_known_value.serialize() + data["last_known_value"] = self.last_known_value.serialize() return data @classmethod - def deserialize(cls, data: Union[JsonDict, str]) -> 'Instance': + def deserialize(cls, data: Union[JsonDict, str]) -> "Instance": if isinstance(data, str): inst = Instance(NOT_READY, []) inst.type_ref = data return inst - assert data['.class'] == 'Instance' + assert data[".class"] == "Instance" args: List[Type] = [] - if 'args' in data: - args_list = data['args'] + if "args" in data: + args_list = data["args"] assert isinstance(args_list, list) args = [deserialize_type(arg) for arg in args_list] inst = Instance(NOT_READY, args) - inst.type_ref = data['type_ref'] # Will be fixed up by fixup.py later. - if 'last_known_value' in data: - inst.last_known_value = LiteralType.deserialize(data['last_known_value']) + inst.type_ref = data["type_ref"] # Will be fixed up by fixup.py later. + if "last_known_value" in data: + inst.last_known_value = LiteralType.deserialize(data["last_known_value"]) return inst - def copy_modified(self, *, - args: Bogus[List[Type]] = _dummy, - last_known_value: Bogus[Optional['LiteralType']] = _dummy) -> 'Instance': + def copy_modified( + self, + *, + args: Bogus[List[Type]] = _dummy, + last_known_value: Bogus[Optional["LiteralType"]] = _dummy, + ) -> "Instance": return Instance( self.type, args if args is not _dummy else self.args, self.line, self.column, - last_known_value=last_known_value if last_known_value is not _dummy + last_known_value=last_known_value + if last_known_value is not _dummy else self.last_known_value, ) @@ -1246,22 +1318,22 @@ def is_singleton_type(self) -> bool: # TODO: # Also make this return True if the type corresponds to NotImplemented? return ( - self.type.is_enum and len(self.get_enum_values()) == 1 - or self.type.fullname == 'builtins.ellipsis' + self.type.is_enum + and len(self.get_enum_values()) == 1 + or self.type.fullname == "builtins.ellipsis" ) def get_enum_values(self) -> List[str]: """Return the list of values for an Enum.""" return [ - name for name, sym in self.type.names.items() - if isinstance(sym.node, mypy.nodes.Var) + name for name, sym in self.type.names.items() if isinstance(sym.node, mypy.nodes.Var) ] class FunctionLike(ProperType): """Abstract base class for function types.""" - __slots__ = ('fallback',) + __slots__ = ("fallback",) fallback: Instance @@ -1270,20 +1342,25 @@ def __init__(self, line: int = -1, column: int = -1) -> None: self.can_be_false = False @abstractmethod - def is_type_obj(self) -> bool: pass + def is_type_obj(self) -> bool: + pass @abstractmethod - def type_object(self) -> mypy.nodes.TypeInfo: pass + def type_object(self) -> mypy.nodes.TypeInfo: + pass @property @abstractmethod - def items(self) -> List['CallableType']: pass + def items(self) -> List["CallableType"]: + pass @abstractmethod - def with_name(self, name: str) -> 'FunctionLike': pass + def with_name(self, name: str) -> "FunctionLike": + pass @abstractmethod - def get_name(self) -> Optional[str]: pass + def get_name(self) -> Optional[str]: + pass class FormalArgument(NamedTuple): @@ -1300,23 +1377,27 @@ class Parameters(ProperType): """Type that represents the parameters to a function. Used for ParamSpec analysis.""" - __slots__ = ('arg_types', - 'arg_kinds', - 'arg_names', - 'min_args', - 'is_ellipsis_args', - 'variables') - - def __init__(self, - arg_types: Sequence[Type], - arg_kinds: List[ArgKind], - arg_names: Sequence[Optional[str]], - *, - variables: Optional[Sequence[TypeVarLikeType]] = None, - is_ellipsis_args: bool = False, - line: int = -1, - column: int = -1 - ) -> None: + + __slots__ = ( + "arg_types", + "arg_kinds", + "arg_names", + "min_args", + "is_ellipsis_args", + "variables", + ) + + def __init__( + self, + arg_types: Sequence[Type], + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + *, + variables: Optional[Sequence[TypeVarLikeType]] = None, + is_ellipsis_args: bool = False, + line: int = -1, + column: int = -1, + ) -> None: super().__init__(line, column) self.arg_types = list(arg_types) self.arg_kinds = arg_kinds @@ -1326,21 +1407,23 @@ def __init__(self, self.is_ellipsis_args = is_ellipsis_args self.variables = variables or [] - def copy_modified(self, - arg_types: Bogus[Sequence[Type]] = _dummy, - arg_kinds: Bogus[List[ArgKind]] = _dummy, - arg_names: Bogus[Sequence[Optional[str]]] = _dummy, - *, - variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, - is_ellipsis_args: Bogus[bool] = _dummy - ) -> 'Parameters': + def copy_modified( + self, + arg_types: Bogus[Sequence[Type]] = _dummy, + arg_kinds: Bogus[List[ArgKind]] = _dummy, + arg_names: Bogus[Sequence[Optional[str]]] = _dummy, + *, + variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, + is_ellipsis_args: Bogus[bool] = _dummy, + ) -> "Parameters": return Parameters( arg_types=arg_types if arg_types is not _dummy else self.arg_types, arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds, arg_names=arg_names if arg_names is not _dummy else self.arg_names, - is_ellipsis_args=(is_ellipsis_args if is_ellipsis_args is not _dummy - else self.is_ellipsis_args), - variables=variables if variables is not _dummy else self.variables + is_ellipsis_args=( + is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args + ), + variables=variables if variables is not _dummy else self.variables, ) # the following are copied from CallableType. Is there a way to decrease code duplication? @@ -1377,12 +1460,7 @@ def formal_arguments(self, include_star_args: bool = False) -> List[FormalArgume required = kind.is_required() pos = None if done_with_positional else i - arg = FormalArgument( - self.arg_names[i], - pos, - self.arg_types[i], - required - ) + arg = FormalArgument(self.arg_names[i], pos, self.arg_types[i], required) args.append(arg) return args @@ -1391,7 +1469,8 @@ def argument_by_name(self, name: Optional[str]) -> Optional[FormalArgument]: return None seen_star = False for i, (arg_name, kind, typ) in enumerate( - zip(self.arg_names, self.arg_kinds, self.arg_types)): + zip(self.arg_names, self.arg_kinds, self.arg_types) + ): # No more positional arguments after these. if kind.is_named() or kind.is_star(): seen_star = True @@ -1417,54 +1496,61 @@ def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgume else: return self.try_synthesizing_arg_from_vararg(position) - def try_synthesizing_arg_from_kwarg(self, - name: Optional[str]) -> Optional[FormalArgument]: + def try_synthesizing_arg_from_kwarg(self, name: Optional[str]) -> Optional[FormalArgument]: kw_arg = self.kw_arg() if kw_arg is not None: return FormalArgument(name, None, kw_arg.typ, False) else: return None - def try_synthesizing_arg_from_vararg(self, - position: Optional[int]) -> Optional[FormalArgument]: + def try_synthesizing_arg_from_vararg( + self, position: Optional[int] + ) -> Optional[FormalArgument]: var_arg = self.var_arg() if var_arg is not None: return FormalArgument(None, position, var_arg.typ, False) else: return None - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_parameters(self) def serialize(self) -> JsonDict: - return {'.class': 'Parameters', - 'arg_types': [t.serialize() for t in self.arg_types], - 'arg_kinds': [int(x.value) for x in self.arg_kinds], - 'arg_names': self.arg_names, - 'variables': [tv.serialize() for tv in self.variables], - } + return { + ".class": "Parameters", + "arg_types": [t.serialize() for t in self.arg_types], + "arg_kinds": [int(x.value) for x in self.arg_kinds], + "arg_names": self.arg_names, + "variables": [tv.serialize() for tv in self.variables], + } @classmethod - def deserialize(cls, data: JsonDict) -> 'Parameters': - assert data['.class'] == 'Parameters' + def deserialize(cls, data: JsonDict) -> "Parameters": + assert data[".class"] == "Parameters" return Parameters( - [deserialize_type(t) for t in data['arg_types']], - [ArgKind(x) for x in data['arg_kinds']], - data['arg_names'], - variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data['variables']], + [deserialize_type(t) for t in data["arg_types"]], + [ArgKind(x) for x in data["arg_kinds"]], + data["arg_names"], + variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]], ) def __hash__(self) -> int: - return hash((self.is_ellipsis_args, tuple(self.arg_types), - tuple(self.arg_names), tuple(self.arg_kinds))) + return hash( + ( + self.is_ellipsis_args, + tuple(self.arg_types), + tuple(self.arg_names), + tuple(self.arg_kinds), + ) + ) def __eq__(self, other: object) -> bool: if isinstance(other, Parameters) or isinstance(other, CallableType): return ( - self.arg_types == other.arg_types and - self.arg_names == other.arg_names and - self.arg_kinds == other.arg_kinds and - self.is_ellipsis_args == other.is_ellipsis_args + self.arg_types == other.arg_types + and self.arg_names == other.arg_names + and self.arg_kinds == other.arg_kinds + and self.is_ellipsis_args == other.is_ellipsis_args ) else: return NotImplemented @@ -1473,54 +1559,56 @@ def __eq__(self, other: object) -> bool: class CallableType(FunctionLike): """Type of a non-overloaded callable object (such as function).""" - __slots__ = ('arg_types', # Types of function arguments - 'arg_kinds', # ARG_ constants - 'arg_names', # Argument names; None if not a keyword argument - 'min_args', # Minimum number of arguments; derived from arg_kinds - 'ret_type', # Return value type - 'name', # Name (may be None; for error messages and plugins) - 'definition', # For error messages. May be None. - 'variables', # Type variables for a generic function - 'is_ellipsis_args', # Is this Callable[..., t] (with literal '...')? - 'is_classmethod_class', # Is this callable constructed for the benefit - # of a classmethod's 'cls' argument? - 'implicit', # Was this type implicitly generated instead of explicitly - # specified by the user? - 'special_sig', # Non-None for signatures that require special handling - # (currently only value is 'dict' for a signature similar to - # 'dict') - 'from_type_type', # Was this callable generated by analyzing Type[...] - # instantiation? - 'bound_args', # Bound type args, mostly unused but may be useful for - # tools that consume mypy ASTs - 'def_extras', # Information about original definition we want to serialize. - # This is used for more detailed error messages. - 'type_guard', # T, if -> TypeGuard[T] (ret_type is bool in this case). - 'from_concatenate', # whether this callable is from a concatenate object - # (this is used for error messages) - ) - - def __init__(self, - # maybe this should be refactored to take a Parameters object - arg_types: Sequence[Type], - arg_kinds: List[ArgKind], - arg_names: Sequence[Optional[str]], - ret_type: Type, - fallback: Instance, - name: Optional[str] = None, - definition: Optional[SymbolNode] = None, - variables: Optional[Sequence[TypeVarLikeType]] = None, - line: int = -1, - column: int = -1, - is_ellipsis_args: bool = False, - implicit: bool = False, - special_sig: Optional[str] = None, - from_type_type: bool = False, - bound_args: Sequence[Optional[Type]] = (), - def_extras: Optional[Dict[str, Any]] = None, - type_guard: Optional[Type] = None, - from_concatenate: bool = False - ) -> None: + __slots__ = ( + "arg_types", # Types of function arguments + "arg_kinds", # ARG_ constants + "arg_names", # Argument names; None if not a keyword argument + "min_args", # Minimum number of arguments; derived from arg_kinds + "ret_type", # Return value type + "name", # Name (may be None; for error messages and plugins) + "definition", # For error messages. May be None. + "variables", # Type variables for a generic function + "is_ellipsis_args", # Is this Callable[..., t] (with literal '...')? + "is_classmethod_class", # Is this callable constructed for the benefit + # of a classmethod's 'cls' argument? + "implicit", # Was this type implicitly generated instead of explicitly + # specified by the user? + "special_sig", # Non-None for signatures that require special handling + # (currently only value is 'dict' for a signature similar to + # 'dict') + "from_type_type", # Was this callable generated by analyzing Type[...] + # instantiation? + "bound_args", # Bound type args, mostly unused but may be useful for + # tools that consume mypy ASTs + "def_extras", # Information about original definition we want to serialize. + # This is used for more detailed error messages. + "type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case). + "from_concatenate", # whether this callable is from a concatenate object + # (this is used for error messages) + ) + + def __init__( + self, + # maybe this should be refactored to take a Parameters object + arg_types: Sequence[Type], + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + ret_type: Type, + fallback: Instance, + name: Optional[str] = None, + definition: Optional[SymbolNode] = None, + variables: Optional[Sequence[TypeVarLikeType]] = None, + line: int = -1, + column: int = -1, + is_ellipsis_args: bool = False, + implicit: bool = False, + special_sig: Optional[str] = None, + from_type_type: bool = False, + bound_args: Sequence[Optional[Type]] = (), + def_extras: Optional[Dict[str, Any]] = None, + type_guard: Optional[Type] = None, + from_concatenate: bool = False, + ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) if variables is None: @@ -1531,7 +1619,7 @@ def __init__(self, self.min_args = arg_kinds.count(ARG_POS) self.ret_type = ret_type self.fallback = fallback - assert not name or ' 'CallableType': + def copy_modified( + self, + arg_types: Bogus[Sequence[Type]] = _dummy, + arg_kinds: Bogus[List[ArgKind]] = _dummy, + arg_names: Bogus[List[Optional[str]]] = _dummy, + ret_type: Bogus[Type] = _dummy, + fallback: Bogus[Instance] = _dummy, + name: Bogus[Optional[str]] = _dummy, + definition: Bogus[SymbolNode] = _dummy, + variables: Bogus[Sequence[TypeVarLikeType]] = _dummy, + line: Bogus[int] = _dummy, + column: Bogus[int] = _dummy, + is_ellipsis_args: Bogus[bool] = _dummy, + implicit: Bogus[bool] = _dummy, + special_sig: Bogus[Optional[str]] = _dummy, + from_type_type: Bogus[bool] = _dummy, + bound_args: Bogus[List[Optional[Type]]] = _dummy, + def_extras: Bogus[Dict[str, Any]] = _dummy, + type_guard: Bogus[Optional[Type]] = _dummy, + from_concatenate: Bogus[bool] = _dummy, + ) -> "CallableType": return CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds, @@ -1595,15 +1682,17 @@ def copy_modified(self, line=line if line is not _dummy else self.line, column=column if column is not _dummy else self.column, is_ellipsis_args=( - is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args), + is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args + ), implicit=implicit if implicit is not _dummy else self.implicit, special_sig=special_sig if special_sig is not _dummy else self.special_sig, from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type, bound_args=bound_args if bound_args is not _dummy else self.bound_args, def_extras=def_extras if def_extras is not _dummy else dict(self.def_extras), type_guard=type_guard if type_guard is not _dummy else self.type_guard, - from_concatenate=(from_concatenate if from_concatenate is not _dummy - else self.from_concatenate), + from_concatenate=( + from_concatenate if from_concatenate is not _dummy else self.from_concatenate + ), ) def var_arg(self) -> Optional[FormalArgument]: @@ -1643,10 +1732,10 @@ def type_object(self) -> mypy.nodes.TypeInfo: assert isinstance(ret, Instance) return ret.type - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_callable_type(self) - def with_name(self, name: str) -> 'CallableType': + def with_name(self, name: str) -> "CallableType": """Return a copy of this type with the specified name.""" return self.copy_modified(ret_type=self.ret_type, name=name) @@ -1680,12 +1769,7 @@ def formal_arguments(self, include_star_args: bool = False) -> List[FormalArgume required = kind.is_required() pos = None if done_with_positional else i - arg = FormalArgument( - self.arg_names[i], - pos, - self.arg_types[i], - required - ) + arg = FormalArgument(self.arg_names[i], pos, self.arg_types[i], required) args.append(arg) return args @@ -1694,7 +1778,8 @@ def argument_by_name(self, name: Optional[str]) -> Optional[FormalArgument]: return None seen_star = False for i, (arg_name, kind, typ) in enumerate( - zip(self.arg_names, self.arg_kinds, self.arg_types)): + zip(self.arg_names, self.arg_kinds, self.arg_types) + ): # No more positional arguments after these. if kind.is_named() or kind.is_star(): seen_star = True @@ -1720,16 +1805,16 @@ def argument_by_position(self, position: Optional[int]) -> Optional[FormalArgume else: return self.try_synthesizing_arg_from_vararg(position) - def try_synthesizing_arg_from_kwarg(self, - name: Optional[str]) -> Optional[FormalArgument]: + def try_synthesizing_arg_from_kwarg(self, name: Optional[str]) -> Optional[FormalArgument]: kw_arg = self.kw_arg() if kw_arg is not None: return FormalArgument(name, None, kw_arg.typ, False) else: return None - def try_synthesizing_arg_from_vararg(self, - position: Optional[int]) -> Optional[FormalArgument]: + def try_synthesizing_arg_from_vararg( + self, position: Optional[int] + ) -> Optional[FormalArgument]: var_arg = self.var_arg() if var_arg is not None: return FormalArgument(None, position, var_arg.typ, False) @@ -1737,7 +1822,7 @@ def try_synthesizing_arg_from_vararg(self, return None @property - def items(self) -> List['CallableType']: + def items(self) -> List["CallableType"]: return [self] def is_generic(self) -> bool: @@ -1768,26 +1853,36 @@ def param_spec(self) -> Optional[ParamSpecType]: if not prefix.arg_types: # TODO: confirm that all arg kinds are positional prefix = Parameters(self.arg_types[:-2], self.arg_kinds[:-2], self.arg_names[:-2]) - return ParamSpecType(arg_type.name, arg_type.fullname, arg_type.id, ParamSpecFlavor.BARE, - arg_type.upper_bound, prefix=prefix) + return ParamSpecType( + arg_type.name, + arg_type.fullname, + arg_type.id, + ParamSpecFlavor.BARE, + arg_type.upper_bound, + prefix=prefix, + ) - def expand_param_spec(self, - c: Union['CallableType', Parameters], - no_prefix: bool = False) -> 'CallableType': + def expand_param_spec( + self, c: Union["CallableType", Parameters], no_prefix: bool = False + ) -> "CallableType": variables = c.variables if no_prefix: - return self.copy_modified(arg_types=c.arg_types, - arg_kinds=c.arg_kinds, - arg_names=c.arg_names, - is_ellipsis_args=c.is_ellipsis_args, - variables=[*variables, *self.variables]) + return self.copy_modified( + arg_types=c.arg_types, + arg_kinds=c.arg_kinds, + arg_names=c.arg_names, + is_ellipsis_args=c.is_ellipsis_args, + variables=[*variables, *self.variables], + ) else: - return self.copy_modified(arg_types=self.arg_types[:-2] + c.arg_types, - arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, - arg_names=self.arg_names[:-2] + c.arg_names, - is_ellipsis_args=c.is_ellipsis_args, - variables=[*variables, *self.variables]) + return self.copy_modified( + arg_types=self.arg_types[:-2] + c.arg_types, + arg_kinds=self.arg_kinds[:-2] + c.arg_kinds, + arg_names=self.arg_names[:-2] + c.arg_names, + is_ellipsis_args=c.is_ellipsis_args, + variables=[*variables, *self.variables], + ) def __hash__(self) -> int: # self.is_type_obj() will fail if self.fallback.type is a FakeInfo @@ -1795,62 +1890,73 @@ def __hash__(self) -> int: is_type_obj = 2 else: is_type_obj = self.is_type_obj() - return hash((self.ret_type, is_type_obj, - self.is_ellipsis_args, self.name, - tuple(self.arg_types), tuple(self.arg_names), tuple(self.arg_kinds))) + return hash( + ( + self.ret_type, + is_type_obj, + self.is_ellipsis_args, + self.name, + tuple(self.arg_types), + tuple(self.arg_names), + tuple(self.arg_kinds), + ) + ) def __eq__(self, other: object) -> bool: if isinstance(other, CallableType): - return (self.ret_type == other.ret_type and - self.arg_types == other.arg_types and - self.arg_names == other.arg_names and - self.arg_kinds == other.arg_kinds and - self.name == other.name and - self.is_type_obj() == other.is_type_obj() and - self.is_ellipsis_args == other.is_ellipsis_args) + return ( + self.ret_type == other.ret_type + and self.arg_types == other.arg_types + and self.arg_names == other.arg_names + and self.arg_kinds == other.arg_kinds + and self.name == other.name + and self.is_type_obj() == other.is_type_obj() + and self.is_ellipsis_args == other.is_ellipsis_args + ) else: return NotImplemented def serialize(self) -> JsonDict: # TODO: As an optimization, leave out everything related to # generic functions for non-generic functions. - return {'.class': 'CallableType', - 'arg_types': [t.serialize() for t in self.arg_types], - 'arg_kinds': [int(x.value) for x in self.arg_kinds], - 'arg_names': self.arg_names, - 'ret_type': self.ret_type.serialize(), - 'fallback': self.fallback.serialize(), - 'name': self.name, - # We don't serialize the definition (only used for error messages). - 'variables': [v.serialize() for v in self.variables], - 'is_ellipsis_args': self.is_ellipsis_args, - 'implicit': self.implicit, - 'bound_args': [(None if t is None else t.serialize()) - for t in self.bound_args], - 'def_extras': dict(self.def_extras), - 'type_guard': self.type_guard.serialize() if self.type_guard is not None else None, - 'from_concatenate': self.from_concatenate, - } + return { + ".class": "CallableType", + "arg_types": [t.serialize() for t in self.arg_types], + "arg_kinds": [int(x.value) for x in self.arg_kinds], + "arg_names": self.arg_names, + "ret_type": self.ret_type.serialize(), + "fallback": self.fallback.serialize(), + "name": self.name, + # We don't serialize the definition (only used for error messages). + "variables": [v.serialize() for v in self.variables], + "is_ellipsis_args": self.is_ellipsis_args, + "implicit": self.implicit, + "bound_args": [(None if t is None else t.serialize()) for t in self.bound_args], + "def_extras": dict(self.def_extras), + "type_guard": self.type_guard.serialize() if self.type_guard is not None else None, + "from_concatenate": self.from_concatenate, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'CallableType': - assert data['.class'] == 'CallableType' + def deserialize(cls, data: JsonDict) -> "CallableType": + assert data[".class"] == "CallableType" # TODO: Set definition to the containing SymbolNode? return CallableType( - [deserialize_type(t) for t in data['arg_types']], - [ArgKind(x) for x in data['arg_kinds']], - data['arg_names'], - deserialize_type(data['ret_type']), - Instance.deserialize(data['fallback']), - name=data['name'], - variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data['variables']], - is_ellipsis_args=data['is_ellipsis_args'], - implicit=data['implicit'], - bound_args=[(None if t is None else deserialize_type(t)) for t in data['bound_args']], - def_extras=data['def_extras'], - type_guard=(deserialize_type(data['type_guard']) - if data['type_guard'] is not None else None), - from_concatenate=data['from_concatenate'], + [deserialize_type(t) for t in data["arg_types"]], + [ArgKind(x) for x in data["arg_kinds"]], + data["arg_names"], + deserialize_type(data["ret_type"]), + Instance.deserialize(data["fallback"]), + name=data["name"], + variables=[cast(TypeVarLikeType, deserialize_type(v)) for v in data["variables"]], + is_ellipsis_args=data["is_ellipsis_args"], + implicit=data["implicit"], + bound_args=[(None if t is None else deserialize_type(t)) for t in data["bound_args"]], + def_extras=data["def_extras"], + type_guard=( + deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None + ), + from_concatenate=data["from_concatenate"], ) @@ -1863,7 +1969,7 @@ class Overloaded(FunctionLike): implementation. """ - __slots__ = ('_items',) + __slots__ = ("_items",) _items: List[CallableType] # Must not be empty @@ -1889,7 +1995,7 @@ def type_object(self) -> mypy.nodes.TypeInfo: # query only (any) one of them. return self._items[0].type_object() - def with_name(self, name: str) -> 'Overloaded': + def with_name(self, name: str) -> "Overloaded": ni: List[CallableType] = [] for it in self._items: ni.append(it.with_name(name)) @@ -1898,7 +2004,7 @@ def with_name(self, name: str) -> 'Overloaded': def get_name(self) -> Optional[str]: return self._items[0].name - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_overloaded(self) def __hash__(self) -> int: @@ -1910,14 +2016,12 @@ def __eq__(self, other: object) -> bool: return self.items == other.items def serialize(self) -> JsonDict: - return {'.class': 'Overloaded', - 'items': [t.serialize() for t in self.items], - } + return {".class": "Overloaded", "items": [t.serialize() for t in self.items]} @classmethod - def deserialize(cls, data: JsonDict) -> 'Overloaded': - assert data['.class'] == 'Overloaded' - return Overloaded([CallableType.deserialize(t) for t in data['items']]) + def deserialize(cls, data: JsonDict) -> "Overloaded": + assert data[".class"] == "Overloaded" + return Overloaded([CallableType.deserialize(t) for t in data["items"]]) class TupleType(ProperType): @@ -1933,14 +2037,20 @@ class TupleType(ProperType): implicit: If True, derived from a tuple expression (t,....) instead of Tuple[t, ...] """ - __slots__ = ('items', 'partial_fallback', 'implicit') + __slots__ = ("items", "partial_fallback", "implicit") items: List[Type] partial_fallback: Instance implicit: bool - def __init__(self, items: List[Type], fallback: Instance, line: int = -1, - column: int = -1, implicit: bool = False) -> None: + def __init__( + self, + items: List[Type], + fallback: Instance, + line: int = -1, + column: int = -1, + implicit: bool = False, + ) -> None: self.partial_fallback = fallback self.items = items self.implicit = implicit @@ -1963,14 +2073,14 @@ def can_be_false_default(self) -> bool: def can_be_any_bool(self) -> bool: return bool( self.partial_fallback.type - and self.partial_fallback.type.fullname != 'builtins.tuple' - and self.partial_fallback.type.names.get('__bool__') + and self.partial_fallback.type.fullname != "builtins.tuple" + and self.partial_fallback.type.names.get("__bool__") ) def length(self) -> int: return len(self.items) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_tuple_type(self) def __hash__(self) -> int: @@ -1982,31 +2092,41 @@ def __eq__(self, other: object) -> bool: return self.items == other.items and self.partial_fallback == other.partial_fallback def serialize(self) -> JsonDict: - return {'.class': 'TupleType', - 'items': [t.serialize() for t in self.items], - 'partial_fallback': self.partial_fallback.serialize(), - 'implicit': self.implicit, - } + return { + ".class": "TupleType", + "items": [t.serialize() for t in self.items], + "partial_fallback": self.partial_fallback.serialize(), + "implicit": self.implicit, + } @classmethod - def deserialize(cls, data: JsonDict) -> 'TupleType': - assert data['.class'] == 'TupleType' - return TupleType([deserialize_type(t) for t in data['items']], - Instance.deserialize(data['partial_fallback']), - implicit=data['implicit']) - - def copy_modified(self, *, fallback: Optional[Instance] = None, - items: Optional[List[Type]] = None) -> 'TupleType': + def deserialize(cls, data: JsonDict) -> "TupleType": + assert data[".class"] == "TupleType" + return TupleType( + [deserialize_type(t) for t in data["items"]], + Instance.deserialize(data["partial_fallback"]), + implicit=data["implicit"], + ) + + def copy_modified( + self, *, fallback: Optional[Instance] = None, items: Optional[List[Type]] = None + ) -> "TupleType": if fallback is None: fallback = self.partial_fallback if items is None: items = self.items return TupleType(items, fallback, self.line, self.column) - def slice(self, begin: Optional[int], end: Optional[int], - stride: Optional[int]) -> 'TupleType': - return TupleType(self.items[begin:end:stride], self.partial_fallback, - self.line, self.column, self.implicit) + def slice( + self, begin: Optional[int], end: Optional[int], stride: Optional[int] + ) -> "TupleType": + return TupleType( + self.items[begin:end:stride], + self.partial_fallback, + self.line, + self.column, + self.implicit, + ) class TypedDictType(ProperType): @@ -2029,14 +2149,20 @@ class TypedDictType(ProperType): TODO: The fallback structure is perhaps overly complicated. """ - __slots__ = ('items', 'required_keys', 'fallback') + __slots__ = ("items", "required_keys", "fallback") items: "OrderedDict[str, Type]" # item_name -> item_type required_keys: Set[str] fallback: Instance - def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str], - fallback: Instance, line: int = -1, column: int = -1) -> None: + def __init__( + self, + items: "OrderedDict[str, Type]", + required_keys: Set[str], + fallback: Instance, + line: int = -1, + column: int = -1, + ) -> None: super().__init__(line, column) self.items = items self.required_keys = required_keys @@ -2044,12 +2170,11 @@ def __init__(self, items: 'OrderedDict[str, Type]', required_keys: Set[str], self.can_be_true = len(self.items) > 0 self.can_be_false = len(self.required_keys) == 0 - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_typeddict_type(self) def __hash__(self) -> int: - return hash((frozenset(self.items.items()), self.fallback, - frozenset(self.required_keys))) + return hash((frozenset(self.items.items()), self.fallback, frozenset(self.required_keys))) def __eq__(self, other: object) -> bool: if isinstance(other, TypedDictType): @@ -2063,32 +2188,38 @@ def __eq__(self, other: object) -> bool: return NotImplemented def serialize(self) -> JsonDict: - return {'.class': 'TypedDictType', - 'items': [[n, t.serialize()] for (n, t) in self.items.items()], - 'required_keys': sorted(self.required_keys), - 'fallback': self.fallback.serialize(), - } + return { + ".class": "TypedDictType", + "items": [[n, t.serialize()] for (n, t) in self.items.items()], + "required_keys": sorted(self.required_keys), + "fallback": self.fallback.serialize(), + } @classmethod - def deserialize(cls, data: JsonDict) -> 'TypedDictType': - assert data['.class'] == 'TypedDictType' - return TypedDictType(OrderedDict([(n, deserialize_type(t)) - for (n, t) in data['items']]), - set(data['required_keys']), - Instance.deserialize(data['fallback'])) + def deserialize(cls, data: JsonDict) -> "TypedDictType": + assert data[".class"] == "TypedDictType" + return TypedDictType( + OrderedDict([(n, deserialize_type(t)) for (n, t) in data["items"]]), + set(data["required_keys"]), + Instance.deserialize(data["fallback"]), + ) def is_anonymous(self) -> bool: return self.fallback.type.fullname in TPDICT_FB_NAMES - def as_anonymous(self) -> 'TypedDictType': + def as_anonymous(self) -> "TypedDictType": if self.is_anonymous(): return self assert self.fallback.type.typeddict_type is not None return self.fallback.type.typeddict_type.as_anonymous() - def copy_modified(self, *, fallback: Optional[Instance] = None, - item_types: Optional[List[Type]] = None, - required_keys: Optional[Set[str]] = None) -> 'TypedDictType': + def copy_modified( + self, + *, + fallback: Optional[Instance] = None, + item_types: Optional[List[Type]] = None, + required_keys: Optional[Set[str]] = None, + ) -> "TypedDictType": if fallback is None: fallback = self.fallback if item_types is None: @@ -2103,18 +2234,19 @@ def create_anonymous_fallback(self) -> Instance: anonymous = self.as_anonymous() return anonymous.fallback - def names_are_wider_than(self, other: 'TypedDictType') -> bool: + def names_are_wider_than(self, other: "TypedDictType") -> bool: return len(other.items.keys() - self.items.keys()) == 0 - def zip(self, right: 'TypedDictType') -> Iterable[Tuple[str, Type, Type]]: + def zip(self, right: "TypedDictType") -> Iterable[Tuple[str, Type, Type]]: left = self for (item_name, left_item_type) in left.items.items(): right_item_type = right.items.get(item_name) if right_item_type is not None: yield (item_name, left_item_type, right_item_type) - def zipall(self, right: 'TypedDictType') \ - -> Iterable[Tuple[str, Optional[Type], Optional[Type]]]: + def zipall( + self, right: "TypedDictType" + ) -> Iterable[Tuple[str, Optional[Type], Optional[Type]]]: left = self for (item_name, left_item_type) in left.items.items(): right_item_type = right.items.get(item_name) @@ -2169,15 +2301,16 @@ class RawExpressionType(ProperType): ) """ - __slots__ = ('literal_value', 'base_type_name', 'note') + __slots__ = ("literal_value", "base_type_name", "note") - def __init__(self, - literal_value: Optional[LiteralValue], - base_type_name: str, - line: int = -1, - column: int = -1, - note: Optional[str] = None, - ) -> None: + def __init__( + self, + literal_value: Optional[LiteralValue], + base_type_name: str, + line: int = -1, + column: int = -1, + note: Optional[str] = None, + ) -> None: super().__init__(line, column) self.literal_value = literal_value self.base_type_name = base_type_name @@ -2186,7 +2319,7 @@ def __init__(self, def simple_name(self) -> str: return self.base_type_name.replace("builtins.", "") - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_raw_expression_type(self) @@ -2198,8 +2331,10 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if isinstance(other, RawExpressionType): - return (self.base_type_name == other.base_type_name - and self.literal_value == other.literal_value) + return ( + self.base_type_name == other.base_type_name + and self.literal_value == other.literal_value + ) else: return NotImplemented @@ -2219,10 +2354,12 @@ class LiteralType(ProperType): As another example, `Literal[Color.RED]` (where Color is an enum) is represented as `LiteralType(value="RED", fallback=instance_of_color)'. """ - __slots__ = ('value', 'fallback', '_hash') - def __init__(self, value: LiteralValue, fallback: Instance, - line: int = -1, column: int = -1) -> None: + __slots__ = ("value", "fallback", "_hash") + + def __init__( + self, value: LiteralValue, fallback: Instance, line: int = -1, column: int = -1 + ) -> None: self.value = value super().__init__(line, column) self.fallback = fallback @@ -2234,7 +2371,7 @@ def can_be_false_default(self) -> bool: def can_be_true_default(self) -> bool: return bool(self.value) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_literal_type(self) def __hash__(self) -> int: @@ -2263,16 +2400,16 @@ def value_repr(self) -> str: # If this is backed by an enum, if self.is_enum_literal(): - return f'{fallback_name}.{self.value}' + return f"{fallback_name}.{self.value}" - if fallback_name == 'builtins.bytes': + if fallback_name == "builtins.bytes": # Note: 'builtins.bytes' only appears in Python 3, so we want to # explicitly prefix with a "b" - return 'b' + raw - elif fallback_name == 'builtins.unicode': + return "b" + raw + elif fallback_name == "builtins.unicode": # Similarly, 'builtins.unicode' only appears in Python 2, where we also # want to explicitly prefix - return 'u' + raw + return "u" + raw else: # 'builtins.str' could mean either depending on context, but either way # we don't prefix: it's the "native" string. And of course, if value is @@ -2281,18 +2418,15 @@ def value_repr(self) -> str: def serialize(self) -> Union[JsonDict, str]: return { - '.class': 'LiteralType', - 'value': self.value, - 'fallback': self.fallback.serialize(), + ".class": "LiteralType", + "value": self.value, + "fallback": self.fallback.serialize(), } @classmethod - def deserialize(cls, data: JsonDict) -> 'LiteralType': - assert data['.class'] == 'LiteralType' - return LiteralType( - value=data['value'], - fallback=Instance.deserialize(data['fallback']), - ) + def deserialize(cls, data: JsonDict) -> "LiteralType": + assert data[".class"] == "LiteralType" + return LiteralType(value=data["value"], fallback=Instance.deserialize(data["fallback"])) def is_singleton_type(self) -> bool: return self.is_enum_literal() or isinstance(self.value, bool) @@ -2304,7 +2438,7 @@ class StarType(ProperType): This is not a real type but a syntactic AST construct. """ - __slots__ = ('type',) + __slots__ = ("type",) type: Type @@ -2312,7 +2446,7 @@ def __init__(self, type: Type, line: int = -1, column: int = -1) -> None: super().__init__(line, column) self.type = type - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_star_type(self) @@ -2323,10 +2457,16 @@ def serialize(self) -> JsonDict: class UnionType(ProperType): """The union type Union[T1, ..., Tn] (at least one type argument).""" - __slots__ = ('items', 'is_evaluated', 'uses_pep604_syntax') + __slots__ = ("items", "is_evaluated", "uses_pep604_syntax") - def __init__(self, items: Sequence[Type], line: int = -1, column: int = -1, - is_evaluated: bool = True, uses_pep604_syntax: bool = False) -> None: + def __init__( + self, + items: Sequence[Type], + line: int = -1, + column: int = -1, + is_evaluated: bool = True, + uses_pep604_syntax: bool = False, + ) -> None: super().__init__(line, column) self.items = flatten_nested_unions(items) self.can_be_true = any(item.can_be_true for item in items) @@ -2346,12 +2486,13 @@ def __eq__(self, other: object) -> bool: @overload @staticmethod - def make_union(items: Sequence[ProperType], - line: int = -1, column: int = -1) -> ProperType: ... + def make_union(items: Sequence[ProperType], line: int = -1, column: int = -1) -> ProperType: + ... @overload @staticmethod - def make_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: ... + def make_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: + ... @staticmethod def make_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: @@ -2365,7 +2506,7 @@ def make_union(items: Sequence[Type], line: int = -1, column: int = -1) -> Type: def length(self) -> int: return len(self.items) - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_union_type(self) def has_readable_member(self, name: str) -> bool: @@ -2374,9 +2515,11 @@ def has_readable_member(self, name: str) -> bool: TODO: Deal with attributes of TupleType etc. TODO: This should probably be refactored to go elsewhere. """ - return all((isinstance(x, UnionType) and x.has_readable_member(name)) or - (isinstance(x, Instance) and x.type.has_readable_member(name)) - for x in get_proper_types(self.relevant_items())) + return all( + (isinstance(x, UnionType) and x.has_readable_member(name)) + or (isinstance(x, Instance) and x.type.has_readable_member(name)) + for x in get_proper_types(self.relevant_items()) + ) def relevant_items(self) -> List[Type]: """Removes NoneTypes from Unions when strict Optional checking is off.""" @@ -2386,14 +2529,12 @@ def relevant_items(self) -> List[Type]: return [i for i in get_proper_types(self.items) if not isinstance(i, NoneType)] def serialize(self) -> JsonDict: - return {'.class': 'UnionType', - 'items': [t.serialize() for t in self.items], - } + return {".class": "UnionType", "items": [t.serialize() for t in self.items]} @classmethod - def deserialize(cls, data: JsonDict) -> 'UnionType': - assert data['.class'] == 'UnionType' - return UnionType([deserialize_type(t) for t in data['items']]) + def deserialize(cls, data: JsonDict) -> "UnionType": + assert data[".class"] == "UnionType" + return UnionType([deserialize_type(t) for t in data["items"]]) class PartialType(ProperType): @@ -2411,7 +2552,7 @@ class PartialType(ProperType): x = 1 # Infer actual type int for x """ - __slots__ = ('type', 'var', 'value_type') + __slots__ = ("type", "var", "value_type") # None for the 'None' partial type; otherwise a generic class type: Optional[mypy.nodes.TypeInfo] @@ -2420,16 +2561,18 @@ class PartialType(ProperType): # the type argument is Any and will be replaced later. value_type: Optional[Instance] - def __init__(self, - type: 'Optional[mypy.nodes.TypeInfo]', - var: 'mypy.nodes.Var', - value_type: 'Optional[Instance]' = None) -> None: + def __init__( + self, + type: "Optional[mypy.nodes.TypeInfo]", + var: "mypy.nodes.Var", + value_type: "Optional[Instance]" = None, + ) -> None: super().__init__() self.type = type self.var = var self.value_type = value_type - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_partial_type(self) @@ -2443,7 +2586,7 @@ class EllipsisType(ProperType): __slots__ = () - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_ellipsis_type(self) @@ -2479,15 +2622,19 @@ class TypeType(ProperType): assumption). """ - __slots__ = ('item',) + __slots__ = ("item",) # This can't be everything, but it can be a class reference, # a generic class instance, a union, Any, a type variable... item: ProperType - def __init__(self, item: Bogus[Union[Instance, AnyType, TypeVarType, TupleType, NoneType, - CallableType]], *, - line: int = -1, column: int = -1) -> None: + def __init__( + self, + item: Bogus[Union[Instance, AnyType, TypeVarType, TupleType, NoneType, CallableType]], + *, + line: int = -1, + column: int = -1, + ) -> None: """To ensure Type[Union[A, B]] is always represented as Union[Type[A], Type[B]], item of type UnionType must be handled through make_normalized static method. """ @@ -2500,11 +2647,12 @@ def make_normalized(item: Type, *, line: int = -1, column: int = -1) -> ProperTy if isinstance(item, UnionType): return UnionType.make_union( [TypeType.make_normalized(union_item) for union_item in item.items], - line=line, column=column + line=line, + column=column, ) return TypeType(item, line=line, column=column) # type: ignore[arg-type] - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: return visitor.visit_type_type(self) def __hash__(self) -> int: @@ -2516,12 +2664,12 @@ def __eq__(self, other: object) -> bool: return self.item == other.item def serialize(self) -> JsonDict: - return {'.class': 'TypeType', 'item': self.item.serialize()} + return {".class": "TypeType", "item": self.item.serialize()} @classmethod def deserialize(cls, data: JsonDict) -> Type: - assert data['.class'] == 'TypeType' - return TypeType.make_normalized(deserialize_type(data['item'])) + assert data[".class"] == "TypeType" + return TypeType.make_normalized(deserialize_type(data["item"])) class PlaceholderType(ProperType): @@ -2540,14 +2688,14 @@ class str(Sequence[str]): ... exist. """ - __slots__ = ('fullname', 'args') + __slots__ = ("fullname", "args") def __init__(self, fullname: Optional[str], args: List[Type], line: int) -> None: super().__init__(line) self.fullname = fullname # Must be a valid full name of an actual node (or None). self.args = args - def accept(self, visitor: 'TypeVisitor[T]') -> T: + def accept(self, visitor: "TypeVisitor[T]") -> T: assert isinstance(visitor, SyntheticTypeVisitor) return visitor.visit_placeholder_type(self) @@ -2558,9 +2706,13 @@ def serialize(self) -> str: @overload -def get_proper_type(typ: None) -> None: ... +def get_proper_type(typ: None) -> None: + ... + + @overload -def get_proper_type(typ: Type) -> ProperType: ... +def get_proper_type(typ: Type) -> ProperType: + ... def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: @@ -2584,13 +2736,18 @@ def get_proper_type(typ: Optional[Type]) -> Optional[ProperType]: @overload -def get_proper_types(it: Iterable[Type]) -> List[ProperType]: ... # type: ignore[misc] +def get_proper_types(it: Iterable[Type]) -> List[ProperType]: # type: ignore[misc] + ... + + @overload -def get_proper_types(it: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: ... +def get_proper_types(it: Iterable[Optional[Type]]) -> List[Optional[ProperType]]: + ... -def get_proper_types(it: Iterable[Optional[Type]] - ) -> Union[List[ProperType], List[Optional[ProperType]]]: +def get_proper_types( + it: Iterable[Optional[Type]], +) -> Union[List[ProperType], List[Optional[ProperType]]]: return [get_proper_type(t) for t in it] @@ -2599,10 +2756,10 @@ def get_proper_types(it: Iterable[Optional[Type]] # Import them here, after the types are defined. # This is intended as a re-export also. from mypy.type_visitor import ( # noqa - TypeVisitor as TypeVisitor, SyntheticTypeVisitor as SyntheticTypeVisitor, - TypeTranslator as TypeTranslator, TypeQuery as TypeQuery, + TypeTranslator as TypeTranslator, + TypeVisitor as TypeVisitor, ) @@ -2623,13 +2780,13 @@ def __init__(self, id_mapper: Optional[IdMapper] = None) -> None: self.any_as_dots = False def visit_unbound_type(self, t: UnboundType) -> str: - s = t.name + '?' + s = t.name + "?" if t.args: - s += f'[{self.list_str(t.args)}]' + s += f"[{self.list_str(t.args)}]" return s def visit_type_list(self, t: TypeList) -> str: - return f'' + return f"" def visit_callable_argument(self, t: CallableArgument) -> str: typ = t.typ.accept(self) @@ -2640,8 +2797,8 @@ def visit_callable_argument(self, t: CallableArgument) -> str: def visit_any(self, t: AnyType) -> str: if self.any_as_dots and t.type_of_any == TypeOfAny.special_form: - return '...' - return 'Any' + return "..." + return "Any" def visit_none_type(self, t: NoneType) -> str: return "None" @@ -2662,82 +2819,82 @@ def visit_instance(self, t: Instance) -> str: if t.last_known_value and not t.args: # Instances with a literal fallback should never be generic. If they are, # something went wrong so we fall back to showing the full Instance repr. - s = f'{t.last_known_value}?' + s = f"{t.last_known_value}?" else: - s = t.type.fullname or t.type.name or '' + s = t.type.fullname or t.type.name or "" if t.args: - if t.type.fullname == 'builtins.tuple': + if t.type.fullname == "builtins.tuple": assert len(t.args) == 1 - s += f'[{self.list_str(t.args)}, ...]' + s += f"[{self.list_str(t.args)}, ...]" else: - s += f'[{self.list_str(t.args)}]' + s += f"[{self.list_str(t.args)}]" if self.id_mapper: - s += f'<{self.id_mapper.id(t.type)}>' + s += f"<{self.id_mapper.id(t.type)}>" return s def visit_type_var(self, t: TypeVarType) -> str: if t.name is None: # Anonymous type variable type (only numeric id). - s = f'`{t.id}' + s = f"`{t.id}" else: # Named type variable type. - s = f'{t.name}`{t.id}' + s = f"{t.name}`{t.id}" if self.id_mapper and t.upper_bound: - s += f'(upper_bound={t.upper_bound.accept(self)})' + s += f"(upper_bound={t.upper_bound.accept(self)})" return s def visit_param_spec(self, t: ParamSpecType) -> str: # prefixes are displayed as Concatenate - s = '' + s = "" if t.prefix.arg_types: - s += f'[{self.list_str(t.prefix.arg_types)}, **' + s += f"[{self.list_str(t.prefix.arg_types)}, **" if t.name is None: # Anonymous type variable type (only numeric id). - s += f'`{t.id}' + s += f"`{t.id}" else: # Named type variable type. - s += f'{t.name_with_suffix()}`{t.id}' + s += f"{t.name_with_suffix()}`{t.id}" if t.prefix.arg_types: - s += ']' + s += "]" return s def visit_parameters(self, t: Parameters) -> str: # This is copied from visit_callable -- is there a way to decrease duplication? if t.is_ellipsis_args: - return '...' + return "..." - s = '' + s = "" bare_asterisk = False for i in range(len(t.arg_types)): - if s != '': - s += ', ' + if s != "": + s += ", " if t.arg_kinds[i].is_named() and not bare_asterisk: - s += '*, ' + s += "*, " bare_asterisk = True if t.arg_kinds[i] == ARG_STAR: - s += '*' + s += "*" if t.arg_kinds[i] == ARG_STAR2: - s += '**' + s += "**" name = t.arg_names[i] if name: - s += f'{name}: ' + s += f"{name}: " r = t.arg_types[i].accept(self) s += r if t.arg_kinds[i].is_optional(): - s += ' =' + s += " =" - return f'[{s}]' + return f"[{s}]" def visit_type_var_tuple(self, t: TypeVarTupleType) -> str: if t.name is None: # Anonymous type variable type (only numeric id). - s = f'`{t.id}' + s = f"`{t.id}" else: # Named type variable type. - s = f'{t.name}`{t.id}' + s = f"{t.name}`{t.id}" return s def visit_callable_type(self, t: CallableType) -> str: @@ -2747,38 +2904,38 @@ def visit_callable_type(self, t: CallableType) -> str: else: num_skip = 0 - s = '' + s = "" bare_asterisk = False for i in range(len(t.arg_types) - num_skip): - if s != '': - s += ', ' + if s != "": + s += ", " if t.arg_kinds[i].is_named() and not bare_asterisk: - s += '*, ' + s += "*, " bare_asterisk = True if t.arg_kinds[i] == ARG_STAR: - s += '*' + s += "*" if t.arg_kinds[i] == ARG_STAR2: - s += '**' + s += "**" name = t.arg_names[i] if name: - s += name + ': ' + s += name + ": " s += t.arg_types[i].accept(self) if t.arg_kinds[i].is_optional(): - s += ' =' + s += " =" if param_spec is not None: n = param_spec.name if s: - s += ', ' - s += f'*{n}.args, **{n}.kwargs' + s += ", " + s += f"*{n}.args, **{n}.kwargs" - s = f'({s})' + s = f"({s})" if not isinstance(get_proper_type(t.ret_type), NoneType): if t.type_guard is not None: - s += f' -> TypeGuard[{t.type_guard.accept(self)}]' + s += f" -> TypeGuard[{t.type_guard.accept(self)}]" else: - s += f' -> {t.ret_type.accept(self)}' + s += f" -> {t.ret_type.accept(self)}" if t.variables: vs = [] @@ -2787,9 +2944,9 @@ def visit_callable_type(self, t: CallableType) -> str: # We reimplement TypeVarType.__repr__ here in order to support id_mapper. if var.values: vals = f"({', '.join(val.accept(self) for val in var.values)})" - vs.append(f'{var.name} in {vals}') - elif not is_named_instance(var.upper_bound, 'builtins.object'): - vs.append(f'{var.name} <: {var.upper_bound.accept(self)}') + vs.append(f"{var.name} in {vals}") + elif not is_named_instance(var.upper_bound, "builtins.object"): + vs.append(f"{var.name} <: {var.upper_bound.accept(self)}") else: vs.append(var.name) else: @@ -2797,7 +2954,7 @@ def visit_callable_type(self, t: CallableType) -> str: vs.append(var.name) s = f"[{', '.join(vs)}] {s}" - return f'def {s}' + return f"def {s}" def visit_overloaded(self, t: Overloaded) -> str: a = [] @@ -2809,54 +2966,56 @@ def visit_tuple_type(self, t: TupleType) -> str: s = self.list_str(t.items) if t.partial_fallback and t.partial_fallback.type: fallback_name = t.partial_fallback.type.fullname - if fallback_name != 'builtins.tuple': - return f'Tuple[{s}, fallback={t.partial_fallback.accept(self)}]' - return f'Tuple[{s}]' + if fallback_name != "builtins.tuple": + return f"Tuple[{s}, fallback={t.partial_fallback.accept(self)}]" + return f"Tuple[{s}]" def visit_typeddict_type(self, t: TypedDictType) -> str: def item_str(name: str, typ: str) -> str: if name in t.required_keys: - return f'{name!r}: {typ}' + return f"{name!r}: {typ}" else: - return f'{name!r}?: {typ}' + return f"{name!r}?: {typ}" - s = '{' + ', '.join(item_str(name, typ.accept(self)) - for name, typ in t.items.items()) + '}' - prefix = '' + s = ( + "{" + + ", ".join(item_str(name, typ.accept(self)) for name, typ in t.items.items()) + + "}" + ) + prefix = "" if t.fallback and t.fallback.type: if t.fallback.type.fullname not in TPDICT_FB_NAMES: - prefix = repr(t.fallback.type.fullname) + ', ' - return f'TypedDict({prefix}{s})' + prefix = repr(t.fallback.type.fullname) + ", " + return f"TypedDict({prefix}{s})" def visit_raw_expression_type(self, t: RawExpressionType) -> str: return repr(t.literal_value) def visit_literal_type(self, t: LiteralType) -> str: - return f'Literal[{t.value_repr()}]' + return f"Literal[{t.value_repr()}]" def visit_star_type(self, t: StarType) -> str: s = t.type.accept(self) - return f'*{s}' + return f"*{s}" def visit_union_type(self, t: UnionType) -> str: s = self.list_str(t.items) - return f'Union[{s}]' + return f"Union[{s}]" def visit_partial_type(self, t: PartialType) -> str: if t.type is None: - return '' + return "" else: - return ''.format(t.type.name, - ', '.join(['?'] * len(t.type.type_vars))) + return "".format(t.type.name, ", ".join(["?"] * len(t.type.type_vars))) def visit_ellipsis_type(self, t: EllipsisType) -> str: - return '...' + return "..." def visit_type_type(self, t: TypeType) -> str: - return f'Type[{t.item.accept(self)}]' + return f"Type[{t.item.accept(self)}]" def visit_placeholder_type(self, t: PlaceholderType) -> str: - return f'' + return f"" def visit_type_alias_type(self, t: TypeAliasType) -> str: if t.alias is not None: @@ -2865,10 +3024,10 @@ def visit_type_alias_type(self, t: TypeAliasType) -> str: type_str = unrolled.accept(self) self.any_as_dots = False return type_str - return '' + return "" def visit_unpack_type(self, t: UnpackType) -> str: - return f'Unpack[{t.type.accept(self)}]' + return f"Unpack[{t.type.accept(self)}]" def list_str(self, a: Iterable[Type]) -> str: """Convert items of an array to strings (pretty-print types) @@ -2877,7 +3036,7 @@ def list_str(self, a: Iterable[Type]) -> str: res = [] for t in a: res.append(t.accept(self)) - return ', '.join(res) + return ", ".join(res) class UnrollAliasVisitor(TypeTranslator): @@ -2906,8 +3065,7 @@ def strip_type(typ: Type) -> ProperType: if isinstance(typ, CallableType): return typ.copy_modified(name=None) elif isinstance(typ, Overloaded): - return Overloaded([cast(CallableType, strip_type(item)) - for item in typ.items]) + return Overloaded([cast(CallableType, strip_type(item)) for item in typ.items]) else: return typ @@ -2942,8 +3100,9 @@ def visit_type_var(self, typ: TypeVarType) -> Type: return typ -def replace_alias_tvars(tp: Type, vars: List[str], subs: List[Type], - newline: int, newcolumn: int) -> Type: +def replace_alias_tvars( + tp: Type, vars: List[str], subs: List[Type], newline: int, newcolumn: int +) -> Type: """Replace type variables in a generic type alias tp with substitutions subs resetting context. Length of subs should be already checked. """ @@ -2967,8 +3126,9 @@ def has_type_vars(typ: Type) -> bool: return typ.accept(HasTypeVars()) -def flatten_nested_unions(types: Iterable[Type], - handle_type_alias_type: bool = False) -> List[Type]: +def flatten_nested_unions( + types: Iterable[Type], handle_type_alias_type: bool = False +) -> List[Type]: """Flatten nested unions in a type list.""" # This and similar functions on unions can cause infinite recursion # if passed a "pathological" alias like A = Union[int, A] or similar. @@ -2979,8 +3139,9 @@ def flatten_nested_unions(types: Iterable[Type], # TODO: avoid duplicate types in unions (e.g. using hash) for tp in types: if isinstance(tp, ProperType) and isinstance(tp, UnionType): - flat_items.extend(flatten_nested_unions(tp.items, - handle_type_alias_type=handle_type_alias_type)) + flat_items.extend( + flatten_nested_unions(tp.items, handle_type_alias_type=handle_type_alias_type) + ) else: flat_items.append(tp) return flat_items @@ -3018,15 +3179,17 @@ def is_generic_instance(tp: Type) -> bool: def is_optional(t: Type) -> bool: t = get_proper_type(t) - return isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) - for e in t.items) + return isinstance(t, UnionType) and any( + isinstance(get_proper_type(e), NoneType) for e in t.items + ) def remove_optional(typ: Type) -> Type: typ = get_proper_type(typ) if isinstance(typ, UnionType): - return UnionType.make_union([t for t in typ.items - if not isinstance(get_proper_type(t), NoneType)]) + return UnionType.make_union( + [t for t in typ.items if not isinstance(get_proper_type(t), NoneType)] + ) else: return typ @@ -3043,7 +3206,7 @@ def is_literal_type(typ: ProperType, fallback_fullname: str, value: LiteralValue names: Final = globals().copy() -names.pop('NOT_READY', None) +names.pop("NOT_READY", None) deserialize_map: Final = { key: obj.deserialize for key, obj in names.items() @@ -3051,13 +3214,13 @@ def is_literal_type(typ: ProperType, fallback_fullname: str, value: LiteralValue } -def callable_with_ellipsis(any_type: AnyType, - ret_type: Type, - fallback: Instance) -> CallableType: +def callable_with_ellipsis(any_type: AnyType, ret_type: Type, fallback: Instance) -> CallableType: """Construct type Callable[..., ret_type].""" - return CallableType([any_type, any_type], - [ARG_STAR, ARG_STAR2], - [None, None], - ret_type=ret_type, - fallback=fallback, - is_ellipsis_args=True) + return CallableType( + [any_type, any_type], + [ARG_STAR, ARG_STAR2], + [None, None], + ret_type=ret_type, + fallback=fallback, + is_ellipsis_args=True, + ) diff --git a/mypy/typestate.py b/mypy/typestate.py index bbb593ce0daf8..91cfb9562139d 100644 --- a/mypy/typestate.py +++ b/mypy/typestate.py @@ -3,12 +3,13 @@ and potentially other mutable TypeInfo state. This module contains mutable global state. """ -from typing import Dict, Set, Tuple, Optional, List +from typing import Dict, List, Optional, Set, Tuple + from typing_extensions import ClassVar, Final, TypeAlias as _TypeAlias from mypy.nodes import TypeInfo -from mypy.types import Instance, TypeAliasType, get_proper_type, Type from mypy.server.trigger import make_trigger +from mypy.types import Instance, Type, TypeAliasType, get_proper_type # Represents that the 'left' instance is a subtype of the 'right' instance SubtypeRelationship: _TypeAlias = Tuple[Instance, Instance] @@ -32,6 +33,7 @@ class TypeState: The protocol dependencies however are only stored here, and shouldn't be deleted unless not needed any more (e.g. during daemon shutdown). """ + # '_subtype_caches' keeps track of (subtype, supertype) pairs where supertypes are # instances of the given TypeInfo. The cache also keeps track of whether the check # was done in strict optional mode and of the specific *kind* of subtyping relationship, @@ -91,16 +93,18 @@ class TypeState: @staticmethod def is_assumed_subtype(left: Type, right: Type) -> bool: for (l, r) in reversed(TypeState._assuming): - if (get_proper_type(l) == get_proper_type(left) - and get_proper_type(r) == get_proper_type(right)): + if get_proper_type(l) == get_proper_type(left) and get_proper_type( + r + ) == get_proper_type(right): return True return False @staticmethod def is_assumed_proper_subtype(left: Type, right: Type) -> bool: for (l, r) in reversed(TypeState._assuming_proper): - if (get_proper_type(l) == get_proper_type(left) - and get_proper_type(r) == get_proper_type(right)): + if get_proper_type(l) == get_proper_type(left) and get_proper_type( + r + ) == get_proper_type(right): return True return False @@ -138,8 +142,7 @@ def is_cached_subtype_check(kind: SubtypeKind, left: Instance, right: Instance) return (left, right) in subcache @staticmethod - def record_subtype_cache_entry(kind: SubtypeKind, - left: Instance, right: Instance) -> None: + def record_subtype_cache_entry(kind: SubtypeKind, left: Instance, right: Instance) -> None: if left.last_known_value is not None or right.last_known_value is not None: # These are unlikely to match, due to the large space of # possible values. Avoid uselessly increasing cache sizes. @@ -159,11 +162,12 @@ def reset_protocol_deps() -> None: def record_protocol_subtype_check(left_type: TypeInfo, right_type: TypeInfo) -> None: assert right_type.is_protocol TypeState._rechecked_types.add(left_type) - TypeState._attempted_protocols.setdefault( - left_type.fullname, set()).add(right_type.fullname) - TypeState._checked_against_members.setdefault( - left_type.fullname, - set()).update(right_type.protocol_members) + TypeState._attempted_protocols.setdefault(left_type.fullname, set()).add( + right_type.fullname + ) + TypeState._checked_against_members.setdefault(left_type.fullname, set()).update( + right_type.protocol_members + ) @staticmethod def _snapshot_protocol_deps() -> Dict[str, Set[str]]: @@ -202,14 +206,14 @@ def __iter__(self) -> Iterator[int]: # a concrete class may not be reprocessed, so not all -> deps # are added. for base_info in info.mro[:-1]: - trigger = make_trigger(f'{base_info.fullname}.{attr}') - if 'typing' in trigger or 'builtins' in trigger: + trigger = make_trigger(f"{base_info.fullname}.{attr}") + if "typing" in trigger or "builtins" in trigger: # TODO: avoid everything from typeshed continue deps.setdefault(trigger, set()).add(make_trigger(info.fullname)) for proto in TypeState._attempted_protocols[info.fullname]: trigger = make_trigger(info.fullname) - if 'typing' in trigger or 'builtins' in trigger: + if "typing" in trigger or "builtins" in trigger: continue # If any class that was checked against a protocol changes, # we need to reset the subtype cache for the protocol. @@ -228,8 +232,9 @@ def update_protocol_deps(second_map: Optional[Dict[str, Set[str]]] = None) -> No type checked types. If second_map is given, update it as well. This is currently used by FineGrainedBuildManager that maintains normal (non-protocol) dependencies. """ - assert TypeState.proto_deps is not None, ( - "This should not be called after failed cache load") + assert ( + TypeState.proto_deps is not None + ), "This should not be called after failed cache load" new_deps = TypeState._snapshot_protocol_deps() for trigger, targets in new_deps.items(): TypeState.proto_deps.setdefault(trigger, set()).update(targets) diff --git a/mypy/typetraverser.py b/mypy/typetraverser.py index 7d959c97b66ba..b2591afbc5d33 100644 --- a/mypy/typetraverser.py +++ b/mypy/typetraverser.py @@ -3,11 +3,35 @@ from mypy_extensions import trait from mypy.types import ( - Type, SyntheticTypeVisitor, AnyType, UninhabitedType, NoneType, ErasedType, DeletedType, - TypeVarType, LiteralType, Instance, CallableType, TupleType, TypedDictType, UnionType, - Overloaded, TypeType, CallableArgument, UnboundType, TypeList, StarType, EllipsisType, - PlaceholderType, PartialType, RawExpressionType, TypeAliasType, ParamSpecType, Parameters, - UnpackType, TypeVarTupleType, + AnyType, + CallableArgument, + CallableType, + DeletedType, + EllipsisType, + ErasedType, + Instance, + LiteralType, + NoneType, + Overloaded, + Parameters, + ParamSpecType, + PartialType, + PlaceholderType, + RawExpressionType, + StarType, + SyntheticTypeVisitor, + TupleType, + Type, + TypeAliasType, + TypedDictType, + TypeList, + TypeType, + TypeVarTupleType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + UnpackType, ) diff --git a/mypy/typevars.py b/mypy/typevars.py index bd1c325b4c817..aefdf339587c2 100644 --- a/mypy/typevars.py +++ b/mypy/typevars.py @@ -1,11 +1,18 @@ -from typing import Union, List - -from mypy.nodes import TypeInfo +from typing import List, Union from mypy.erasetype import erase_typevars +from mypy.nodes import TypeInfo from mypy.types import ( - Instance, TypeVarType, TupleType, Type, TypeOfAny, AnyType, ParamSpecType, - TypeVarTupleType, UnpackType, TypeVarLikeType + AnyType, + Instance, + ParamSpecType, + TupleType, + Type, + TypeOfAny, + TypeVarLikeType, + TypeVarTupleType, + TypeVarType, + UnpackType, ) @@ -21,17 +28,24 @@ def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]: # Change the line number if isinstance(tv, TypeVarType): tv = TypeVarType( - tv.name, tv.fullname, tv.id, tv.values, - tv.upper_bound, tv.variance, line=-1, column=-1, + tv.name, + tv.fullname, + tv.id, + tv.values, + tv.upper_bound, + tv.variance, + line=-1, + column=-1, ) elif isinstance(tv, TypeVarTupleType): - tv = UnpackType(TypeVarTupleType( - tv.name, tv.fullname, tv.id, tv.upper_bound, line=-1, column=-1 - )) + tv = UnpackType( + TypeVarTupleType(tv.name, tv.fullname, tv.id, tv.upper_bound, line=-1, column=-1) + ) else: assert isinstance(tv, ParamSpecType) - tv = ParamSpecType(tv.name, tv.fullname, tv.id, tv.flavor, tv.upper_bound, - line=-1, column=-1) + tv = ParamSpecType( + tv.name, tv.fullname, tv.id, tv.flavor, tv.upper_bound, line=-1, column=-1 + ) tvs.append(tv) inst = Instance(typ, tvs) if typ.tuple_type is None: diff --git a/mypy/typevartuples.py b/mypy/typevartuples.py index dbbd780f9d562..a4b71da3f5f98 100644 --- a/mypy/typevartuples.py +++ b/mypy/typevartuples.py @@ -1,8 +1,8 @@ """Helpers for interacting with type var tuples.""" -from typing import TypeVar, Optional, Tuple, Sequence +from typing import Optional, Sequence, Tuple, TypeVar -from mypy.types import Instance, UnpackType, ProperType, get_proper_type, Type +from mypy.types import Instance, ProperType, Type, UnpackType, get_proper_type def find_unpack_in_list(items: Sequence[Type]) -> Optional[int]: @@ -24,9 +24,7 @@ def find_unpack_in_list(items: Sequence[Type]) -> Optional[int]: def split_with_prefix_and_suffix( - types: Tuple[T, ...], - prefix: int, - suffix: int, + types: Tuple[T, ...], prefix: int, suffix: int ) -> Tuple[Tuple[T, ...], Tuple[T, ...], Tuple[T, ...]]: if suffix: return (types[:prefix], types[prefix:-suffix], types[-suffix:]) @@ -35,14 +33,12 @@ def split_with_prefix_and_suffix( def split_with_instance( - typ: Instance + typ: Instance, ) -> Tuple[Tuple[Type, ...], Tuple[Type, ...], Tuple[Type, ...]]: assert typ.type.type_var_tuple_prefix is not None assert typ.type.type_var_tuple_suffix is not None return split_with_prefix_and_suffix( - typ.args, - typ.type.type_var_tuple_prefix, - typ.type.type_var_tuple_suffix, + typ.args, typ.type.type_var_tuple_prefix, typ.type.type_var_tuple_suffix ) diff --git a/mypy/util.py b/mypy/util.py index 957fabbb06867..8e6aa8acda932 100644 --- a/mypy/util.py +++ b/mypy/util.py @@ -1,30 +1,43 @@ """Utility functions with no non-trivial dependencies.""" +import hashlib +import io import os import pathlib import re +import shutil import subprocess import sys -import hashlib -import io -import shutil import time - from typing import ( - TypeVar, List, Tuple, Optional, Dict, Sequence, Iterable, Container, IO, Callable, Union, Sized + IO, + Callable, + Container, + Dict, + Iterable, + List, + Optional, + Sequence, + Sized, + Tuple, + TypeVar, + Union, ) -from typing_extensions import Final, Type, Literal + +from typing_extensions import Final, Literal, Type try: import curses + import _curses # noqa + CURSES_ENABLED = True except ImportError: CURSES_ENABLED = False -T = TypeVar('T') +T = TypeVar("T") -ENCODING_RE: Final = re.compile(br"([ \t\v]*#.*(\r\n?|\n))??[ \t\v]*#.*coding[:=][ \t]*([-\w.]+)") +ENCODING_RE: Final = re.compile(rb"([ \t\v]*#.*(\r\n?|\n))??[ \t\v]*#.*coding[:=][ \t]*([-\w.]+)") DEFAULT_SOURCE_OFFSET: Final = 4 DEFAULT_COLUMNS: Final = 80 @@ -47,9 +60,9 @@ "C:\\Python27\\python.exe", ] -SPECIAL_DUNDERS: Final = frozenset(( - "__init__", "__new__", "__call__", "__init_subclass__", "__class_getitem__", -)) +SPECIAL_DUNDERS: Final = frozenset( + ("__init__", "__new__", "__call__", "__init_subclass__", "__class_getitem__") +) def is_dunder(name: str, exclude_special: bool = False) -> bool: @@ -65,7 +78,7 @@ def is_dunder(name: str, exclude_special: bool = False) -> bool: def is_sunder(name: str) -> bool: - return not is_dunder(name) and name.startswith('_') and name.endswith('_') + return not is_dunder(name) and name.startswith("_") and name.endswith("_") def split_module_names(mod_name: str) -> List[str]: @@ -75,8 +88,8 @@ def split_module_names(mod_name: str) -> List[str]: ['a.b.c', 'a.b', and 'a']. """ out = [mod_name] - while '.' in mod_name: - mod_name = mod_name.rsplit('.', 1)[0] + while "." in mod_name: + mod_name = mod_name.rsplit(".", 1)[0] out.append(mod_name) return out @@ -92,8 +105,8 @@ def split_target(modules: Iterable[str], target: str) -> Optional[Tuple[str, str remaining: List[str] = [] while True: if target in modules: - return target, '.'.join(remaining) - components = target.rsplit('.', 1) + return target, ".".join(remaining) + components = target.rsplit(".", 1) if len(components) == 1: return None target = components[0] @@ -106,9 +119,9 @@ def short_type(obj: object) -> str: If obj is None, return 'nil'. For example, if obj is 1, return 'int'. """ if obj is None: - return 'nil' + return "nil" t = str(type(obj)) - return t.split('.')[-1].rstrip("'>") + return t.split(".")[-1].rstrip("'>") def find_python_encoding(text: bytes, pyversion: Tuple[int, int]) -> Tuple[str, int]: @@ -116,13 +129,13 @@ def find_python_encoding(text: bytes, pyversion: Tuple[int, int]) -> Tuple[str, result = ENCODING_RE.match(text) if result: line = 2 if result.group(1) else 1 - encoding = result.group(3).decode('ascii') + encoding = result.group(3).decode("ascii") # Handle some aliases that Python is happy to accept and that are used in the wild. - if encoding.startswith(('iso-latin-1-', 'latin-1-')) or encoding == 'iso-latin-1': - encoding = 'latin-1' + if encoding.startswith(("iso-latin-1-", "latin-1-")) or encoding == "iso-latin-1": + encoding = "latin-1" return encoding, line else: - default_encoding = 'utf8' if pyversion[0] >= 3 else 'ascii' + default_encoding = "utf8" if pyversion[0] >= 3 else "ascii" return default_encoding, -1 @@ -153,8 +166,8 @@ def decode_python_encoding(source: bytes, pyversion: Tuple[int, int]) -> str: Returns the source as a string. """ # check for BOM UTF-8 encoding and strip it out if present - if source.startswith(b'\xef\xbb\xbf'): - encoding = 'utf8' + if source.startswith(b"\xef\xbb\xbf"): + encoding = "utf8" source = source[3:] else: # look at first two lines and check if PEP-263 coding is present @@ -167,8 +180,9 @@ def decode_python_encoding(source: bytes, pyversion: Tuple[int, int]) -> str: return source_text -def read_py_file(path: str, read: Callable[[str], bytes], - pyversion: Tuple[int, int]) -> Optional[List[str]]: +def read_py_file( + path: str, read: Callable[[str], bytes], pyversion: Tuple[int, int] +) -> Optional[List[str]]: """Try reading a Python file as list of source lines. Return None if something goes wrong. @@ -206,27 +220,27 @@ def trim_source_line(line: str, max_len: int, col: int, min_width: int) -> Tuple # If column is not too large so that there is still min_width after it, # the line doesn't need to be trimmed at the start. if col + min_width < max_len: - return line[:max_len] + '...', 0 + return line[:max_len] + "...", 0 # Otherwise, if the column is not too close to the end, trim both sides. if col < len(line) - min_width - 1: offset = col - max_len + min_width + 1 - return '...' + line[offset:col + min_width + 1] + '...', offset - 3 + return "..." + line[offset : col + min_width + 1] + "...", offset - 3 # Finally, if the column is near the end, just trim the start. - return '...' + line[-max_len:], len(line) - max_len - 3 + return "..." + line[-max_len:], len(line) - max_len - 3 def get_mypy_comments(source: str) -> List[Tuple[int, str]]: - PREFIX = '# mypy: ' + PREFIX = "# mypy: " # Don't bother splitting up the lines unless we know it is useful if PREFIX not in source: return [] - lines = source.split('\n') + lines = source.split("\n") results = [] for i, line in enumerate(lines): if line.startswith(PREFIX): - results.append((i + 1, line[len(PREFIX):])) + results.append((i + 1, line[len(PREFIX) :])) return results @@ -240,10 +254,9 @@ def try_find_python2_interpreter() -> Optional[str]: return _python2_interpreter for interpreter in default_python2_interpreter: try: - retcode = subprocess.Popen([ - interpreter, '-c', - 'import sys, typing; assert sys.version_info[:2] == (2, 7)' - ]).wait() + retcode = subprocess.Popen( + [interpreter, "-c", "import sys, typing; assert sys.version_info[:2] == (2, 7)"] + ).wait() if not retcode: _python2_interpreter = interpreter return interpreter @@ -276,25 +289,29 @@ def try_find_python2_interpreter() -> Optional[str]: """ -def write_junit_xml(dt: float, serious: bool, messages: List[str], path: str, - version: str, platform: str) -> None: +def write_junit_xml( + dt: float, serious: bool, messages: List[str], path: str, version: str, platform: str +) -> None: from xml.sax.saxutils import escape + if not messages and not serious: xml = PASS_TEMPLATE.format(time=dt, ver=version, platform=platform) elif not serious: - xml = FAIL_TEMPLATE.format(text=escape('\n'.join(messages)), time=dt, - ver=version, platform=platform) + xml = FAIL_TEMPLATE.format( + text=escape("\n".join(messages)), time=dt, ver=version, platform=platform + ) else: - xml = ERROR_TEMPLATE.format(text=escape('\n'.join(messages)), time=dt, - ver=version, platform=platform) + xml = ERROR_TEMPLATE.format( + text=escape("\n".join(messages)), time=dt, ver=version, platform=platform + ) # checks for a directory structure in path and creates folders if needed xml_dirs = os.path.dirname(os.path.abspath(path)) if not os.path.isdir(xml_dirs): os.makedirs(xml_dirs) - with open(path, 'wb') as f: - f.write(xml.encode('utf-8')) + with open(path, "wb") as f: + f.write(xml.encode("utf-8")) class IdMapper: @@ -319,7 +336,7 @@ def id(self, o: object) -> int: def get_prefix(fullname: str) -> str: """Drop the final component of a qualified name (e.g. ('x.y' -> 'x').""" - return fullname.rsplit('.', 1)[0] + return fullname.rsplit(".", 1)[0] def get_top_two_prefixes(fullname: str) -> Tuple[str, str]: @@ -329,14 +346,13 @@ def get_top_two_prefixes(fullname: str) -> Tuple[str, str]: If fullname has only one component, return (fullname, fullname). """ - components = fullname.split('.', 3) - return components[0], '.'.join(components[:2]) + components = fullname.split(".", 3) + return components[0], ".".join(components[:2]) -def correct_relative_import(cur_mod_id: str, - relative: int, - target: str, - is_cur_package_init_file: bool) -> Tuple[str, bool]: +def correct_relative_import( + cur_mod_id: str, relative: int, target: str, is_cur_package_init_file: bool +) -> Tuple[str, bool]: if relative == 0: return target, True parts = cur_mod_id.split(".") @@ -352,15 +368,16 @@ def correct_relative_import(cur_mod_id: str, fields_cache: Final[Dict[Type[object], List[str]]] = {} -def get_class_descriptors(cls: 'Type[object]') -> Sequence[str]: +def get_class_descriptors(cls: "Type[object]") -> Sequence[str]: import inspect # Lazy import for minor startup speed win + # Maintain a cache of type -> attributes defined by descriptors in the class # (that is, attributes from __slots__ and C extension classes) if cls not in fields_cache: members = inspect.getmembers( - cls, - lambda o: inspect.isgetsetdescriptor(o) or inspect.ismemberdescriptor(o)) - fields_cache[cls] = [x for x, y in members if x != '__weakref__' and x != '__dict__'] + cls, lambda o: inspect.isgetsetdescriptor(o) or inspect.ismemberdescriptor(o) + ) + fields_cache[cls] = [x for x, y in members if x != "__weakref__" and x != "__dict__"] return fields_cache[cls] @@ -372,7 +389,7 @@ def replace_object_state(new: object, old: object, copy_dict: bool = False) -> N Assume that both objects have the same __class__. """ - if hasattr(old, '__dict__'): + if hasattr(old, "__dict__"): if copy_dict: new.__dict__ = dict(old.__dict__) else: @@ -418,7 +435,7 @@ def get_unique_redefinition_name(name: str, existing: Container[str]) -> str: For example, for name 'foo' we try 'foo-redefinition', 'foo-redefinition2', 'foo-redefinition3', etc. until we find one that is not in existing. """ - r_name = name + '-redefinition' + r_name = name + "-redefinition" if r_name not in existing: return r_name @@ -432,27 +449,29 @@ def check_python_version(program: str) -> None: """Report issues with the Python used to run mypy, dmypy, or stubgen""" # Check for known bad Python versions. if sys.version_info[:2] < (3, 6): - sys.exit("Running {name} with Python 3.5 or lower is not supported; " - "please upgrade to 3.6 or newer".format(name=program)) + sys.exit( + "Running {name} with Python 3.5 or lower is not supported; " + "please upgrade to 3.6 or newer".format(name=program) + ) def count_stats(messages: List[str]) -> Tuple[int, int, int]: """Count total number of errors, notes and error_files in message list.""" - errors = [e for e in messages if ': error:' in e] - error_files = {e.split(':')[0] for e in errors} - notes = [e for e in messages if ': note:' in e] + errors = [e for e in messages if ": error:" in e] + error_files = {e.split(":")[0] for e in errors} + notes = [e for e in messages if ": note:" in e] return len(errors), len(notes), len(error_files) def split_words(msg: str) -> List[str]: """Split line of text into words (but not within quoted groups).""" - next_word = '' + next_word = "" res: List[str] = [] allow_break = True for c in msg: - if c == ' ' and allow_break: + if c == " " and allow_break: res.append(next_word) - next_word = '' + next_word = "" continue if c == '"': allow_break = not allow_break @@ -463,13 +482,14 @@ def split_words(msg: str) -> List[str]: def get_terminal_width() -> int: """Get current terminal width if possible, otherwise return the default one.""" - return (int(os.getenv('MYPY_FORCE_TERMINAL_WIDTH', '0')) - or shutil.get_terminal_size().columns - or DEFAULT_COLUMNS) + return ( + int(os.getenv("MYPY_FORCE_TERMINAL_WIDTH", "0")) + or shutil.get_terminal_size().columns + or DEFAULT_COLUMNS + ) -def soft_wrap(msg: str, max_len: int, first_offset: int, - num_indent: int = 0) -> str: +def soft_wrap(msg: str, max_len: int, first_offset: int, num_indent: int = 0) -> str: """Wrap a long error message into few lines. Breaks will only happen between words, and never inside a quoted group @@ -496,12 +516,12 @@ def soft_wrap(msg: str, max_len: int, first_offset: int, max_line_len = max_len - num_indent if lines else max_len - first_offset # Add 1 to account for space between words. if len(next_line) + len(next_word) + 1 <= max_line_len: - next_line += ' ' + next_word + next_line += " " + next_word else: lines.append(next_line) next_line = next_word lines.append(next_line) - padding = '\n' + ' ' * num_indent + padding = "\n" + " " * num_indent return padding.join(lines) @@ -521,8 +541,8 @@ def parse_gray_color(cup: bytes) -> str: """Reproduce a gray color in ANSI escape sequence""" if sys.platform == "win32": assert False, "curses is not available on Windows" - set_color = ''.join([cup[:-1].decode(), 'm']) - gray = curses.tparm(set_color.encode('utf-8'), 1, 89).decode() + set_color = "".join([cup[:-1].decode(), "m"]) + gray = curses.tparm(set_color.encode("utf-8"), 1, 89).decode() return gray @@ -531,54 +551,64 @@ class FancyFormatter: This currently only works on Linux and Mac. """ + def __init__(self, f_out: IO[str], f_err: IO[str], show_error_codes: bool) -> None: self.show_error_codes = show_error_codes # Check if we are in a human-facing terminal on a supported platform. - if sys.platform not in ('linux', 'darwin', 'win32'): + if sys.platform not in ("linux", "darwin", "win32"): self.dummy_term = True return - force_color = int(os.getenv('MYPY_FORCE_COLOR', '0')) + force_color = int(os.getenv("MYPY_FORCE_COLOR", "0")) if not force_color and (not f_out.isatty() or not f_err.isatty()): self.dummy_term = True return - if sys.platform == 'win32': + if sys.platform == "win32": self.dummy_term = not self.initialize_win_colors() else: self.dummy_term = not self.initialize_unix_colors() if not self.dummy_term: - self.colors = {'red': self.RED, 'green': self.GREEN, - 'blue': self.BLUE, 'yellow': self.YELLOW, - 'none': ''} + self.colors = { + "red": self.RED, + "green": self.GREEN, + "blue": self.BLUE, + "yellow": self.YELLOW, + "none": "", + } def initialize_win_colors(self) -> bool: """Return True if initialization was successful and we can use colors, False otherwise""" # Windows ANSI escape sequences are only supported on Threshold 2 and above. # we check with an assert at runtime and an if check for mypy, as asserts do not # yet narrow platform - assert sys.platform == 'win32' - if sys.platform == 'win32': + assert sys.platform == "win32" + if sys.platform == "win32": winver = sys.getwindowsversion() - if (winver.major < MINIMUM_WINDOWS_MAJOR_VT100 - or winver.build < MINIMUM_WINDOWS_BUILD_VT100): + if ( + winver.major < MINIMUM_WINDOWS_MAJOR_VT100 + or winver.build < MINIMUM_WINDOWS_BUILD_VT100 + ): return False import ctypes + kernel32 = ctypes.windll.kernel32 ENABLE_PROCESSED_OUTPUT = 0x1 ENABLE_WRAP_AT_EOL_OUTPUT = 0x2 ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x4 STD_OUTPUT_HANDLE = -11 - kernel32.SetConsoleMode(kernel32.GetStdHandle(STD_OUTPUT_HANDLE), - ENABLE_PROCESSED_OUTPUT - | ENABLE_WRAP_AT_EOL_OUTPUT - | ENABLE_VIRTUAL_TERMINAL_PROCESSING) - self.BOLD = '\033[1m' - self.UNDER = '\033[4m' - self.BLUE = '\033[94m' - self.GREEN = '\033[92m' - self.RED = '\033[91m' - self.YELLOW = '\033[93m' - self.NORMAL = '\033[0m' - self.DIM = '\033[2m' + kernel32.SetConsoleMode( + kernel32.GetStdHandle(STD_OUTPUT_HANDLE), + ENABLE_PROCESSED_OUTPUT + | ENABLE_WRAP_AT_EOL_OUTPUT + | ENABLE_VIRTUAL_TERMINAL_PROCESSING, + ) + self.BOLD = "\033[1m" + self.UNDER = "\033[4m" + self.BLUE = "\033[94m" + self.GREEN = "\033[92m" + self.RED = "\033[91m" + self.YELLOW = "\033[93m" + self.NORMAL = "\033[0m" + self.DIM = "\033[2m" return True return False @@ -600,11 +630,11 @@ def initialize_unix_colors(self) -> bool: except curses.error: # Most likely terminfo not found. return False - bold = curses.tigetstr('bold') - under = curses.tigetstr('smul') - set_color = curses.tigetstr('setaf') - set_eseq = curses.tigetstr('cup') - normal = curses.tigetstr('sgr0') + bold = curses.tigetstr("bold") + under = curses.tigetstr("smul") + set_color = curses.tigetstr("setaf") + set_eseq = curses.tigetstr("cup") + normal = curses.tigetstr("sgr0") if not (bold and under and set_color and set_eseq and normal): return False @@ -619,81 +649,93 @@ def initialize_unix_colors(self) -> bool: self.YELLOW = curses.tparm(set_color, curses.COLOR_YELLOW).decode() return True - def style(self, text: str, color: Literal['red', 'green', 'blue', 'yellow', 'none'], - bold: bool = False, underline: bool = False, dim: bool = False) -> str: + def style( + self, + text: str, + color: Literal["red", "green", "blue", "yellow", "none"], + bold: bool = False, + underline: bool = False, + dim: bool = False, + ) -> str: """Apply simple color and style (underlined or bold).""" if self.dummy_term: return text if bold: start = self.BOLD else: - start = '' + start = "" if underline: start += self.UNDER if dim: start += self.DIM return start + self.colors[color] + text + self.NORMAL - def fit_in_terminal(self, messages: List[str], - fixed_terminal_width: Optional[int] = None) -> List[str]: + def fit_in_terminal( + self, messages: List[str], fixed_terminal_width: Optional[int] = None + ) -> List[str]: """Improve readability by wrapping error messages and trimming source code.""" width = fixed_terminal_width or get_terminal_width() new_messages = messages.copy() for i, error in enumerate(messages): - if ': error:' in error: - loc, msg = error.split('error:', maxsplit=1) - msg = soft_wrap(msg, width, first_offset=len(loc) + len('error: ')) - new_messages[i] = loc + 'error:' + msg - if error.startswith(' ' * DEFAULT_SOURCE_OFFSET) and '^' not in error: + if ": error:" in error: + loc, msg = error.split("error:", maxsplit=1) + msg = soft_wrap(msg, width, first_offset=len(loc) + len("error: ")) + new_messages[i] = loc + "error:" + msg + if error.startswith(" " * DEFAULT_SOURCE_OFFSET) and "^" not in error: # TODO: detecting source code highlights through an indent can be surprising. # Restore original error message and error location. error = error[DEFAULT_SOURCE_OFFSET:] - marker_line = messages[i+1] - marker_column = marker_line.index('^') + marker_line = messages[i + 1] + marker_column = marker_line.index("^") column = marker_column - DEFAULT_SOURCE_OFFSET - if '~' not in marker_line: - marker = '^' + if "~" not in marker_line: + marker = "^" else: # +1 because both ends are included - marker = marker_line[marker_column:marker_line.rindex('~')+1] + marker = marker_line[marker_column : marker_line.rindex("~") + 1] # Let source have some space also on the right side, plus 6 # to accommodate ... on each side. max_len = width - DEFAULT_SOURCE_OFFSET - 6 source_line, offset = trim_source_line(error, max_len, column, MINIMUM_WIDTH) - new_messages[i] = ' ' * DEFAULT_SOURCE_OFFSET + source_line + new_messages[i] = " " * DEFAULT_SOURCE_OFFSET + source_line # Also adjust the error marker position and trim error marker is needed. - new_marker_line = ' ' * (DEFAULT_SOURCE_OFFSET + column - offset) + marker + new_marker_line = " " * (DEFAULT_SOURCE_OFFSET + column - offset) + marker if len(new_marker_line) > len(new_messages[i]) and len(marker) > 3: - new_marker_line = new_marker_line[:len(new_messages[i]) - 3] + '...' - new_messages[i+1] = new_marker_line + new_marker_line = new_marker_line[: len(new_messages[i]) - 3] + "..." + new_messages[i + 1] = new_marker_line return new_messages def colorize(self, error: str) -> str: """Colorize an output line by highlighting the status and error code.""" - if ': error:' in error: - loc, msg = error.split('error:', maxsplit=1) + if ": error:" in error: + loc, msg = error.split("error:", maxsplit=1) if not self.show_error_codes: - return (loc + self.style('error:', 'red', bold=True) + - self.highlight_quote_groups(msg)) - codepos = msg.rfind('[') + return ( + loc + self.style("error:", "red", bold=True) + self.highlight_quote_groups(msg) + ) + codepos = msg.rfind("[") if codepos != -1: code = msg[codepos:] msg = msg[:codepos] else: code = "" # no error code specified - return (loc + self.style('error:', 'red', bold=True) + - self.highlight_quote_groups(msg) + self.style(code, 'yellow')) - elif ': note:' in error: - loc, msg = error.split('note:', maxsplit=1) + return ( + loc + + self.style("error:", "red", bold=True) + + self.highlight_quote_groups(msg) + + self.style(code, "yellow") + ) + elif ": note:" in error: + loc, msg = error.split("note:", maxsplit=1) formatted = self.highlight_quote_groups(self.underline_link(msg)) - return loc + self.style('note:', 'blue') + formatted - elif error.startswith(' ' * DEFAULT_SOURCE_OFFSET): + return loc + self.style("note:", "blue") + formatted + elif error.startswith(" " * DEFAULT_SOURCE_OFFSET): # TODO: detecting source code highlights through an indent can be surprising. - if '^' not in error: - return self.style(error, 'none', dim=True) - return self.style(error, 'red') + if "^" not in error: + return self.style(error, "none", dim=True) + return self.style(error, "red") else: return error @@ -706,12 +748,12 @@ def highlight_quote_groups(self, msg: str) -> str: # Broken error message, don't do any formatting. return msg parts = msg.split('"') - out = '' + out = "" for i, part in enumerate(parts): if i % 2 == 0: - out += self.style(part, 'none') + out += self.style(part, "none") else: - out += self.style('"' + part + '"', 'none', bold=True) + out += self.style('"' + part + '"', "none", bold=True) return out def underline_link(self, note: str) -> str: @@ -719,14 +761,12 @@ def underline_link(self, note: str) -> str: This assumes there is at most one link in the message. """ - match = re.search(r'https?://\S*', note) + match = re.search(r"https?://\S*", note) if not match: return note start = match.start() end = match.end() - return (note[:start] + - self.style(note[start:end], 'none', underline=True) + - note[end:]) + return note[:start] + self.style(note[start:end], "none", underline=True) + note[end:] def format_success(self, n_sources: int, use_color: bool = True) -> str: """Format short summary in case of success. @@ -734,37 +774,41 @@ def format_success(self, n_sources: int, use_color: bool = True) -> str: n_sources is total number of files passed directly on command line, i.e. excluding stubs and followed imports. """ - msg = f'Success: no issues found in {n_sources} source file{plural_s(n_sources)}' + msg = f"Success: no issues found in {n_sources} source file{plural_s(n_sources)}" if not use_color: return msg - return self.style(msg, 'green', bold=True) + return self.style(msg, "green", bold=True) def format_error( - self, n_errors: int, n_files: int, n_sources: int, *, - blockers: bool = False, use_color: bool = True + self, + n_errors: int, + n_files: int, + n_sources: int, + *, + blockers: bool = False, + use_color: bool = True, ) -> str: """Format a short summary in case of errors.""" - msg = f'Found {n_errors} error{plural_s(n_errors)} in {n_files} file{plural_s(n_files)}' + msg = f"Found {n_errors} error{plural_s(n_errors)} in {n_files} file{plural_s(n_files)}" if blockers: - msg += ' (errors prevented further checking)' + msg += " (errors prevented further checking)" else: msg += f" (checked {n_sources} source file{plural_s(n_sources)})" if not use_color: return msg - return self.style(msg, 'red', bold=True) + return self.style(msg, "red", bold=True) def is_typeshed_file(file: str) -> bool: # gross, but no other clear way to tell - return 'typeshed' in os.path.abspath(file).split(os.sep) + return "typeshed" in os.path.abspath(file).split(os.sep) def is_stub_package_file(file: str) -> bool: # Use hacky heuristics to check whether file is part of a PEP 561 stub package. - if not file.endswith('.pyi'): + if not file.endswith(".pyi"): return False - return any(component.endswith('-stubs') - for component in os.path.abspath(file).split(os.sep)) + return any(component.endswith("-stubs") for component in os.path.abspath(file).split(os.sep)) def unnamed_function(name: Optional[str]) -> bool: @@ -783,6 +827,6 @@ def time_spent_us(t0: float) -> int: def plural_s(s: Union[int, Sized]) -> str: count = s if isinstance(s, int) else len(s) if count > 1: - return 's' + return "s" else: - return '' + return "" diff --git a/mypy/version.py b/mypy/version.py index 9c9a75b3dd356..71536d51b83b0 100644 --- a/mypy/version.py +++ b/mypy/version.py @@ -1,16 +1,17 @@ import os + from mypy import git # Base version. # - Release versions have the form "0.NNN". # - Dev versions have the form "0.NNN+dev" (PLUS sign to conform to PEP 440). # - For 1.0 we'll switch back to 1.2.3 form. -__version__ = '0.980+dev' +__version__ = "0.980+dev" base_version = __version__ mypy_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) -if __version__.endswith('+dev') and git.is_git_repo(mypy_dir) and git.have_git(): - __version__ += '.' + git.git_revision(mypy_dir).decode('utf-8') +if __version__.endswith("+dev") and git.is_git_repo(mypy_dir) and git.have_git(): + __version__ += "." + git.git_revision(mypy_dir).decode("utf-8") if git.is_dirty(mypy_dir): - __version__ += '.dirty' + __version__ += ".dirty" del mypy_dir diff --git a/mypy/visitor.py b/mypy/visitor.py index 94fde0b113191..52cd31e5791ed 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -1,9 +1,10 @@ """Generic abstract syntax tree node visitor""" from abc import abstractmethod -from typing import TypeVar, Generic +from typing import Generic, TypeVar + +from mypy_extensions import mypyc_attr, trait from typing_extensions import TYPE_CHECKING -from mypy_extensions import trait, mypyc_attr if TYPE_CHECKING: # break import cycle only needed for mypy @@ -11,194 +12,194 @@ import mypy.patterns -T = TypeVar('T') +T = TypeVar("T") @trait @mypyc_attr(allow_interpreted_subclasses=True) class ExpressionVisitor(Generic[T]): @abstractmethod - def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: + def visit_int_expr(self, o: "mypy.nodes.IntExpr") -> T: pass @abstractmethod - def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T: + def visit_str_expr(self, o: "mypy.nodes.StrExpr") -> T: pass @abstractmethod - def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T: + def visit_bytes_expr(self, o: "mypy.nodes.BytesExpr") -> T: pass @abstractmethod - def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T: + def visit_unicode_expr(self, o: "mypy.nodes.UnicodeExpr") -> T: pass @abstractmethod - def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T: + def visit_float_expr(self, o: "mypy.nodes.FloatExpr") -> T: pass @abstractmethod - def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> T: + def visit_complex_expr(self, o: "mypy.nodes.ComplexExpr") -> T: pass @abstractmethod - def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> T: + def visit_ellipsis(self, o: "mypy.nodes.EllipsisExpr") -> T: pass @abstractmethod - def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> T: + def visit_star_expr(self, o: "mypy.nodes.StarExpr") -> T: pass @abstractmethod - def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: + def visit_name_expr(self, o: "mypy.nodes.NameExpr") -> T: pass @abstractmethod - def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T: + def visit_member_expr(self, o: "mypy.nodes.MemberExpr") -> T: pass @abstractmethod - def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T: + def visit_yield_from_expr(self, o: "mypy.nodes.YieldFromExpr") -> T: pass @abstractmethod - def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T: + def visit_yield_expr(self, o: "mypy.nodes.YieldExpr") -> T: pass @abstractmethod - def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T: + def visit_call_expr(self, o: "mypy.nodes.CallExpr") -> T: pass @abstractmethod - def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T: + def visit_op_expr(self, o: "mypy.nodes.OpExpr") -> T: pass @abstractmethod - def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T: + def visit_comparison_expr(self, o: "mypy.nodes.ComparisonExpr") -> T: pass @abstractmethod - def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T: + def visit_cast_expr(self, o: "mypy.nodes.CastExpr") -> T: pass @abstractmethod - def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T: + def visit_assert_type_expr(self, o: "mypy.nodes.AssertTypeExpr") -> T: pass @abstractmethod - def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T: + def visit_reveal_expr(self, o: "mypy.nodes.RevealExpr") -> T: pass @abstractmethod - def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T: + def visit_super_expr(self, o: "mypy.nodes.SuperExpr") -> T: pass @abstractmethod - def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T: + def visit_unary_expr(self, o: "mypy.nodes.UnaryExpr") -> T: pass @abstractmethod - def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> T: + def visit_assignment_expr(self, o: "mypy.nodes.AssignmentExpr") -> T: pass @abstractmethod - def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T: + def visit_list_expr(self, o: "mypy.nodes.ListExpr") -> T: pass @abstractmethod - def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T: + def visit_dict_expr(self, o: "mypy.nodes.DictExpr") -> T: pass @abstractmethod - def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T: + def visit_tuple_expr(self, o: "mypy.nodes.TupleExpr") -> T: pass @abstractmethod - def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T: + def visit_set_expr(self, o: "mypy.nodes.SetExpr") -> T: pass @abstractmethod - def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T: + def visit_index_expr(self, o: "mypy.nodes.IndexExpr") -> T: pass @abstractmethod - def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T: + def visit_type_application(self, o: "mypy.nodes.TypeApplication") -> T: pass @abstractmethod - def visit_lambda_expr(self, o: 'mypy.nodes.LambdaExpr') -> T: + def visit_lambda_expr(self, o: "mypy.nodes.LambdaExpr") -> T: pass @abstractmethod - def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T: + def visit_list_comprehension(self, o: "mypy.nodes.ListComprehension") -> T: pass @abstractmethod - def visit_set_comprehension(self, o: 'mypy.nodes.SetComprehension') -> T: + def visit_set_comprehension(self, o: "mypy.nodes.SetComprehension") -> T: pass @abstractmethod - def visit_dictionary_comprehension(self, o: 'mypy.nodes.DictionaryComprehension') -> T: + def visit_dictionary_comprehension(self, o: "mypy.nodes.DictionaryComprehension") -> T: pass @abstractmethod - def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T: + def visit_generator_expr(self, o: "mypy.nodes.GeneratorExpr") -> T: pass @abstractmethod - def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T: + def visit_slice_expr(self, o: "mypy.nodes.SliceExpr") -> T: pass @abstractmethod - def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T: + def visit_conditional_expr(self, o: "mypy.nodes.ConditionalExpr") -> T: pass @abstractmethod - def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T: + def visit_backquote_expr(self, o: "mypy.nodes.BackquoteExpr") -> T: pass @abstractmethod - def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: + def visit_type_var_expr(self, o: "mypy.nodes.TypeVarExpr") -> T: pass @abstractmethod - def visit_paramspec_expr(self, o: 'mypy.nodes.ParamSpecExpr') -> T: + def visit_paramspec_expr(self, o: "mypy.nodes.ParamSpecExpr") -> T: pass @abstractmethod - def visit_type_var_tuple_expr(self, o: 'mypy.nodes.TypeVarTupleExpr') -> T: + def visit_type_var_tuple_expr(self, o: "mypy.nodes.TypeVarTupleExpr") -> T: pass @abstractmethod - def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: + def visit_type_alias_expr(self, o: "mypy.nodes.TypeAliasExpr") -> T: pass @abstractmethod - def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: + def visit_namedtuple_expr(self, o: "mypy.nodes.NamedTupleExpr") -> T: pass @abstractmethod - def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + def visit_enum_call_expr(self, o: "mypy.nodes.EnumCallExpr") -> T: pass @abstractmethod - def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: + def visit_typeddict_expr(self, o: "mypy.nodes.TypedDictExpr") -> T: pass @abstractmethod - def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> T: + def visit_newtype_expr(self, o: "mypy.nodes.NewTypeExpr") -> T: pass @abstractmethod - def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T: + def visit__promote_expr(self, o: "mypy.nodes.PromoteExpr") -> T: pass @abstractmethod - def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: + def visit_await_expr(self, o: "mypy.nodes.AwaitExpr") -> T: pass @abstractmethod - def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: + def visit_temp_node(self, o: "mypy.nodes.TempNode") -> T: pass @@ -208,119 +209,119 @@ class StatementVisitor(Generic[T]): # Definitions @abstractmethod - def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T: + def visit_assignment_stmt(self, o: "mypy.nodes.AssignmentStmt") -> T: pass @abstractmethod - def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: + def visit_for_stmt(self, o: "mypy.nodes.ForStmt") -> T: pass @abstractmethod - def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: + def visit_with_stmt(self, o: "mypy.nodes.WithStmt") -> T: pass @abstractmethod - def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T: + def visit_del_stmt(self, o: "mypy.nodes.DelStmt") -> T: pass @abstractmethod - def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T: + def visit_func_def(self, o: "mypy.nodes.FuncDef") -> T: pass @abstractmethod - def visit_overloaded_func_def(self, o: 'mypy.nodes.OverloadedFuncDef') -> T: + def visit_overloaded_func_def(self, o: "mypy.nodes.OverloadedFuncDef") -> T: pass @abstractmethod - def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T: + def visit_class_def(self, o: "mypy.nodes.ClassDef") -> T: pass @abstractmethod - def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T: + def visit_global_decl(self, o: "mypy.nodes.GlobalDecl") -> T: pass @abstractmethod - def visit_nonlocal_decl(self, o: 'mypy.nodes.NonlocalDecl') -> T: + def visit_nonlocal_decl(self, o: "mypy.nodes.NonlocalDecl") -> T: pass @abstractmethod - def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T: + def visit_decorator(self, o: "mypy.nodes.Decorator") -> T: pass # Module structure @abstractmethod - def visit_import(self, o: 'mypy.nodes.Import') -> T: + def visit_import(self, o: "mypy.nodes.Import") -> T: pass @abstractmethod - def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T: + def visit_import_from(self, o: "mypy.nodes.ImportFrom") -> T: pass @abstractmethod - def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T: + def visit_import_all(self, o: "mypy.nodes.ImportAll") -> T: pass # Statements @abstractmethod - def visit_block(self, o: 'mypy.nodes.Block') -> T: + def visit_block(self, o: "mypy.nodes.Block") -> T: pass @abstractmethod - def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T: + def visit_expression_stmt(self, o: "mypy.nodes.ExpressionStmt") -> T: pass @abstractmethod - def visit_operator_assignment_stmt(self, o: 'mypy.nodes.OperatorAssignmentStmt') -> T: + def visit_operator_assignment_stmt(self, o: "mypy.nodes.OperatorAssignmentStmt") -> T: pass @abstractmethod - def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: + def visit_while_stmt(self, o: "mypy.nodes.WhileStmt") -> T: pass @abstractmethod - def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: + def visit_return_stmt(self, o: "mypy.nodes.ReturnStmt") -> T: pass @abstractmethod - def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T: + def visit_assert_stmt(self, o: "mypy.nodes.AssertStmt") -> T: pass @abstractmethod - def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T: + def visit_if_stmt(self, o: "mypy.nodes.IfStmt") -> T: pass @abstractmethod - def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T: + def visit_break_stmt(self, o: "mypy.nodes.BreakStmt") -> T: pass @abstractmethod - def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T: + def visit_continue_stmt(self, o: "mypy.nodes.ContinueStmt") -> T: pass @abstractmethod - def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T: + def visit_pass_stmt(self, o: "mypy.nodes.PassStmt") -> T: pass @abstractmethod - def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T: + def visit_raise_stmt(self, o: "mypy.nodes.RaiseStmt") -> T: pass @abstractmethod - def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: + def visit_try_stmt(self, o: "mypy.nodes.TryStmt") -> T: pass @abstractmethod - def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: + def visit_print_stmt(self, o: "mypy.nodes.PrintStmt") -> T: pass @abstractmethod - def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: + def visit_exec_stmt(self, o: "mypy.nodes.ExecStmt") -> T: pass @abstractmethod - def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + def visit_match_stmt(self, o: "mypy.nodes.MatchStmt") -> T: pass @@ -328,35 +329,35 @@ def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: @mypyc_attr(allow_interpreted_subclasses=True) class PatternVisitor(Generic[T]): @abstractmethod - def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + def visit_as_pattern(self, o: "mypy.patterns.AsPattern") -> T: pass @abstractmethod - def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + def visit_or_pattern(self, o: "mypy.patterns.OrPattern") -> T: pass @abstractmethod - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + def visit_value_pattern(self, o: "mypy.patterns.ValuePattern") -> T: pass @abstractmethod - def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: + def visit_singleton_pattern(self, o: "mypy.patterns.SingletonPattern") -> T: pass @abstractmethod - def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + def visit_sequence_pattern(self, o: "mypy.patterns.SequencePattern") -> T: pass @abstractmethod - def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + def visit_starred_pattern(self, o: "mypy.patterns.StarredPattern") -> T: pass @abstractmethod - def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + def visit_mapping_pattern(self, o: "mypy.patterns.MappingPattern") -> T: pass @abstractmethod - def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + def visit_class_pattern(self, o: "mypy.patterns.ClassPattern") -> T: pass @@ -374,275 +375,273 @@ class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T], Pattern # Not in superclasses: - def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: + def visit_mypy_file(self, o: "mypy.nodes.MypyFile") -> T: pass # TODO: We have a visit_var method, but no visit_typeinfo or any # other non-Statement SymbolNode (accepting those will raise a # runtime error). Maybe this should be resolved in some direction. - def visit_var(self, o: 'mypy.nodes.Var') -> T: + def visit_var(self, o: "mypy.nodes.Var") -> T: pass # Module structure - def visit_import(self, o: 'mypy.nodes.Import') -> T: + def visit_import(self, o: "mypy.nodes.Import") -> T: pass - def visit_import_from(self, o: 'mypy.nodes.ImportFrom') -> T: + def visit_import_from(self, o: "mypy.nodes.ImportFrom") -> T: pass - def visit_import_all(self, o: 'mypy.nodes.ImportAll') -> T: + def visit_import_all(self, o: "mypy.nodes.ImportAll") -> T: pass # Definitions - def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T: + def visit_func_def(self, o: "mypy.nodes.FuncDef") -> T: pass - def visit_overloaded_func_def(self, - o: 'mypy.nodes.OverloadedFuncDef') -> T: + def visit_overloaded_func_def(self, o: "mypy.nodes.OverloadedFuncDef") -> T: pass - def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T: + def visit_class_def(self, o: "mypy.nodes.ClassDef") -> T: pass - def visit_global_decl(self, o: 'mypy.nodes.GlobalDecl') -> T: + def visit_global_decl(self, o: "mypy.nodes.GlobalDecl") -> T: pass - def visit_nonlocal_decl(self, o: 'mypy.nodes.NonlocalDecl') -> T: + def visit_nonlocal_decl(self, o: "mypy.nodes.NonlocalDecl") -> T: pass - def visit_decorator(self, o: 'mypy.nodes.Decorator') -> T: + def visit_decorator(self, o: "mypy.nodes.Decorator") -> T: pass - def visit_type_alias(self, o: 'mypy.nodes.TypeAlias') -> T: + def visit_type_alias(self, o: "mypy.nodes.TypeAlias") -> T: pass - def visit_placeholder_node(self, o: 'mypy.nodes.PlaceholderNode') -> T: + def visit_placeholder_node(self, o: "mypy.nodes.PlaceholderNode") -> T: pass # Statements - def visit_block(self, o: 'mypy.nodes.Block') -> T: + def visit_block(self, o: "mypy.nodes.Block") -> T: pass - def visit_expression_stmt(self, o: 'mypy.nodes.ExpressionStmt') -> T: + def visit_expression_stmt(self, o: "mypy.nodes.ExpressionStmt") -> T: pass - def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T: + def visit_assignment_stmt(self, o: "mypy.nodes.AssignmentStmt") -> T: pass - def visit_operator_assignment_stmt(self, - o: 'mypy.nodes.OperatorAssignmentStmt') -> T: + def visit_operator_assignment_stmt(self, o: "mypy.nodes.OperatorAssignmentStmt") -> T: pass - def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T: + def visit_while_stmt(self, o: "mypy.nodes.WhileStmt") -> T: pass - def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T: + def visit_for_stmt(self, o: "mypy.nodes.ForStmt") -> T: pass - def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T: + def visit_return_stmt(self, o: "mypy.nodes.ReturnStmt") -> T: pass - def visit_assert_stmt(self, o: 'mypy.nodes.AssertStmt') -> T: + def visit_assert_stmt(self, o: "mypy.nodes.AssertStmt") -> T: pass - def visit_del_stmt(self, o: 'mypy.nodes.DelStmt') -> T: + def visit_del_stmt(self, o: "mypy.nodes.DelStmt") -> T: pass - def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T: + def visit_if_stmt(self, o: "mypy.nodes.IfStmt") -> T: pass - def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T: + def visit_break_stmt(self, o: "mypy.nodes.BreakStmt") -> T: pass - def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T: + def visit_continue_stmt(self, o: "mypy.nodes.ContinueStmt") -> T: pass - def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T: + def visit_pass_stmt(self, o: "mypy.nodes.PassStmt") -> T: pass - def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T: + def visit_raise_stmt(self, o: "mypy.nodes.RaiseStmt") -> T: pass - def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T: + def visit_try_stmt(self, o: "mypy.nodes.TryStmt") -> T: pass - def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T: + def visit_with_stmt(self, o: "mypy.nodes.WithStmt") -> T: pass - def visit_print_stmt(self, o: 'mypy.nodes.PrintStmt') -> T: + def visit_print_stmt(self, o: "mypy.nodes.PrintStmt") -> T: pass - def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: + def visit_exec_stmt(self, o: "mypy.nodes.ExecStmt") -> T: pass - def visit_match_stmt(self, o: 'mypy.nodes.MatchStmt') -> T: + def visit_match_stmt(self, o: "mypy.nodes.MatchStmt") -> T: pass # Expressions (default no-op implementation) - def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> T: + def visit_int_expr(self, o: "mypy.nodes.IntExpr") -> T: pass - def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> T: + def visit_str_expr(self, o: "mypy.nodes.StrExpr") -> T: pass - def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> T: + def visit_bytes_expr(self, o: "mypy.nodes.BytesExpr") -> T: pass - def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> T: + def visit_unicode_expr(self, o: "mypy.nodes.UnicodeExpr") -> T: pass - def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> T: + def visit_float_expr(self, o: "mypy.nodes.FloatExpr") -> T: pass - def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> T: + def visit_complex_expr(self, o: "mypy.nodes.ComplexExpr") -> T: pass - def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> T: + def visit_ellipsis(self, o: "mypy.nodes.EllipsisExpr") -> T: pass - def visit_star_expr(self, o: 'mypy.nodes.StarExpr') -> T: + def visit_star_expr(self, o: "mypy.nodes.StarExpr") -> T: pass - def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T: + def visit_name_expr(self, o: "mypy.nodes.NameExpr") -> T: pass - def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T: + def visit_member_expr(self, o: "mypy.nodes.MemberExpr") -> T: pass - def visit_yield_from_expr(self, o: 'mypy.nodes.YieldFromExpr') -> T: + def visit_yield_from_expr(self, o: "mypy.nodes.YieldFromExpr") -> T: pass - def visit_yield_expr(self, o: 'mypy.nodes.YieldExpr') -> T: + def visit_yield_expr(self, o: "mypy.nodes.YieldExpr") -> T: pass - def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T: + def visit_call_expr(self, o: "mypy.nodes.CallExpr") -> T: pass - def visit_op_expr(self, o: 'mypy.nodes.OpExpr') -> T: + def visit_op_expr(self, o: "mypy.nodes.OpExpr") -> T: pass - def visit_comparison_expr(self, o: 'mypy.nodes.ComparisonExpr') -> T: + def visit_comparison_expr(self, o: "mypy.nodes.ComparisonExpr") -> T: pass - def visit_cast_expr(self, o: 'mypy.nodes.CastExpr') -> T: + def visit_cast_expr(self, o: "mypy.nodes.CastExpr") -> T: pass - def visit_assert_type_expr(self, o: 'mypy.nodes.AssertTypeExpr') -> T: + def visit_assert_type_expr(self, o: "mypy.nodes.AssertTypeExpr") -> T: pass - def visit_reveal_expr(self, o: 'mypy.nodes.RevealExpr') -> T: + def visit_reveal_expr(self, o: "mypy.nodes.RevealExpr") -> T: pass - def visit_super_expr(self, o: 'mypy.nodes.SuperExpr') -> T: + def visit_super_expr(self, o: "mypy.nodes.SuperExpr") -> T: pass - def visit_assignment_expr(self, o: 'mypy.nodes.AssignmentExpr') -> T: + def visit_assignment_expr(self, o: "mypy.nodes.AssignmentExpr") -> T: pass - def visit_unary_expr(self, o: 'mypy.nodes.UnaryExpr') -> T: + def visit_unary_expr(self, o: "mypy.nodes.UnaryExpr") -> T: pass - def visit_list_expr(self, o: 'mypy.nodes.ListExpr') -> T: + def visit_list_expr(self, o: "mypy.nodes.ListExpr") -> T: pass - def visit_dict_expr(self, o: 'mypy.nodes.DictExpr') -> T: + def visit_dict_expr(self, o: "mypy.nodes.DictExpr") -> T: pass - def visit_tuple_expr(self, o: 'mypy.nodes.TupleExpr') -> T: + def visit_tuple_expr(self, o: "mypy.nodes.TupleExpr") -> T: pass - def visit_set_expr(self, o: 'mypy.nodes.SetExpr') -> T: + def visit_set_expr(self, o: "mypy.nodes.SetExpr") -> T: pass - def visit_index_expr(self, o: 'mypy.nodes.IndexExpr') -> T: + def visit_index_expr(self, o: "mypy.nodes.IndexExpr") -> T: pass - def visit_type_application(self, o: 'mypy.nodes.TypeApplication') -> T: + def visit_type_application(self, o: "mypy.nodes.TypeApplication") -> T: pass - def visit_lambda_expr(self, o: 'mypy.nodes.LambdaExpr') -> T: + def visit_lambda_expr(self, o: "mypy.nodes.LambdaExpr") -> T: pass - def visit_list_comprehension(self, o: 'mypy.nodes.ListComprehension') -> T: + def visit_list_comprehension(self, o: "mypy.nodes.ListComprehension") -> T: pass - def visit_set_comprehension(self, o: 'mypy.nodes.SetComprehension') -> T: + def visit_set_comprehension(self, o: "mypy.nodes.SetComprehension") -> T: pass - def visit_dictionary_comprehension(self, o: 'mypy.nodes.DictionaryComprehension') -> T: + def visit_dictionary_comprehension(self, o: "mypy.nodes.DictionaryComprehension") -> T: pass - def visit_generator_expr(self, o: 'mypy.nodes.GeneratorExpr') -> T: + def visit_generator_expr(self, o: "mypy.nodes.GeneratorExpr") -> T: pass - def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> T: + def visit_slice_expr(self, o: "mypy.nodes.SliceExpr") -> T: pass - def visit_conditional_expr(self, o: 'mypy.nodes.ConditionalExpr') -> T: + def visit_conditional_expr(self, o: "mypy.nodes.ConditionalExpr") -> T: pass - def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> T: + def visit_backquote_expr(self, o: "mypy.nodes.BackquoteExpr") -> T: pass - def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> T: + def visit_type_var_expr(self, o: "mypy.nodes.TypeVarExpr") -> T: pass - def visit_paramspec_expr(self, o: 'mypy.nodes.ParamSpecExpr') -> T: + def visit_paramspec_expr(self, o: "mypy.nodes.ParamSpecExpr") -> T: pass - def visit_type_var_tuple_expr(self, o: 'mypy.nodes.TypeVarTupleExpr') -> T: + def visit_type_var_tuple_expr(self, o: "mypy.nodes.TypeVarTupleExpr") -> T: pass - def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> T: + def visit_type_alias_expr(self, o: "mypy.nodes.TypeAliasExpr") -> T: pass - def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> T: + def visit_namedtuple_expr(self, o: "mypy.nodes.NamedTupleExpr") -> T: pass - def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> T: + def visit_enum_call_expr(self, o: "mypy.nodes.EnumCallExpr") -> T: pass - def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> T: + def visit_typeddict_expr(self, o: "mypy.nodes.TypedDictExpr") -> T: pass - def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> T: + def visit_newtype_expr(self, o: "mypy.nodes.NewTypeExpr") -> T: pass - def visit__promote_expr(self, o: 'mypy.nodes.PromoteExpr') -> T: + def visit__promote_expr(self, o: "mypy.nodes.PromoteExpr") -> T: pass - def visit_await_expr(self, o: 'mypy.nodes.AwaitExpr') -> T: + def visit_await_expr(self, o: "mypy.nodes.AwaitExpr") -> T: pass - def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> T: + def visit_temp_node(self, o: "mypy.nodes.TempNode") -> T: pass # Patterns - def visit_as_pattern(self, o: 'mypy.patterns.AsPattern') -> T: + def visit_as_pattern(self, o: "mypy.patterns.AsPattern") -> T: pass - def visit_or_pattern(self, o: 'mypy.patterns.OrPattern') -> T: + def visit_or_pattern(self, o: "mypy.patterns.OrPattern") -> T: pass - def visit_value_pattern(self, o: 'mypy.patterns.ValuePattern') -> T: + def visit_value_pattern(self, o: "mypy.patterns.ValuePattern") -> T: pass - def visit_singleton_pattern(self, o: 'mypy.patterns.SingletonPattern') -> T: + def visit_singleton_pattern(self, o: "mypy.patterns.SingletonPattern") -> T: pass - def visit_sequence_pattern(self, o: 'mypy.patterns.SequencePattern') -> T: + def visit_sequence_pattern(self, o: "mypy.patterns.SequencePattern") -> T: pass - def visit_starred_pattern(self, o: 'mypy.patterns.StarredPattern') -> T: + def visit_starred_pattern(self, o: "mypy.patterns.StarredPattern") -> T: pass - def visit_mapping_pattern(self, o: 'mypy.patterns.MappingPattern') -> T: + def visit_mapping_pattern(self, o: "mypy.patterns.MappingPattern") -> T: pass - def visit_class_pattern(self, o: 'mypy.patterns.ClassPattern') -> T: + def visit_class_pattern(self, o: "mypy.patterns.ClassPattern") -> T: pass diff --git a/mypyc/__main__.py b/mypyc/__main__.py index aaaf9a83c8c56..a37b500fae744 100644 --- a/mypyc/__main__.py +++ b/mypyc/__main__.py @@ -15,7 +15,7 @@ import subprocess import sys -base_path = os.path.join(os.path.dirname(__file__), '..') +base_path = os.path.join(os.path.dirname(__file__), "..") setup_format = """\ from setuptools import setup @@ -28,28 +28,28 @@ def main() -> None: - build_dir = 'build' # can this be overridden?? + build_dir = "build" # can this be overridden?? try: os.mkdir(build_dir) except FileExistsError: pass - opt_level = os.getenv("MYPYC_OPT_LEVEL", '3') - debug_level = os.getenv("MYPYC_DEBUG_LEVEL", '1') + opt_level = os.getenv("MYPYC_OPT_LEVEL", "3") + debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1") - setup_file = os.path.join(build_dir, 'setup.py') - with open(setup_file, 'w') as f: + setup_file = os.path.join(build_dir, "setup.py") + with open(setup_file, "w") as f: f.write(setup_format.format(sys.argv[1:], opt_level, debug_level)) # We don't use run_setup (like we do in the test suite) because it throws # away the error code from distutils, and we don't care about the slight # performance loss here. env = os.environ.copy() - base_path = os.path.join(os.path.dirname(__file__), '..') - env['PYTHONPATH'] = base_path + os.pathsep + env.get('PYTHONPATH', '') - cmd = subprocess.run([sys.executable, setup_file, 'build_ext', '--inplace'], env=env) + base_path = os.path.join(os.path.dirname(__file__), "..") + env["PYTHONPATH"] = base_path + os.pathsep + env.get("PYTHONPATH", "") + cmd = subprocess.run([sys.executable, setup_file, "build_ext", "--inplace"], env=env) sys.exit(cmd.returncode) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/mypyc/analysis/attrdefined.py b/mypyc/analysis/attrdefined.py index 6187d143711f8..77b6539eb8b7a 100644 --- a/mypyc/analysis/attrdefined.py +++ b/mypyc/analysis/attrdefined.py @@ -62,19 +62,35 @@ def foo(self) -> int: """ from typing import List, Set, Tuple + from typing_extensions import Final -from mypyc.ir.ops import ( - Register, Assign, AssignMulti, SetMem, SetAttr, Branch, Return, Unreachable, GetAttr, - Call, RegisterOp, BasicBlock, ControlOp -) -from mypyc.ir.rtypes import RInstance -from mypyc.ir.class_ir import ClassIR from mypyc.analysis.dataflow import ( - BaseAnalysisVisitor, AnalysisResult, get_cfg, CFG, MAYBE_ANALYSIS, run_analysis + CFG, + MAYBE_ANALYSIS, + AnalysisResult, + BaseAnalysisVisitor, + get_cfg, + run_analysis, ) from mypyc.analysis.selfleaks import analyze_self_leaks - +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.ops import ( + Assign, + AssignMulti, + BasicBlock, + Branch, + Call, + ControlOp, + GetAttr, + Register, + RegisterOp, + Return, + SetAttr, + SetMem, + Unreachable, +) +from mypyc.ir.rtypes import RInstance # If True, print out all always-defined attributes of native classes (to aid # debugging and testing) @@ -110,12 +126,14 @@ def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: Set[ClassIR]) -> No seen.add(cl) - if (cl.is_trait - or cl.inherits_python - or cl.allow_interpreted_subclasses - or cl.builtin_base is not None - or cl.children is None - or cl.is_serializable()): + if ( + cl.is_trait + or cl.inherits_python + or cl.allow_interpreted_subclasses + or cl.builtin_base is not None + or cl.children is None + or cl.is_serializable() + ): # Give up -- we can't enforce that attributes are always defined. return @@ -123,7 +141,7 @@ def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: Set[ClassIR]) -> No for base in cl.mro[1:]: analyze_always_defined_attrs_in_class(base, seen) - m = cl.get_method('__init__') + m = cl.get_method("__init__") if m is None: cl._always_initialized_attrs = cl.attrs_with_defaults.copy() cl._sometimes_initialized_attrs = cl.attrs_with_defaults.copy() @@ -132,25 +150,26 @@ def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: Set[ClassIR]) -> No cfg = get_cfg(m.blocks) dirty = analyze_self_leaks(m.blocks, self_reg, cfg) maybe_defined = analyze_maybe_defined_attrs_in_init( - m.blocks, self_reg, cl.attrs_with_defaults, cfg) + m.blocks, self_reg, cl.attrs_with_defaults, cfg + ) all_attrs: Set[str] = set() for base in cl.mro: all_attrs.update(base.attributes) maybe_undefined = analyze_maybe_undefined_attrs_in_init( - m.blocks, - self_reg, - initial_undefined=all_attrs - cl.attrs_with_defaults, - cfg=cfg) + m.blocks, self_reg, initial_undefined=all_attrs - cl.attrs_with_defaults, cfg=cfg + ) always_defined = find_always_defined_attributes( - m.blocks, self_reg, all_attrs, maybe_defined, maybe_undefined, dirty) + m.blocks, self_reg, all_attrs, maybe_defined, maybe_undefined, dirty + ) always_defined = {a for a in always_defined if not cl.is_deletable(a)} cl._always_initialized_attrs = always_defined if dump_always_defined: print(cl.name, sorted(always_defined)) cl._sometimes_initialized_attrs = find_sometimes_defined_attributes( - m.blocks, self_reg, maybe_defined, dirty) + m.blocks, self_reg, maybe_defined, dirty + ) mark_attr_initialiation_ops(m.blocks, self_reg, maybe_defined, dirty) @@ -164,12 +183,14 @@ def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: Set[ClassIR]) -> No cl.init_self_leak = any_dirty -def find_always_defined_attributes(blocks: List[BasicBlock], - self_reg: Register, - all_attrs: Set[str], - maybe_defined: AnalysisResult[str], - maybe_undefined: AnalysisResult[str], - dirty: AnalysisResult[None]) -> Set[str]: +def find_always_defined_attributes( + blocks: List[BasicBlock], + self_reg: Register, + all_attrs: Set[str], + maybe_defined: AnalysisResult[str], + maybe_undefined: AnalysisResult[str], + dirty: AnalysisResult[None], +) -> Set[str]: """Find attributes that are always initialized in some basic blocks. The analysis results are expected to be up-to-date for the blocks. @@ -188,30 +209,36 @@ def find_always_defined_attributes(blocks: List[BasicBlock], # the get case, it's fine for the attribute to be undefined. # The set operation will then be treated as initialization. if isinstance(op, SetAttr) and op.obj is self_reg: - if (op.attr in maybe_undefined.before[block, i] - and op.attr in maybe_defined.before[block, i]): + if ( + op.attr in maybe_undefined.before[block, i] + and op.attr in maybe_defined.before[block, i] + ): attrs.discard(op.attr) # Treat an op that might run arbitrary code as an "exit" # in terms of the analysis -- we can't do any inference # afterwards reliably. if dirty.after[block, i]: if not dirty.before[block, i]: - attrs = attrs & (maybe_defined.after[block, i] - - maybe_undefined.after[block, i]) + attrs = attrs & ( + maybe_defined.after[block, i] - maybe_undefined.after[block, i] + ) break if isinstance(op, ControlOp): for target in op.targets(): # Gotos/branches can also be "exits". if not dirty.after[block, i] and dirty.before[target, 0]: - attrs = attrs & (maybe_defined.after[target, 0] - - maybe_undefined.after[target, 0]) + attrs = attrs & ( + maybe_defined.after[target, 0] - maybe_undefined.after[target, 0] + ) return attrs -def find_sometimes_defined_attributes(blocks: List[BasicBlock], - self_reg: Register, - maybe_defined: AnalysisResult[str], - dirty: AnalysisResult[None]) -> Set[str]: +def find_sometimes_defined_attributes( + blocks: List[BasicBlock], + self_reg: Register, + maybe_defined: AnalysisResult[str], + dirty: AnalysisResult[None], +) -> Set[str]: """Find attributes that are sometimes initialized in some basic blocks.""" attrs: Set[str] = set() for block in blocks: @@ -228,10 +255,12 @@ def find_sometimes_defined_attributes(blocks: List[BasicBlock], return attrs -def mark_attr_initialiation_ops(blocks: List[BasicBlock], - self_reg: Register, - maybe_defined: AnalysisResult[str], - dirty: AnalysisResult[None]) -> None: +def mark_attr_initialiation_ops( + blocks: List[BasicBlock], + self_reg: Register, + maybe_defined: AnalysisResult[str], + dirty: AnalysisResult[None], +) -> None: """Tag all SetAttr ops in the basic blocks that initialize attributes. Initialization ops assume that the previous attribute value is the error value, @@ -286,7 +315,7 @@ def visit_unreachable(self, op: Unreachable) -> Tuple[Set[str], Set[str]]: def visit_register_op(self, op: RegisterOp) -> Tuple[Set[str], Set[str]]: if isinstance(op, SetAttr) and op.obj is self.self_reg: return {op.attr}, set() - if isinstance(op, Call) and op.fn.class_name and op.fn.name == '__init__': + if isinstance(op, Call) and op.fn.class_name and op.fn.name == "__init__": return attributes_maybe_initialized_by_init_call(op), set() return set(), set() @@ -300,16 +329,17 @@ def visit_set_mem(self, op: SetMem) -> Tuple[Set[str], Set[str]]: return set(), set() -def analyze_maybe_defined_attrs_in_init(blocks: List[BasicBlock], - self_reg: Register, - attrs_with_defaults: Set[str], - cfg: CFG) -> AnalysisResult[str]: - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=AttributeMaybeDefinedVisitor(self_reg), - initial=attrs_with_defaults, - backward=False, - kind=MAYBE_ANALYSIS) +def analyze_maybe_defined_attrs_in_init( + blocks: List[BasicBlock], self_reg: Register, attrs_with_defaults: Set[str], cfg: CFG +) -> AnalysisResult[str]: + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=AttributeMaybeDefinedVisitor(self_reg), + initial=attrs_with_defaults, + backward=False, + kind=MAYBE_ANALYSIS, + ) class AttributeMaybeUndefinedVisitor(BaseAnalysisVisitor[str]): @@ -334,7 +364,7 @@ def visit_unreachable(self, op: Unreachable) -> Tuple[Set[str], Set[str]]: def visit_register_op(self, op: RegisterOp) -> Tuple[Set[str], Set[str]]: if isinstance(op, SetAttr) and op.obj is self.self_reg: return set(), {op.attr} - if isinstance(op, Call) and op.fn.class_name and op.fn.name == '__init__': + if isinstance(op, Call) and op.fn.class_name and op.fn.name == "__init__": return set(), attributes_initialized_by_init_call(op) return set(), set() @@ -348,16 +378,17 @@ def visit_set_mem(self, op: SetMem) -> Tuple[Set[str], Set[str]]: return set(), set() -def analyze_maybe_undefined_attrs_in_init(blocks: List[BasicBlock], - self_reg: Register, - initial_undefined: Set[str], - cfg: CFG) -> AnalysisResult[str]: - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=AttributeMaybeUndefinedVisitor(self_reg), - initial=initial_undefined, - backward=False, - kind=MAYBE_ANALYSIS) +def analyze_maybe_undefined_attrs_in_init( + blocks: List[BasicBlock], self_reg: Register, initial_undefined: Set[str], cfg: CFG +) -> AnalysisResult[str]: + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=AttributeMaybeUndefinedVisitor(self_reg), + initial=initial_undefined, + backward=False, + kind=MAYBE_ANALYSIS, + ) def update_always_defined_attrs_using_subclasses(cl: ClassIR, seen: Set[ClassIR]) -> None: diff --git a/mypyc/analysis/blockfreq.py b/mypyc/analysis/blockfreq.py index 547fb9ce10d3f..8269297c93a3e 100644 --- a/mypyc/analysis/blockfreq.py +++ b/mypyc/analysis/blockfreq.py @@ -9,7 +9,7 @@ from typing import Set -from mypyc.ir.ops import BasicBlock, Goto, Branch +from mypyc.ir.ops import BasicBlock, Branch, Goto def frequently_executed_blocks(entry_point: BasicBlock) -> Set[BasicBlock]: diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index 528c04af546fa..63af9e53102ad 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -1,17 +1,49 @@ """Data-flow analyses.""" from abc import abstractmethod +from typing import Dict, Generic, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar, Union -from typing import Dict, Tuple, List, Set, TypeVar, Iterator, Generic, Optional, Iterable, Union - +from mypyc.ir.func_ir import all_values from mypyc.ir.ops import ( - Value, ControlOp, - BasicBlock, OpVisitor, Assign, AssignMulti, Integer, LoadErrorValue, RegisterOp, Goto, Branch, - Return, Call, Box, Unbox, Cast, Op, Unreachable, TupleGet, TupleSet, GetAttr, SetAttr, - LoadLiteral, LoadStatic, InitStatic, MethodCall, RaiseStandardError, CallC, LoadGlobal, - Truncate, IntOp, LoadMem, GetElementPtr, LoadAddress, ComparisonOp, SetMem, KeepAlive, Extend + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + ControlOp, + Extend, + GetAttr, + GetElementPtr, + Goto, + InitStatic, + Integer, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + OpVisitor, + RaiseStandardError, + RegisterOp, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unbox, + Unreachable, + Value, ) -from mypyc.ir.func_ir import all_values class CFG: @@ -21,10 +53,12 @@ class CFG: non-empty set of exits. """ - def __init__(self, - succ: Dict[BasicBlock, List[BasicBlock]], - pred: Dict[BasicBlock, List[BasicBlock]], - exits: Set[BasicBlock]) -> None: + def __init__( + self, + succ: Dict[BasicBlock, List[BasicBlock]], + pred: Dict[BasicBlock, List[BasicBlock]], + exits: Set[BasicBlock], + ) -> None: assert exits self.succ = succ self.pred = pred @@ -32,10 +66,10 @@ def __init__(self, def __str__(self) -> str: lines = [] - lines.append('exits: %s' % sorted(self.exits, key=lambda e: e.label)) - lines.append('succ: %s' % self.succ) - lines.append('pred: %s' % self.pred) - return '\n'.join(lines) + lines.append("exits: %s" % sorted(self.exits, key=lambda e: e.label)) + lines.append("succ: %s" % self.succ) + lines.append("pred: %s" % self.pred) + return "\n".join(lines) def get_cfg(blocks: List[BasicBlock]) -> CFG: @@ -50,8 +84,9 @@ def get_cfg(blocks: List[BasicBlock]) -> CFG: exits = set() for block in blocks: - assert not any(isinstance(op, ControlOp) for op in block.ops[:-1]), ( - "Control-flow ops must be at the end of blocks") + assert not any( + isinstance(op, ControlOp) for op in block.ops[:-1] + ), "Control-flow ops must be at the end of blocks" succ = list(block.terminator.targets()) if not succ: @@ -114,18 +149,18 @@ def cleanup_cfg(blocks: List[BasicBlock]) -> None: changed = True -T = TypeVar('T') +T = TypeVar("T") AnalysisDict = Dict[Tuple[BasicBlock, int], Set[T]] class AnalysisResult(Generic[T]): - def __init__(self, before: 'AnalysisDict[T]', after: 'AnalysisDict[T]') -> None: + def __init__(self, before: "AnalysisDict[T]", after: "AnalysisDict[T]") -> None: self.before = before self.after = after def __str__(self) -> str: - return f'before: {self.before}\nafter: {self.after}\n' + return f"before: {self.before}\nafter: {self.after}\n" GenAndKill = Tuple[Set[T], Set[T]] @@ -257,8 +292,7 @@ def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]: def visit_assign(self, op: Assign) -> GenAndKill[Value]: # Loading an error value may undefine the register. - if (isinstance(op.src, LoadErrorValue) - and (op.src.undefines or self.strict_errors)): + if isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors): return set(), {op.dest} else: return {op.dest}, set() @@ -271,27 +305,30 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return set(), set() -def analyze_maybe_defined_regs(blocks: List[BasicBlock], - cfg: CFG, - initial_defined: Set[Value]) -> AnalysisResult[Value]: +def analyze_maybe_defined_regs( + blocks: List[BasicBlock], cfg: CFG, initial_defined: Set[Value] +) -> AnalysisResult[Value]: """Calculate potentially defined registers at each CFG location. A register is defined if it has a value along some path from the initial location. """ - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=DefinedVisitor(), - initial=initial_defined, - backward=False, - kind=MAYBE_ANALYSIS) + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=DefinedVisitor(), + initial=initial_defined, + backward=False, + kind=MAYBE_ANALYSIS, + ) def analyze_must_defined_regs( - blocks: List[BasicBlock], - cfg: CFG, - initial_defined: Set[Value], - regs: Iterable[Value], - strict_errors: bool = False) -> AnalysisResult[Value]: + blocks: List[BasicBlock], + cfg: CFG, + initial_defined: Set[Value], + regs: Iterable[Value], + strict_errors: bool = False, +) -> AnalysisResult[Value]: """Calculate always defined registers at each CFG location. This analysis can work before exception insertion, since it is a @@ -301,13 +338,15 @@ def analyze_must_defined_regs( A register is defined if it has a value along all paths from the initial location. """ - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=DefinedVisitor(strict_errors=strict_errors), - initial=initial_defined, - backward=False, - kind=MUST_ANALYSIS, - universe=set(regs)) + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=DefinedVisitor(strict_errors=strict_errors), + initial=initial_defined, + backward=False, + kind=MUST_ANALYSIS, + universe=set(regs), + ) class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]): @@ -339,20 +378,21 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: def analyze_borrowed_arguments( - blocks: List[BasicBlock], - cfg: CFG, - borrowed: Set[Value]) -> AnalysisResult[Value]: + blocks: List[BasicBlock], cfg: CFG, borrowed: Set[Value] +) -> AnalysisResult[Value]: """Calculate arguments that can use references borrowed from the caller. When assigning to an argument, it no longer is borrowed. """ - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=BorrowedArgumentsVisitor(borrowed), - initial=borrowed, - backward=False, - kind=MUST_ANALYSIS, - universe=borrowed) + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=BorrowedArgumentsVisitor(borrowed), + initial=borrowed, + backward=False, + kind=MUST_ANALYSIS, + universe=borrowed, + ) class UndefinedVisitor(BaseAnalysisVisitor[Value]): @@ -378,9 +418,9 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return set(), set() -def analyze_undefined_regs(blocks: List[BasicBlock], - cfg: CFG, - initial_defined: Set[Value]) -> AnalysisResult[Value]: +def analyze_undefined_regs( + blocks: List[BasicBlock], cfg: CFG, initial_defined: Set[Value] +) -> AnalysisResult[Value]: """Calculate potentially undefined registers at each CFG location. A register is undefined if there is some path from initial block @@ -389,12 +429,14 @@ def analyze_undefined_regs(blocks: List[BasicBlock], Function arguments are assumed to be always defined. """ initial_undefined = set(all_values([], blocks)) - initial_defined - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=UndefinedVisitor(), - initial=initial_undefined, - backward=False, - kind=MAYBE_ANALYSIS) + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=UndefinedVisitor(), + initial=initial_undefined, + backward=False, + kind=MAYBE_ANALYSIS, + ) def non_trivial_sources(op: Op) -> Set[Value]: @@ -435,19 +477,20 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]: return non_trivial_sources(op), set() -def analyze_live_regs(blocks: List[BasicBlock], - cfg: CFG) -> AnalysisResult[Value]: +def analyze_live_regs(blocks: List[BasicBlock], cfg: CFG) -> AnalysisResult[Value]: """Calculate live registers at each CFG location. A register is live at a location if it can be read along some CFG path starting from the location. """ - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=LivenessVisitor(), - initial=set(), - backward=True, - kind=MAYBE_ANALYSIS) + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=LivenessVisitor(), + initial=set(), + backward=True, + kind=MAYBE_ANALYSIS, + ) # Analysis kinds @@ -455,13 +498,15 @@ def analyze_live_regs(blocks: List[BasicBlock], MAYBE_ANALYSIS = 1 -def run_analysis(blocks: List[BasicBlock], - cfg: CFG, - gen_and_kill: OpVisitor[GenAndKill[T]], - initial: Set[T], - kind: int, - backward: bool, - universe: Optional[Set[T]] = None) -> AnalysisResult[T]: +def run_analysis( + blocks: List[BasicBlock], + cfg: CFG, + gen_and_kill: OpVisitor[GenAndKill[T]], + initial: Set[T], + kind: int, + backward: bool, + universe: Optional[Set[T]] = None, +) -> AnalysisResult[T]: """Run a general set-based data flow analysis. Args: @@ -491,8 +536,8 @@ def run_analysis(blocks: List[BasicBlock], ops = list(reversed(ops)) for op in ops: opgen, opkill = op.accept(gen_and_kill) - gen = ((gen - opkill) | opgen) - kill = ((kill - opgen) | opkill) + gen = (gen - opkill) | opgen + kill = (kill - opgen) | opkill block_gen[block] = gen block_kill[block] = kill diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 8217d9865c4bb..631ee30c78b3d 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -1,20 +1,66 @@ """Utilities for checking that internal ir is valid and consistent.""" -from typing import List, Union, Set, Tuple -from mypyc.ir.pprint import format_func +from typing import List, Set, Tuple, Union + +from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR from mypyc.ir.ops import ( - OpVisitor, BasicBlock, Op, ControlOp, Goto, Branch, Return, Unreachable, - Assign, AssignMulti, LoadErrorValue, LoadLiteral, GetAttr, SetAttr, LoadStatic, - InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast, - Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, - LoadMem, SetMem, GetElementPtr, LoadAddress, KeepAlive, Register, Integer, - BaseAssign, Extend + Assign, + AssignMulti, + BaseAssign, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + ControlOp, + DecRef, + Extend, + GetAttr, + GetElementPtr, + Goto, + IncRef, + InitStatic, + Integer, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + OpVisitor, + RaiseStandardError, + Register, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unbox, + Unreachable, ) +from mypyc.ir.pprint import format_func from mypyc.ir.rtypes import ( - RType, RPrimitive, RUnion, is_object_rprimitive, RInstance, RArray, - int_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive, - range_rprimitive, str_rprimitive, bytes_rprimitive, tuple_rprimitive + RArray, + RInstance, + RPrimitive, + RType, + RUnion, + bytes_rprimitive, + dict_rprimitive, + int_rprimitive, + is_object_rprimitive, + list_rprimitive, + range_rprimitive, + set_rprimitive, + str_rprimitive, + tuple_rprimitive, ) -from mypyc.ir.func_ir import FuncIR, FUNC_STATICMETHOD class FnError: @@ -24,9 +70,7 @@ def __init__(self, source: Union[Op, BasicBlock], desc: str) -> None: def __eq__(self, other: object) -> bool: return ( - isinstance(other, FnError) - and self.source == other.source - and self.desc == other.desc + isinstance(other, FnError) and self.source == other.source and self.desc == other.desc ) def __repr__(self) -> str: @@ -42,27 +86,14 @@ def check_func_ir(fn: FuncIR) -> List[FnError]: for block in fn.blocks: if not block.terminated: errors.append( - FnError( - source=block.ops[-1] if block.ops else block, - desc="Block not terminated", - ) + FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated") ) for op in block.ops[:-1]: if isinstance(op, ControlOp): - errors.append( - FnError( - source=op, - desc="Block has operations after control op", - ) - ) + errors.append(FnError(source=op, desc="Block has operations after control op")) if op in op_set: - errors.append( - FnError( - source=op, - desc="Func has a duplicate op", - ) - ) + errors.append(FnError(source=op, desc="Func has a duplicate op")) op_set.add(op) errors.extend(check_op_sources_valid(fn)) @@ -86,7 +117,7 @@ def assert_func_ir_valid(fn: FuncIR) -> None: if errors: raise IrCheckException( "Internal error: Generated invalid IR: \n" - + "\n".join(format_func(fn, [(e.source, e.desc) for e in errors])), + + "\n".join(format_func(fn, [(e.source, e.desc) for e in errors])) ) @@ -98,9 +129,7 @@ def check_op_sources_valid(fn: FuncIR) -> List[FnError]: for block in fn.blocks: valid_ops.update(block.ops) - valid_registers.update( - [op.dest for op in block.ops if isinstance(op, BaseAssign)] - ) + valid_registers.update([op.dest for op in block.ops if isinstance(op, BaseAssign)]) valid_registers.update(fn.arg_regs) @@ -121,8 +150,7 @@ def check_op_sources_valid(fn: FuncIR) -> List[FnError]: if source not in valid_registers: errors.append( FnError( - source=op, - desc=f"Invalid op reference to register {source.name}", + source=op, desc=f"Invalid op reference to register {source.name}" ) ) @@ -130,14 +158,14 @@ def check_op_sources_valid(fn: FuncIR) -> List[FnError]: disjoint_types = { - int_rprimitive.name, - bytes_rprimitive.name, - str_rprimitive.name, - dict_rprimitive.name, - list_rprimitive.name, - set_rprimitive.name, - tuple_rprimitive.name, - range_rprimitive.name, + int_rprimitive.name, + bytes_rprimitive.name, + str_rprimitive.name, + dict_rprimitive.name, + list_rprimitive.name, + set_rprimitive.name, + tuple_rprimitive.name, + range_rprimitive.name, } @@ -177,15 +205,12 @@ def fail(self, source: Op, desc: str) -> None: def check_control_op_targets(self, op: ControlOp) -> None: for target in op.targets(): if target not in self.parent_fn.blocks: - self.fail( - source=op, desc=f"Invalid control operation target: {target.label}" - ) + self.fail(source=op, desc=f"Invalid control operation target: {target.label}") def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None: if not can_coerce_to(src, dest): self.fail( - source=op, - desc=f"Cannot coerce source type {src.name} to dest type {dest.name}", + source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}" ) def visit_goto(self, op: Goto) -> None: @@ -216,13 +241,9 @@ def visit_load_error_value(self, op: LoadErrorValue) -> None: # has an error value. pass - def check_tuple_items_valid_literals( - self, op: LoadLiteral, t: Tuple[object, ...] - ) -> None: + def check_tuple_items_valid_literals(self, op: LoadLiteral, t: Tuple[object, ...]) -> None: for x in t: - if x is not None and not isinstance( - x, (str, bytes, bool, int, float, complex, tuple) - ): + if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)): self.fail(op, f"Invalid type for item of tuple literal: {type(x)})") if isinstance(x, tuple): self.check_tuple_items_valid_literals(op, x) diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 4ba6cfb28eb39..dab066185a975 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -1,14 +1,44 @@ from typing import List, Set, Tuple +from mypyc.analysis.dataflow import CFG, MAYBE_ANALYSIS, AnalysisResult, run_analysis from mypyc.ir.ops import ( - OpVisitor, Register, Goto, Assign, AssignMulti, SetMem, Call, MethodCall, LoadErrorValue, - LoadLiteral, GetAttr, SetAttr, LoadStatic, InitStatic, TupleGet, TupleSet, Box, Unbox, - Cast, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, - GetElementPtr, LoadAddress, KeepAlive, Branch, Return, Unreachable, RegisterOp, BasicBlock, - Extend + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + Extend, + GetAttr, + GetElementPtr, + Goto, + InitStatic, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + OpVisitor, + RaiseStandardError, + Register, + RegisterOp, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unbox, + Unreachable, ) from mypyc.ir.rtypes import RInstance -from mypyc.analysis.dataflow import MAYBE_ANALYSIS, run_analysis, AnalysisResult, CFG GenAndKill = Tuple[Set[None], Set[None]] @@ -55,7 +85,7 @@ def visit_set_mem(self, op: SetMem) -> GenAndKill: def visit_call(self, op: Call) -> GenAndKill: fn = op.fn - if fn.class_name and fn.name == '__init__': + if fn.class_name and fn.name == "__init__": self_type = op.fn.sig.args[0].type assert isinstance(self_type, RInstance) cl = self_type.class_ir @@ -146,12 +176,14 @@ def check_register_op(self, op: RegisterOp) -> GenAndKill: return CLEAN -def analyze_self_leaks(blocks: List[BasicBlock], - self_reg: Register, - cfg: CFG) -> AnalysisResult[None]: - return run_analysis(blocks=blocks, - cfg=cfg, - gen_and_kill=SelfLeakedVisitor(self_reg), - initial=set(), - backward=False, - kind=MAYBE_ANALYSIS) +def analyze_self_leaks( + blocks: List[BasicBlock], self_reg: Register, cfg: CFG +) -> AnalysisResult[None]: + return run_analysis( + blocks=blocks, + cfg=cfg, + gen_and_kill=SelfLeakedVisitor(self_reg), + initial=set(), + backward=False, + kind=MAYBE_ANALYSIS, + ) diff --git a/mypyc/build.py b/mypyc/build.py index f5ff0201ffaf9..4f0a7fcaf0222 100644 --- a/mypyc/build.py +++ b/mypyc/build.py @@ -18,29 +18,27 @@ hackily decide based on whether setuptools has been imported already. """ -import sys -import os.path import hashlib -import time +import os.path import re +import sys +import time +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, cast -from typing import List, Tuple, Any, Optional, Dict, Union, Set, Iterable, cast from typing_extensions import TYPE_CHECKING, NoReturn, Type -from mypy.main import process_options -from mypy.errors import CompileError -from mypy.options import Options from mypy.build import BuildSource +from mypy.errors import CompileError from mypy.fscache import FileSystemCache +from mypy.main import process_options +from mypy.options import Options from mypy.util import write_junit_xml - -from mypyc.namegen import exported_name -from mypyc.options import CompilerOptions -from mypyc.errors import Errors +from mypyc.codegen import emitmodule from mypyc.common import RUNTIME_C_FILES, shared_lib_name +from mypyc.errors import Errors from mypyc.ir.pprint import format_modules - -from mypyc.codegen import emitmodule +from mypyc.namegen import exported_name +from mypyc.options import CompilerOptions if TYPE_CHECKING: from distutils.core import Extension # noqa @@ -52,13 +50,13 @@ if sys.version_info >= (3, 12): # Raise on Python 3.12, since distutils will go away forever raise -from distutils import sysconfig, ccompiler +from distutils import ccompiler, sysconfig -def get_extension() -> Type['Extension']: +def get_extension() -> Type["Extension"]: # We can work with either setuptools or distutils, and pick setuptools # if it has been imported. - use_setuptools = 'setuptools' in sys.modules + use_setuptools = "setuptools" in sys.modules if not use_setuptools: from distutils.core import Extension @@ -74,12 +72,12 @@ def setup_mypycify_vars() -> None: # The vars can contain ints but we only work with str ones vars = cast(Dict[str, str], sysconfig.get_config_vars()) - if sys.platform == 'darwin': + if sys.platform == "darwin": # Disable building 32-bit binaries, since we generate too much code # for a 32-bit Mach-O object. There has to be a better way to do this. - vars['LDSHARED'] = vars['LDSHARED'].replace('-arch i386', '') - vars['LDFLAGS'] = vars['LDFLAGS'].replace('-arch i386', '') - vars['CFLAGS'] = vars['CFLAGS'].replace('-arch i386', '') + vars["LDSHARED"] = vars["LDSHARED"].replace("-arch i386", "") + vars["LDFLAGS"] = vars["LDFLAGS"].replace("-arch i386", "") + vars["CFLAGS"] = vars["CFLAGS"].replace("-arch i386", "") def fail(message: str) -> NoReturn: @@ -87,11 +85,12 @@ def fail(message: str) -> NoReturn: sys.exit(message) -def get_mypy_config(mypy_options: List[str], - only_compile_paths: Optional[Iterable[str]], - compiler_options: CompilerOptions, - fscache: Optional[FileSystemCache], - ) -> Tuple[List[BuildSource], List[BuildSource], Options]: +def get_mypy_config( + mypy_options: List[str], + only_compile_paths: Optional[Iterable[str]], + compiler_options: CompilerOptions, + fscache: Optional[FileSystemCache], +) -> Tuple[List[BuildSource], List[BuildSource], Options]: """Construct mypy BuildSources and Options from file and options lists""" all_sources, options = process_options(mypy_options, fscache=fscache) if only_compile_paths is not None: @@ -101,8 +100,9 @@ def get_mypy_config(mypy_options: List[str], mypyc_sources = all_sources if compiler_options.separate: - mypyc_sources = [src for src in mypyc_sources - if src.path and not src.path.endswith('__init__.py')] + mypyc_sources = [ + src for src in mypyc_sources if src.path and not src.path.endswith("__init__.py") + ] if not mypyc_sources: return mypyc_sources, all_sources, options @@ -112,9 +112,9 @@ def get_mypy_config(mypy_options: List[str], options.python_version = sys.version_info[:2] if options.python_version[0] == 2: - fail('Python 2 not supported') + fail("Python 2 not supported") if not options.strict_optional: - fail('Disabling strict optional checking not supported') + fail("Disabling strict optional checking not supported") options.show_traceback = True # Needed to get types for all AST nodes options.export_types = True @@ -123,13 +123,14 @@ def get_mypy_config(mypy_options: List[str], options.preserve_asts = True for source in mypyc_sources: - options.per_module_options.setdefault(source.module, {})['mypyc'] = True + options.per_module_options.setdefault(source.module, {})["mypyc"] = True return mypyc_sources, all_sources, options def generate_c_extension_shim( - full_module_name: str, module_name: str, dir_name: str, group_name: str) -> str: + full_module_name: str, module_name: str, dir_name: str, group_name: str +) -> str: """Create a C extension shim with a passthrough PyInit function. Arguments: @@ -138,19 +139,22 @@ def generate_c_extension_shim( dir_name: the directory to place source code group_name: the name of the group """ - cname = '%s.c' % full_module_name.replace('.', os.sep) + cname = "%s.c" % full_module_name.replace(".", os.sep) cpath = os.path.join(dir_name, cname) # We load the C extension shim template from a file. # (So that the file could be reused as a bazel template also.) - with open(os.path.join(include_dir(), 'module_shim.tmpl')) as f: + with open(os.path.join(include_dir(), "module_shim.tmpl")) as f: shim_template = f.read() write_file( cpath, - shim_template.format(modname=module_name, - libname=shared_lib_name(group_name), - full_modname=exported_name(full_module_name))) + shim_template.format( + modname=module_name, + libname=shared_lib_name(group_name), + full_modname=exported_name(full_module_name), + ), + ) return cpath @@ -161,21 +165,22 @@ def group_name(modules: List[str]) -> str: return modules[0] h = hashlib.sha1() - h.update(','.join(modules).encode()) + h.update(",".join(modules).encode()) return h.hexdigest()[:20] def include_dir() -> str: """Find the path of the lib-rt dir that needs to be included""" - return os.path.join(os.path.abspath(os.path.dirname(__file__)), 'lib-rt') + return os.path.join(os.path.abspath(os.path.dirname(__file__)), "lib-rt") -def generate_c(sources: List[BuildSource], - options: Options, - groups: emitmodule.Groups, - fscache: FileSystemCache, - compiler_options: CompilerOptions, - ) -> Tuple[List[List[Tuple[str, str]]], str]: +def generate_c( + sources: List[BuildSource], + options: Options, + groups: emitmodule.Groups, + fscache: FileSystemCache, + compiler_options: CompilerOptions, +) -> Tuple[List[List[Tuple[str, str]]], str]: """Drive the actual core compilation step. The groups argument describes how modules are assigned to C @@ -191,7 +196,8 @@ def generate_c(sources: List[BuildSource], result = None try: result = emitmodule.parse_and_typecheck( - sources, options, compiler_options, groups, fscache) + sources, options, compiler_options, groups, fscache + ) messages = result.errors except CompileError as e: messages = e.messages @@ -205,7 +211,8 @@ def generate_c(sources: List[BuildSource], if not messages and result: errors = Errors() modules, ctext = emitmodule.compile_modules_to_c( - result, compiler_options=compiler_options, errors=errors, groups=groups) + result, compiler_options=compiler_options, errors=errors, groups=groups + ) if errors.num_errors: messages.extend(errors.new_messages()) @@ -216,9 +223,7 @@ def generate_c(sources: List[BuildSource], # ... you know, just in case. if options.junit_xml: - py_version = "{}_{}".format( - options.python_version[0], options.python_version[1] - ) + py_version = "{}_{}".format(options.python_version[0], options.python_version[1]) write_junit_xml( t2 - t0, serious, messages, options.junit_xml, py_version, options.platform ) @@ -227,16 +232,17 @@ def generate_c(sources: List[BuildSource], print("\n".join(messages)) sys.exit(1) - return ctext, '\n'.join(format_modules(modules)) + return ctext, "\n".join(format_modules(modules)) -def build_using_shared_lib(sources: List[BuildSource], - group_name: str, - cfiles: List[str], - deps: List[str], - build_dir: str, - extra_compile_args: List[str], - ) -> List['Extension']: +def build_using_shared_lib( + sources: List[BuildSource], + group_name: str, + cfiles: List[str], + deps: List[str], + build_dir: str, + extra_compile_args: List[str], +) -> List["Extension"]: """Produce the list of extension modules when a shared library is needed. This creates one shared library extension module that all of the @@ -248,47 +254,50 @@ def build_using_shared_lib(sources: List[BuildSource], extension module that exports the real initialization functions in Capsules stored in module attributes. """ - extensions = [get_extension()( - shared_lib_name(group_name), - sources=cfiles, - include_dirs=[include_dir(), build_dir], - depends=deps, - extra_compile_args=extra_compile_args, - )] + extensions = [ + get_extension()( + shared_lib_name(group_name), + sources=cfiles, + include_dirs=[include_dir(), build_dir], + depends=deps, + extra_compile_args=extra_compile_args, + ) + ] for source in sources: - module_name = source.module.split('.')[-1] + module_name = source.module.split(".")[-1] shim_file = generate_c_extension_shim(source.module, module_name, build_dir, group_name) # We include the __init__ in the "module name" we stick in the Extension, # since this seems to be needed for it to end up in the right place. full_module_name = source.module assert source.path - if os.path.split(source.path)[1] == '__init__.py': - full_module_name += '.__init__' - extensions.append(get_extension()( - full_module_name, - sources=[shim_file], - extra_compile_args=extra_compile_args, - )) + if os.path.split(source.path)[1] == "__init__.py": + full_module_name += ".__init__" + extensions.append( + get_extension()( + full_module_name, sources=[shim_file], extra_compile_args=extra_compile_args + ) + ) return extensions -def build_single_module(sources: List[BuildSource], - cfiles: List[str], - extra_compile_args: List[str], - ) -> List['Extension']: +def build_single_module( + sources: List[BuildSource], cfiles: List[str], extra_compile_args: List[str] +) -> List["Extension"]: """Produce the list of extension modules for a standalone extension. This contains just one module, since there is no need for a shared module. """ - return [get_extension()( - sources[0].module, - sources=cfiles, - include_dirs=[include_dir()], - extra_compile_args=extra_compile_args, - )] + return [ + get_extension()( + sources[0].module, + sources=cfiles, + include_dirs=[include_dir()], + extra_compile_args=extra_compile_args, + ) + ] def write_file(path: str, contents: str) -> None: @@ -300,15 +309,15 @@ def write_file(path: str, contents: str) -> None: """ # We encode it ourselves and open the files as binary to avoid windows # newline translation - encoded_contents = contents.encode('utf-8') + encoded_contents = contents.encode("utf-8") try: - with open(path, 'rb') as f: + with open(path, "rb") as f: old_contents: Optional[bytes] = f.read() except OSError: old_contents = None if old_contents != encoded_contents: os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, 'wb') as g: + with open(path, "wb") as g: g.write(encoded_contents) # Fudge the mtime forward because otherwise when two builds happen close @@ -381,19 +390,20 @@ def mypyc_build( separate: Union[bool, List[Tuple[List[str], Optional[str]]]] = False, only_compile_paths: Optional[Iterable[str]] = None, skip_cgen_input: Optional[Any] = None, - always_use_shared_lib: bool = False + always_use_shared_lib: bool = False, ) -> Tuple[emitmodule.Groups, List[Tuple[List[str], List[str]]]]: """Do the front and middle end of mypyc building, producing and writing out C source.""" fscache = FileSystemCache() mypyc_sources, all_sources, options = get_mypy_config( - paths, only_compile_paths, compiler_options, fscache) + paths, only_compile_paths, compiler_options, fscache + ) # We generate a shared lib if there are multiple modules or if any # of the modules are in package. (Because I didn't want to fuss # around with making the single module code handle packages.) use_shared_lib = ( len(mypyc_sources) > 1 - or any('.' in x.module for x in mypyc_sources) + or any("." in x.module for x in mypyc_sources) or always_use_shared_lib ) @@ -402,10 +412,11 @@ def mypyc_build( # We let the test harness just pass in the c file contents instead # so that it can do a corner-cutting version without full stubs. if not skip_cgen_input: - group_cfiles, ops_text = generate_c(all_sources, options, groups, fscache, - compiler_options=compiler_options) + group_cfiles, ops_text = generate_c( + all_sources, options, groups, fscache, compiler_options=compiler_options + ) # TODO: unique names? - write_file(os.path.join(compiler_options.target_dir, 'ops.txt'), ops_text) + write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text) else: group_cfiles = skip_cgen_input @@ -417,7 +428,7 @@ def mypyc_build( for cfile, ctext in cfiles: cfile = os.path.join(compiler_options.target_dir, cfile) write_file(cfile, ctext) - if os.path.splitext(cfile)[1] == '.c': + if os.path.splitext(cfile)[1] == ".c": cfilenames.append(cfile) deps = [os.path.join(compiler_options.target_dir, dep) for dep in get_header_deps(cfiles)] @@ -438,8 +449,8 @@ def mypycify( separate: Union[bool, List[Tuple[List[str], Optional[str]]]] = False, skip_cgen_input: Optional[Any] = None, target_dir: Optional[str] = None, - include_runtime_files: Optional[bool] = None -) -> List['Extension']: + include_runtime_files: Optional[bool] = None, +) -> List["Extension"]: """Main entry point to building using mypyc. This produces a list of Extension objects that should be passed as the @@ -511,44 +522,45 @@ def mypycify( build_dir = compiler_options.target_dir cflags: List[str] = [] - if compiler.compiler_type == 'unix': + if compiler.compiler_type == "unix": cflags += [ - f'-O{opt_level}', - f'-g{debug_level}', - '-Werror', '-Wno-unused-function', '-Wno-unused-label', - '-Wno-unreachable-code', '-Wno-unused-variable', - '-Wno-unused-command-line-argument', '-Wno-unknown-warning-option', + f"-O{opt_level}", + f"-g{debug_level}", + "-Werror", + "-Wno-unused-function", + "-Wno-unused-label", + "-Wno-unreachable-code", + "-Wno-unused-variable", + "-Wno-unused-command-line-argument", + "-Wno-unknown-warning-option", ] - if 'gcc' in compiler.compiler[0] or 'gnu-cc' in compiler.compiler[0]: + if "gcc" in compiler.compiler[0] or "gnu-cc" in compiler.compiler[0]: # This flag is needed for gcc but does not exist on clang. - cflags += ['-Wno-unused-but-set-variable'] - elif compiler.compiler_type == 'msvc': + cflags += ["-Wno-unused-but-set-variable"] + elif compiler.compiler_type == "msvc": # msvc doesn't have levels, '/O2' is full and '/Od' is disable - if opt_level == '0': - opt_level = 'd' - elif opt_level in ('1', '2', '3'): - opt_level = '2' - if debug_level == '0': + if opt_level == "0": + opt_level = "d" + elif opt_level in ("1", "2", "3"): + opt_level = "2" + if debug_level == "0": debug_level = "NONE" - elif debug_level == '1': + elif debug_level == "1": debug_level = "FASTLINK" - elif debug_level in ('2', '3'): + elif debug_level in ("2", "3"): debug_level = "FULL" cflags += [ - f'/O{opt_level}', - f'/DEBUG:{debug_level}', - '/wd4102', # unreferenced label - '/wd4101', # unreferenced local variable - '/wd4146', # negating unsigned int + f"/O{opt_level}", + f"/DEBUG:{debug_level}", + "/wd4102", # unreferenced label + "/wd4101", # unreferenced local variable + "/wd4146", # negating unsigned int ] if multi_file: # Disable whole program optimization in multi-file mode so # that we actually get the compilation speed and memory # use wins that multi-file mode is intended for. - cflags += [ - '/GL-', - '/wd9025', # warning about overriding /GL - ] + cflags += ["/GL-", "/wd9025"] # warning about overriding /GL # If configured to (defaults to yes in multi-file mode), copy the # runtime library in. Otherwise it just gets #included to save on @@ -557,17 +569,26 @@ def mypycify( if not compiler_options.include_runtime_files: for name in RUNTIME_C_FILES: rt_file = os.path.join(build_dir, name) - with open(os.path.join(include_dir(), name), encoding='utf-8') as f: + with open(os.path.join(include_dir(), name), encoding="utf-8") as f: write_file(rt_file, f.read()) shared_cfilenames.append(rt_file) extensions = [] for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames): if lib_name: - extensions.extend(build_using_shared_lib( - group_sources, lib_name, cfilenames + shared_cfilenames, deps, build_dir, cflags)) + extensions.extend( + build_using_shared_lib( + group_sources, + lib_name, + cfilenames + shared_cfilenames, + deps, + build_dir, + cflags, + ) + ) else: - extensions.extend(build_single_module( - group_sources, cfilenames + shared_cfilenames, cflags)) + extensions.extend( + build_single_module(group_sources, cfilenames + shared_cfilenames, cflags) + ) return extensions diff --git a/mypyc/codegen/cstring.py b/mypyc/codegen/cstring.py index dba2bf8142462..c4d1a422f4d18 100644 --- a/mypyc/codegen/cstring.py +++ b/mypyc/codegen/cstring.py @@ -22,7 +22,6 @@ from typing_extensions import Final - CHAR_MAP: Final = [f"\\{i:03o}" for i in range(256)] # It is safe to use string.printable as it always uses the C locale. @@ -31,18 +30,18 @@ # These assignments must come last because we prioritize simple escape # sequences over any other representation. -for c in ('\'', '"', '\\', 'a', 'b', 'f', 'n', 'r', 't', 'v'): - escaped = f'\\{c}' - decoded = escaped.encode('ascii').decode('unicode_escape') +for c in ("'", '"', "\\", "a", "b", "f", "n", "r", "t", "v"): + escaped = f"\\{c}" + decoded = escaped.encode("ascii").decode("unicode_escape") CHAR_MAP[ord(decoded)] = escaped # This escape sequence is invalid in Python. -CHAR_MAP[ord('?')] = r'\?' +CHAR_MAP[ord("?")] = r"\?" def encode_bytes_as_c_string(b: bytes) -> str: """Produce contents of a C string literal for a byte string, without quotes.""" - escaped = ''.join([CHAR_MAP[i] for i in b]) + escaped = "".join([CHAR_MAP[i] for i in b]) return escaped diff --git a/mypyc/codegen/emit.py b/mypyc/codegen/emit.py index b1f886ee3f5f8..0c9f708472d03 100644 --- a/mypyc/codegen/emit.py +++ b/mypyc/codegen/emit.py @@ -1,30 +1,54 @@ """Utilities for emitting C code.""" -from mypy.backports import OrderedDict -from typing import List, Set, Dict, Optional, Callable, Union, Tuple -from typing_extensions import Final - import sys +from typing import Callable, Dict, List, Optional, Set, Tuple, Union +from typing_extensions import Final + +from mypy.backports import OrderedDict +from mypyc.codegen.literals import Literals from mypyc.common import ( - REG_PREFIX, ATTR_PREFIX, STATIC_PREFIX, TYPE_PREFIX, NATIVE_PREFIX, - FAST_ISINSTANCE_MAX_SUBCLASSES, use_vectorcall + ATTR_PREFIX, + FAST_ISINSTANCE_MAX_SUBCLASSES, + NATIVE_PREFIX, + REG_PREFIX, + STATIC_PREFIX, + TYPE_PREFIX, + use_vectorcall, ) +from mypyc.ir.class_ir import ClassIR, all_concrete_classes +from mypyc.ir.func_ir import FuncDecl from mypyc.ir.ops import BasicBlock, Value from mypyc.ir.rtypes import ( - RType, RTuple, RInstance, RUnion, RPrimitive, - is_float_rprimitive, is_bool_rprimitive, is_int_rprimitive, is_short_int_rprimitive, - is_list_rprimitive, is_dict_rprimitive, is_set_rprimitive, is_tuple_rprimitive, - is_none_rprimitive, is_object_rprimitive, object_rprimitive, is_str_rprimitive, - int_rprimitive, is_optional_type, optional_value_type, is_int32_rprimitive, - is_int64_rprimitive, is_bit_rprimitive, is_range_rprimitive, is_bytes_rprimitive, - is_fixed_width_rtype + RInstance, + RPrimitive, + RTuple, + RType, + RUnion, + int_rprimitive, + is_bit_rprimitive, + is_bool_rprimitive, + is_bytes_rprimitive, + is_dict_rprimitive, + is_fixed_width_rtype, + is_float_rprimitive, + is_int32_rprimitive, + is_int64_rprimitive, + is_int_rprimitive, + is_list_rprimitive, + is_none_rprimitive, + is_object_rprimitive, + is_optional_type, + is_range_rprimitive, + is_set_rprimitive, + is_short_int_rprimitive, + is_str_rprimitive, + is_tuple_rprimitive, + object_rprimitive, + optional_value_type, ) -from mypyc.ir.func_ir import FuncDecl -from mypyc.ir.class_ir import ClassIR, all_concrete_classes from mypyc.namegen import NameGenerator, exported_name from mypyc.sametype import is_same_type -from mypyc.codegen.literals import Literals # Whether to insert debug asserts for all error handling, to quickly # catch errors propagating without exceptions set. @@ -47,14 +71,15 @@ class HeaderDeclaration: other modules in the linking table. """ - def __init__(self, - decl: Union[str, List[str]], - defn: Optional[List[str]] = None, - *, - dependencies: Optional[Set[str]] = None, - is_type: bool = False, - needs_export: bool = False - ) -> None: + def __init__( + self, + decl: Union[str, List[str]], + defn: Optional[List[str]] = None, + *, + dependencies: Optional[Set[str]] = None, + is_type: bool = False, + needs_export: bool = False, + ) -> None: self.decl = [decl] if isinstance(decl, str) else decl self.defn = defn self.dependencies = dependencies or set() @@ -65,11 +90,12 @@ def __init__(self, class EmitterContext: """Shared emitter state for a compilation group.""" - def __init__(self, - names: NameGenerator, - group_name: Optional[str] = None, - group_map: Optional[Dict[str, Optional[str]]] = None, - ) -> None: + def __init__( + self, + names: NameGenerator, + group_name: Optional[str] = None, + group_map: Optional[Dict[str, Optional[str]]] = None, + ) -> None: """Setup shared emitter state. Args: @@ -114,11 +140,9 @@ def __init__(self, label: str) -> None: class TracebackAndGotoHandler(ErrorHandler): """Add traceback item and goto label on error.""" - def __init__(self, - label: str, - source_path: str, - module_name: str, - traceback_entry: Tuple[str, int]) -> None: + def __init__( + self, label: str, source_path: str, module_name: str, traceback_entry: Tuple[str, int] + ) -> None: self.label = label self.source_path = source_path self.module_name = module_name @@ -135,11 +159,12 @@ def __init__(self, value: str) -> None: class Emitter: """Helper for C code generation.""" - def __init__(self, - context: EmitterContext, - value_names: Optional[Dict[Value, str]] = None, - capi_version: Optional[Tuple[int, int]] = None, - ) -> None: + def __init__( + self, + context: EmitterContext, + value_names: Optional[Dict[Value, str]] = None, + capi_version: Optional[Tuple[int, int]] = None, + ) -> None: self.context = context self.capi_version = capi_version or sys.version_info[:2] self.names = context.names @@ -157,7 +182,7 @@ def dedent(self) -> None: assert self._indent >= 0 def label(self, label: BasicBlock) -> str: - return 'CPyL%s' % label.label + return "CPyL%s" % label.label def reg(self, reg: Value) -> str: return REG_PREFIX + self.value_names[reg] @@ -165,11 +190,11 @@ def reg(self, reg: Value) -> str: def attr(self, name: str) -> str: return ATTR_PREFIX + name - def emit_line(self, line: str = '') -> None: - if line.startswith('}'): + def emit_line(self, line: str = "") -> None: + if line.startswith("}"): self.dedent() - self.fragments.append(self._indent * ' ' + line + '\n') - if line.endswith('{'): + self.fragments.append(self._indent * " " + line + "\n") + if line.endswith("{"): self.indent() def emit_lines(self, *lines: str) -> None: @@ -182,23 +207,23 @@ def emit_label(self, label: Union[BasicBlock, str]) -> None: else: text = self.label(label) # Extra semicolon prevents an error when the next line declares a tempvar - self.fragments.append(f'{text}: ;\n') + self.fragments.append(f"{text}: ;\n") - def emit_from_emitter(self, emitter: 'Emitter') -> None: + def emit_from_emitter(self, emitter: "Emitter") -> None: self.fragments.extend(emitter.fragments) def emit_printf(self, fmt: str, *args: str) -> None: - fmt = fmt.replace('\n', '\\n') - self.emit_line('printf(%s);' % ', '.join(['"%s"' % fmt] + list(args))) - self.emit_line('fflush(stdout);') + fmt = fmt.replace("\n", "\\n") + self.emit_line("printf(%s);" % ", ".join(['"%s"' % fmt] + list(args))) + self.emit_line("fflush(stdout);") def temp_name(self) -> str: self.context.temp_counter += 1 - return '__tmp%d' % self.context.temp_counter + return "__tmp%d" % self.context.temp_counter def new_label(self) -> str: self.context.temp_counter += 1 - return '__LL%d' % self.context.temp_counter + return "__LL%d" % self.context.temp_counter def get_module_group_prefix(self, module_name: str) -> str: """Get the group prefix for a module (relative to the current group). @@ -222,9 +247,9 @@ def get_module_group_prefix(self, module_name: str) -> str: target_group_name = groups.get(module_name) if target_group_name and target_group_name != self.context.group_name: self.context.group_deps.add(target_group_name) - return f'exports_{exported_name(target_group_name)}.' + return f"exports_{exported_name(target_group_name)}." else: - return '' + return "" def get_group_prefix(self, obj: Union[ClassIR, FuncDecl]) -> str: """Get the group prefix for an object.""" @@ -241,12 +266,12 @@ def static_name(self, id: str, module: Optional[str], prefix: str = STATIC_PREFI overlap with other calls to this method within a compilation group. """ - lib_prefix = '' if not module else self.get_module_group_prefix(module) + lib_prefix = "" if not module else self.get_module_group_prefix(module) # If we are accessing static via the export table, we need to dereference # the pointer also. - star_maybe = '*' if lib_prefix else '' - suffix = self.names.private_name(module or '', id) - return f'{star_maybe}{lib_prefix}{prefix}{suffix}' + star_maybe = "*" if lib_prefix else "" + suffix = self.names.private_name(module or "", id) + return f"{star_maybe}{lib_prefix}{prefix}{suffix}" def type_struct_name(self, cl: ClassIR) -> str: return self.static_name(cl.name, cl.module_name, prefix=TYPE_PREFIX) @@ -257,14 +282,14 @@ def ctype(self, rtype: RType) -> str: def ctype_spaced(self, rtype: RType) -> str: """Adds a space after ctype for non-pointers.""" ctype = self.ctype(rtype) - if ctype[-1] == '*': + if ctype[-1] == "*": return ctype else: - return ctype + ' ' + return ctype + " " def c_undefined_value(self, rtype: RType) -> str: if not rtype.is_unboxed: - return 'NULL' + return "NULL" elif isinstance(rtype, RPrimitive): return rtype.c_undefined elif isinstance(rtype, RTuple): @@ -275,67 +300,73 @@ def c_error_value(self, rtype: RType) -> str: return self.c_undefined_value(rtype) def native_function_name(self, fn: FuncDecl) -> str: - return f'{NATIVE_PREFIX}{fn.cname(self.names)}' + return f"{NATIVE_PREFIX}{fn.cname(self.names)}" def tuple_c_declaration(self, rtuple: RTuple) -> List[str]: result = [ - f'#ifndef MYPYC_DECLARED_{rtuple.struct_name}', - f'#define MYPYC_DECLARED_{rtuple.struct_name}', - f'typedef struct {rtuple.struct_name} {{', + f"#ifndef MYPYC_DECLARED_{rtuple.struct_name}", + f"#define MYPYC_DECLARED_{rtuple.struct_name}", + f"typedef struct {rtuple.struct_name} {{", ] if len(rtuple.types) == 0: # empty tuple # Empty tuples contain a flag so that they can still indicate # error values. - result.append('int empty_struct_error_flag;') + result.append("int empty_struct_error_flag;") else: i = 0 for typ in rtuple.types: - result.append(f'{self.ctype_spaced(typ)}f{i};') + result.append(f"{self.ctype_spaced(typ)}f{i};") i += 1 - result.append(f'}} {rtuple.struct_name};') + result.append(f"}} {rtuple.struct_name};") values = self.tuple_undefined_value_helper(rtuple) - result.append('static {} {} = {{ {} }};'.format( - self.ctype(rtuple), self.tuple_undefined_value(rtuple), ''.join(values))) - result.append('#endif') - result.append('') + result.append( + "static {} {} = {{ {} }};".format( + self.ctype(rtuple), self.tuple_undefined_value(rtuple), "".join(values) + ) + ) + result.append("#endif") + result.append("") return result def use_vectorcall(self) -> bool: return use_vectorcall(self.capi_version) - def emit_undefined_attr_check(self, rtype: RType, attr_expr: str, - compare: str, - unlikely: bool = False) -> None: + def emit_undefined_attr_check( + self, rtype: RType, attr_expr: str, compare: str, unlikely: bool = False + ) -> None: if isinstance(rtype, RTuple): - check = '({})'.format(self.tuple_undefined_check_cond( - rtype, attr_expr, self.c_undefined_value, compare) + check = "({})".format( + self.tuple_undefined_check_cond(rtype, attr_expr, self.c_undefined_value, compare) ) else: - check = '({} {} {})'.format( - attr_expr, compare, self.c_undefined_value(rtype) - ) + check = "({} {} {})".format(attr_expr, compare, self.c_undefined_value(rtype)) if unlikely: - check = f'(unlikely{check})' - self.emit_line(f'if {check} {{') + check = f"(unlikely{check})" + self.emit_line(f"if {check} {{") def tuple_undefined_check_cond( - self, rtuple: RTuple, tuple_expr_in_c: str, - c_type_compare_val: Callable[[RType], str], compare: str) -> str: + self, + rtuple: RTuple, + tuple_expr_in_c: str, + c_type_compare_val: Callable[[RType], str], + compare: str, + ) -> str: if len(rtuple.types) == 0: # empty tuple - return '{}.empty_struct_error_flag {} {}'.format( - tuple_expr_in_c, compare, c_type_compare_val(int_rprimitive)) + return "{}.empty_struct_error_flag {} {}".format( + tuple_expr_in_c, compare, c_type_compare_val(int_rprimitive) + ) item_type = rtuple.types[0] if isinstance(item_type, RTuple): return self.tuple_undefined_check_cond( - item_type, tuple_expr_in_c + '.f0', c_type_compare_val, compare) + item_type, tuple_expr_in_c + ".f0", c_type_compare_val, compare + ) else: - return '{}.f0 {} {}'.format( - tuple_expr_in_c, compare, c_type_compare_val(item_type)) + return "{}.f0 {} {}".format(tuple_expr_in_c, compare, c_type_compare_val(item_type)) def tuple_undefined_value(self, rtuple: RTuple) -> str: - return 'tuple_undefined_' + rtuple.unique_id + return "tuple_undefined_" + rtuple.unique_id def tuple_undefined_value_helper(self, rtuple: RTuple) -> List[str]: res = [] @@ -347,10 +378,10 @@ def tuple_undefined_value_helper(self, rtuple: RTuple) -> List[str]: res.append(self.c_undefined_value(item)) else: sub_list = self.tuple_undefined_value_helper(item) - res.append('{ ') + res.append("{ ") res.extend(sub_list) - res.append(' }') - res.append(', ') + res.append(" }") + res.append(", ") return res[:-1] # Higher-level operations @@ -364,9 +395,7 @@ def declare_tuple_struct(self, tuple_type: RTuple) -> None: dependencies.add(typ.struct_name) self.context.declarations[tuple_type.struct_name] = HeaderDeclaration( - self.tuple_c_declaration(tuple_type), - dependencies=dependencies, - is_type=True, + self.tuple_c_declaration(tuple_type), dependencies=dependencies, is_type=True ) def emit_inc_ref(self, dest: str, rtype: RType, *, rare: bool = False) -> None: @@ -379,23 +408,20 @@ def emit_inc_ref(self, dest: str, rtype: RType, *, rare: bool = False) -> None: """ if is_int_rprimitive(rtype): if rare: - self.emit_line('CPyTagged_IncRef(%s);' % dest) + self.emit_line("CPyTagged_IncRef(%s);" % dest) else: - self.emit_line('CPyTagged_INCREF(%s);' % dest) + self.emit_line("CPyTagged_INCREF(%s);" % dest) elif isinstance(rtype, RTuple): for i, item_type in enumerate(rtype.types): - self.emit_inc_ref(f'{dest}.f{i}', item_type) + self.emit_inc_ref(f"{dest}.f{i}", item_type) elif not rtype.is_unboxed: # Always inline, since this is a simple op - self.emit_line('CPy_INCREF(%s);' % dest) + self.emit_line("CPy_INCREF(%s);" % dest) # Otherwise assume it's an unboxed, pointerless value and do nothing. - def emit_dec_ref(self, - dest: str, - rtype: RType, - *, - is_xdec: bool = False, - rare: bool = False) -> None: + def emit_dec_ref( + self, dest: str, rtype: RType, *, is_xdec: bool = False, rare: bool = False + ) -> None: """Decrement reference count of C expression `dest`. For composite unboxed structures (e.g. tuples) recursively @@ -403,41 +429,43 @@ def emit_dec_ref(self, If rare is True, optimize for code size and compilation speed. """ - x = 'X' if is_xdec else '' + x = "X" if is_xdec else "" if is_int_rprimitive(rtype): if rare: - self.emit_line(f'CPyTagged_{x}DecRef({dest});') + self.emit_line(f"CPyTagged_{x}DecRef({dest});") else: # Inlined - self.emit_line(f'CPyTagged_{x}DECREF({dest});') + self.emit_line(f"CPyTagged_{x}DECREF({dest});") elif isinstance(rtype, RTuple): for i, item_type in enumerate(rtype.types): - self.emit_dec_ref(f'{dest}.f{i}', item_type, is_xdec=is_xdec, rare=rare) + self.emit_dec_ref(f"{dest}.f{i}", item_type, is_xdec=is_xdec, rare=rare) elif not rtype.is_unboxed: if rare: - self.emit_line(f'CPy_{x}DecRef({dest});') + self.emit_line(f"CPy_{x}DecRef({dest});") else: # Inlined - self.emit_line(f'CPy_{x}DECREF({dest});') + self.emit_line(f"CPy_{x}DECREF({dest});") # Otherwise assume it's an unboxed, pointerless value and do nothing. def pretty_name(self, typ: RType) -> str: value_type = optional_value_type(typ) if value_type is not None: - return '%s or None' % self.pretty_name(value_type) + return "%s or None" % self.pretty_name(value_type) return str(typ) - def emit_cast(self, - src: str, - dest: str, - typ: RType, - *, - declare_dest: bool = False, - error: Optional[ErrorHandler] = None, - raise_exception: bool = True, - optional: bool = False, - src_type: Optional[RType] = None, - likely: bool = True) -> None: + def emit_cast( + self, + src: str, + dest: str, + typ: RType, + *, + declare_dest: bool = False, + error: Optional[ErrorHandler] = None, + raise_exception: bool = True, + optional: bool = False, + src_type: Optional[RType] = None, + likely: bool = True, + ) -> None: """Emit code for casting a value of given type. Somewhat strangely, this supports unboxed types but only @@ -467,252 +495,256 @@ def emit_cast(self, assert value_type is not None if is_same_type(value_type, typ): if declare_dest: - self.emit_line(f'PyObject *{dest};') - check = '({} != Py_None)' + self.emit_line(f"PyObject *{dest};") + check = "({} != Py_None)" if likely: - check = f'(likely{check})' + check = f"(likely{check})" self.emit_arg_check(src, dest, typ, check.format(src), optional) - self.emit_lines( - f' {dest} = {src};', - 'else {') + self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") return # TODO: Verify refcount handling. - if (is_list_rprimitive(typ) - or is_dict_rprimitive(typ) - or is_set_rprimitive(typ) - or is_str_rprimitive(typ) - or is_range_rprimitive(typ) - or is_float_rprimitive(typ) - or is_int_rprimitive(typ) - or is_bool_rprimitive(typ) - or is_bit_rprimitive(typ) - or is_fixed_width_rtype(typ)): + if ( + is_list_rprimitive(typ) + or is_dict_rprimitive(typ) + or is_set_rprimitive(typ) + or is_str_rprimitive(typ) + or is_range_rprimitive(typ) + or is_float_rprimitive(typ) + or is_int_rprimitive(typ) + or is_bool_rprimitive(typ) + or is_bit_rprimitive(typ) + or is_fixed_width_rtype(typ) + ): if declare_dest: - self.emit_line(f'PyObject *{dest};') + self.emit_line(f"PyObject *{dest};") if is_list_rprimitive(typ): - prefix = 'PyList' + prefix = "PyList" elif is_dict_rprimitive(typ): - prefix = 'PyDict' + prefix = "PyDict" elif is_set_rprimitive(typ): - prefix = 'PySet' + prefix = "PySet" elif is_str_rprimitive(typ): - prefix = 'PyUnicode' + prefix = "PyUnicode" elif is_range_rprimitive(typ): - prefix = 'PyRange' + prefix = "PyRange" elif is_float_rprimitive(typ): - prefix = 'CPyFloat' + prefix = "CPyFloat" elif is_int_rprimitive(typ) or is_fixed_width_rtype(typ): # TODO: Range check for fixed-width types? - prefix = 'PyLong' + prefix = "PyLong" elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ): - prefix = 'PyBool' + prefix = "PyBool" else: - assert False, f'unexpected primitive type: {typ}' - check = '({}_Check({}))' + assert False, f"unexpected primitive type: {typ}" + check = "({}_Check({}))" if likely: - check = f'(likely{check})' + check = f"(likely{check})" self.emit_arg_check(src, dest, typ, check.format(prefix, src), optional) - self.emit_lines( - f' {dest} = {src};', - 'else {') + self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") elif is_bytes_rprimitive(typ): if declare_dest: - self.emit_line(f'PyObject *{dest};') - check = '(PyBytes_Check({}) || PyByteArray_Check({}))' + self.emit_line(f"PyObject *{dest};") + check = "(PyBytes_Check({}) || PyByteArray_Check({}))" if likely: - check = f'(likely{check})' + check = f"(likely{check})" self.emit_arg_check(src, dest, typ, check.format(src, src), optional) - self.emit_lines( - f' {dest} = {src};', - 'else {') + self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") elif is_tuple_rprimitive(typ): if declare_dest: - self.emit_line(f'{self.ctype(typ)} {dest};') - check = '(PyTuple_Check({}))' + self.emit_line(f"{self.ctype(typ)} {dest};") + check = "(PyTuple_Check({}))" if likely: - check = f'(likely{check})' - self.emit_arg_check(src, dest, typ, - check.format(src), optional) - self.emit_lines( - f' {dest} = {src};', - 'else {') + check = f"(likely{check})" + self.emit_arg_check(src, dest, typ, check.format(src), optional) + self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") elif isinstance(typ, RInstance): if declare_dest: - self.emit_line(f'PyObject *{dest};') + self.emit_line(f"PyObject *{dest};") concrete = all_concrete_classes(typ.class_ir) # If there are too many concrete subclasses or we can't find any # (meaning the code ought to be dead or we aren't doing global opts), # fall back to a normal typecheck. # Otherwise check all the subclasses. if not concrete or len(concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1: - check = '(PyObject_TypeCheck({}, {}))'.format( - src, self.type_struct_name(typ.class_ir)) + check = "(PyObject_TypeCheck({}, {}))".format( + src, self.type_struct_name(typ.class_ir) + ) else: - full_str = '(Py_TYPE({src}) == {targets[0]})' + full_str = "(Py_TYPE({src}) == {targets[0]})" for i in range(1, len(concrete)): - full_str += ' || (Py_TYPE({src}) == {targets[%d]})' % i + full_str += " || (Py_TYPE({src}) == {targets[%d]})" % i if len(concrete) > 1: - full_str = '(%s)' % full_str + full_str = "(%s)" % full_str check = full_str.format( - src=src, targets=[self.type_struct_name(ir) for ir in concrete]) + src=src, targets=[self.type_struct_name(ir) for ir in concrete] + ) if likely: - check = f'(likely{check})' + check = f"(likely{check})" self.emit_arg_check(src, dest, typ, check, optional) - self.emit_lines( - f' {dest} = {src};'.format(dest, src), - 'else {') + self.emit_lines(f" {dest} = {src};".format(dest, src), "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") elif is_none_rprimitive(typ): if declare_dest: - self.emit_line(f'PyObject *{dest};') - check = '({} == Py_None)' + self.emit_line(f"PyObject *{dest};") + check = "({} == Py_None)" if likely: - check = f'(likely{check})' - self.emit_arg_check(src, dest, typ, - check.format(src), optional) - self.emit_lines( - f' {dest} = {src};', - 'else {') + check = f"(likely{check})" + self.emit_arg_check(src, dest, typ, check.format(src), optional) + self.emit_lines(f" {dest} = {src};", "else {") self.emit_cast_error_handler(error, src, dest, typ, raise_exception) - self.emit_line('}') + self.emit_line("}") elif is_object_rprimitive(typ): if declare_dest: - self.emit_line(f'PyObject *{dest};') - self.emit_arg_check(src, dest, typ, '', optional) - self.emit_line(f'{dest} = {src};') + self.emit_line(f"PyObject *{dest};") + self.emit_arg_check(src, dest, typ, "", optional) + self.emit_line(f"{dest} = {src};") if optional: - self.emit_line('}') + self.emit_line("}") elif isinstance(typ, RUnion): - self.emit_union_cast(src, dest, typ, declare_dest, error, optional, src_type, - raise_exception) + self.emit_union_cast( + src, dest, typ, declare_dest, error, optional, src_type, raise_exception + ) elif isinstance(typ, RTuple): assert not optional self.emit_tuple_cast(src, dest, typ, declare_dest, error, src_type) else: - assert False, 'Cast not implemented: %s' % typ - - def emit_cast_error_handler(self, - error: ErrorHandler, - src: str, - dest: str, - typ: RType, - raise_exception: bool) -> None: + assert False, "Cast not implemented: %s" % typ + + def emit_cast_error_handler( + self, error: ErrorHandler, src: str, dest: str, typ: RType, raise_exception: bool + ) -> None: if raise_exception: if isinstance(error, TracebackAndGotoHandler): # Merge raising and emitting traceback entry into a single call. self.emit_type_error_traceback( - error.source_path, error.module_name, error.traceback_entry, - typ=typ, - src=src) - self.emit_line('goto %s;' % error.label) + error.source_path, error.module_name, error.traceback_entry, typ=typ, src=src + ) + self.emit_line("goto %s;" % error.label) return self.emit_line('CPy_TypeError("{}", {}); '.format(self.pretty_name(typ), src)) if isinstance(error, AssignHandler): - self.emit_line('%s = NULL;' % dest) + self.emit_line("%s = NULL;" % dest) elif isinstance(error, GotoHandler): - self.emit_line('goto %s;' % error.label) + self.emit_line("goto %s;" % error.label) elif isinstance(error, TracebackAndGotoHandler): - self.emit_line('%s = NULL;' % dest) + self.emit_line("%s = NULL;" % dest) self.emit_traceback(error.source_path, error.module_name, error.traceback_entry) - self.emit_line('goto %s;' % error.label) + self.emit_line("goto %s;" % error.label) else: assert isinstance(error, ReturnHandler) - self.emit_line('return %s;' % error.value) - - def emit_union_cast(self, - src: str, - dest: str, - typ: RUnion, - declare_dest: bool, - error: ErrorHandler, - optional: bool, - src_type: Optional[RType], - raise_exception: bool) -> None: + self.emit_line("return %s;" % error.value) + + def emit_union_cast( + self, + src: str, + dest: str, + typ: RUnion, + declare_dest: bool, + error: ErrorHandler, + optional: bool, + src_type: Optional[RType], + raise_exception: bool, + ) -> None: """Emit cast to a union type. The arguments are similar to emit_cast. """ if declare_dest: - self.emit_line(f'PyObject *{dest};') + self.emit_line(f"PyObject *{dest};") good_label = self.new_label() if optional: - self.emit_line(f'if ({src} == NULL) {{') - self.emit_line(f'{dest} = {self.c_error_value(typ)};') - self.emit_line(f'goto {good_label};') - self.emit_line('}') + self.emit_line(f"if ({src} == NULL) {{") + self.emit_line(f"{dest} = {self.c_error_value(typ)};") + self.emit_line(f"goto {good_label};") + self.emit_line("}") for item in typ.items: - self.emit_cast(src, - dest, - item, - declare_dest=False, - raise_exception=False, - optional=False, - likely=False) - self.emit_line(f'if ({dest} != NULL) goto {good_label};') + self.emit_cast( + src, + dest, + item, + declare_dest=False, + raise_exception=False, + optional=False, + likely=False, + ) + self.emit_line(f"if ({dest} != NULL) goto {good_label};") # Handle cast failure. self.emit_cast_error_handler(error, src, dest, typ, raise_exception) self.emit_label(good_label) - def emit_tuple_cast(self, src: str, dest: str, typ: RTuple, declare_dest: bool, - error: ErrorHandler, src_type: Optional[RType]) -> None: + def emit_tuple_cast( + self, + src: str, + dest: str, + typ: RTuple, + declare_dest: bool, + error: ErrorHandler, + src_type: Optional[RType], + ) -> None: """Emit cast to a tuple type. The arguments are similar to emit_cast. """ if declare_dest: - self.emit_line(f'PyObject *{dest};') + self.emit_line(f"PyObject *{dest};") # This reuse of the variable is super dodgy. We don't even # care about the values except to check whether they are # invalid. out_label = self.new_label() self.emit_lines( - 'if (unlikely(!(PyTuple_Check({r}) && PyTuple_GET_SIZE({r}) == {size}))) {{'.format( - r=src, size=len(typ.types)), - f'{dest} = NULL;', - f'goto {out_label};', - '}') + "if (unlikely(!(PyTuple_Check({r}) && PyTuple_GET_SIZE({r}) == {size}))) {{".format( + r=src, size=len(typ.types) + ), + f"{dest} = NULL;", + f"goto {out_label};", + "}", + ) for i, item in enumerate(typ.types): # Since we did the checks above this should never fail - self.emit_cast(f'PyTuple_GET_ITEM({src}, {i})', - dest, - item, - declare_dest=False, - raise_exception=False, - optional=False) - self.emit_line(f'if ({dest} == NULL) goto {out_label};') - - self.emit_line(f'{dest} = {src};') + self.emit_cast( + f"PyTuple_GET_ITEM({src}, {i})", + dest, + item, + declare_dest=False, + raise_exception=False, + optional=False, + ) + self.emit_line(f"if ({dest} == NULL) goto {out_label};") + + self.emit_line(f"{dest} = {src};") self.emit_label(out_label) def emit_arg_check(self, src: str, dest: str, typ: RType, check: str, optional: bool) -> None: if optional: - self.emit_line(f'if ({src} == NULL) {{') - self.emit_line(f'{dest} = {self.c_error_value(typ)};') - if check != '': - self.emit_line('{}if {}'.format('} else ' if optional else '', check)) + self.emit_line(f"if ({src} == NULL) {{") + self.emit_line(f"{dest} = {self.c_error_value(typ)};") + if check != "": + self.emit_line("{}if {}".format("} else " if optional else "", check)) elif optional: - self.emit_line('else {') - - def emit_unbox(self, - src: str, - dest: str, - typ: RType, - *, - declare_dest: bool = False, - error: Optional[ErrorHandler] = None, - raise_exception: bool = True, - optional: bool = False, - borrow: bool = False) -> None: + self.emit_line("else {") + + def emit_unbox( + self, + src: str, + dest: str, + typ: RType, + *, + declare_dest: bool = False, + error: Optional[ErrorHandler] = None, + raise_exception: bool = True, + optional: bool = False, + borrow: bool = False, + ) -> None: """Emit code for unboxing a value of given type (from PyObject *). By default, assign error value to dest if the value has an @@ -734,112 +766,113 @@ def emit_unbox(self, error = error or AssignHandler() # TODO: Verify refcount handling. if isinstance(error, AssignHandler): - failure = f'{dest} = {self.c_error_value(typ)};' + failure = f"{dest} = {self.c_error_value(typ)};" elif isinstance(error, GotoHandler): - failure = 'goto %s;' % error.label + failure = "goto %s;" % error.label else: assert isinstance(error, ReturnHandler) - failure = 'return %s;' % error.value + failure = "return %s;" % error.value if raise_exception: raise_exc = f'CPy_TypeError("{self.pretty_name(typ)}", {src}); ' failure = raise_exc + failure if is_int_rprimitive(typ) or is_short_int_rprimitive(typ): if declare_dest: - self.emit_line(f'CPyTagged {dest};') - self.emit_arg_check(src, dest, typ, f'(likely(PyLong_Check({src})))', - optional) + self.emit_line(f"CPyTagged {dest};") + self.emit_arg_check(src, dest, typ, f"(likely(PyLong_Check({src})))", optional) if borrow: - self.emit_line(f' {dest} = CPyTagged_BorrowFromObject({src});') + self.emit_line(f" {dest} = CPyTagged_BorrowFromObject({src});") else: - self.emit_line(f' {dest} = CPyTagged_FromObject({src});') - self.emit_line('else {') + self.emit_line(f" {dest} = CPyTagged_FromObject({src});") + self.emit_line("else {") self.emit_line(failure) - self.emit_line('}') + self.emit_line("}") elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ): # Whether we are borrowing or not makes no difference. if declare_dest: - self.emit_line(f'char {dest};') - self.emit_arg_check(src, dest, typ, f'(unlikely(!PyBool_Check({src}))) {{', - optional) + self.emit_line(f"char {dest};") + self.emit_arg_check(src, dest, typ, f"(unlikely(!PyBool_Check({src}))) {{", optional) self.emit_line(failure) - self.emit_line('} else') - conversion = f'{src} == Py_True' - self.emit_line(f' {dest} = {conversion};') + self.emit_line("} else") + conversion = f"{src} == Py_True" + self.emit_line(f" {dest} = {conversion};") elif is_none_rprimitive(typ): # Whether we are borrowing or not makes no difference. if declare_dest: - self.emit_line(f'char {dest};') - self.emit_arg_check(src, dest, typ, f'(unlikely({src} != Py_None)) {{', - optional) + self.emit_line(f"char {dest};") + self.emit_arg_check(src, dest, typ, f"(unlikely({src} != Py_None)) {{", optional) self.emit_line(failure) - self.emit_line('} else') - self.emit_line(f' {dest} = 1;') + self.emit_line("} else") + self.emit_line(f" {dest} = 1;") elif is_int64_rprimitive(typ): # Whether we are borrowing or not makes no difference. if declare_dest: - self.emit_line(f'int64_t {dest};') - self.emit_line(f'{dest} = CPyLong_AsInt64({src});') + self.emit_line(f"int64_t {dest};") + self.emit_line(f"{dest} = CPyLong_AsInt64({src});") # TODO: Handle 'optional' # TODO: Handle 'failure' elif is_int32_rprimitive(typ): # Whether we are borrowing or not makes no difference. if declare_dest: - self.emit_line('int32_t {};'.format(dest)) - self.emit_line('{} = CPyLong_AsInt32({});'.format(dest, src)) + self.emit_line("int32_t {};".format(dest)) + self.emit_line("{} = CPyLong_AsInt32({});".format(dest, src)) # TODO: Handle 'optional' # TODO: Handle 'failure' elif isinstance(typ, RTuple): self.declare_tuple_struct(typ) if declare_dest: - self.emit_line(f'{self.ctype(typ)} {dest};') + self.emit_line(f"{self.ctype(typ)} {dest};") # HACK: The error handling for unboxing tuples is busted # and instead of fixing it I am just wrapping it in the # cast code which I think is right. This is not good. if optional: - self.emit_line(f'if ({src} == NULL) {{') - self.emit_line(f'{dest} = {self.c_error_value(typ)};') - self.emit_line('} else {') + self.emit_line(f"if ({src} == NULL) {{") + self.emit_line(f"{dest} = {self.c_error_value(typ)};") + self.emit_line("} else {") cast_temp = self.temp_name() - self.emit_tuple_cast(src, cast_temp, typ, declare_dest=True, error=error, - src_type=None) - self.emit_line(f'if (unlikely({cast_temp} == NULL)) {{') + self.emit_tuple_cast( + src, cast_temp, typ, declare_dest=True, error=error, src_type=None + ) + self.emit_line(f"if (unlikely({cast_temp} == NULL)) {{") # self.emit_arg_check(src, dest, typ, # '(!PyTuple_Check({}) || PyTuple_Size({}) != {}) {{'.format( # src, src, len(typ.types)), optional) self.emit_line(failure) # TODO: Decrease refcount? - self.emit_line('} else {') + self.emit_line("} else {") if not typ.types: - self.emit_line(f'{dest}.empty_struct_error_flag = 0;') + self.emit_line(f"{dest}.empty_struct_error_flag = 0;") for i, item_type in enumerate(typ.types): temp = self.temp_name() # emit_tuple_cast above checks the size, so this should not fail - self.emit_line(f'PyObject *{temp} = PyTuple_GET_ITEM({src}, {i});') + self.emit_line(f"PyObject *{temp} = PyTuple_GET_ITEM({src}, {i});") temp2 = self.temp_name() # Unbox or check the item. if item_type.is_unboxed: - self.emit_unbox(temp, - temp2, - item_type, - raise_exception=raise_exception, - error=error, - declare_dest=True, - borrow=borrow) + self.emit_unbox( + temp, + temp2, + item_type, + raise_exception=raise_exception, + error=error, + declare_dest=True, + borrow=borrow, + ) else: if not borrow: self.emit_inc_ref(temp, object_rprimitive) self.emit_cast(temp, temp2, item_type, declare_dest=True) - self.emit_line(f'{dest}.f{i} = {temp2};') - self.emit_line('}') + self.emit_line(f"{dest}.f{i} = {temp2};") + self.emit_line("}") if optional: - self.emit_line('}') + self.emit_line("}") else: - assert False, 'Unboxing not implemented: %s' % typ + assert False, "Unboxing not implemented: %s" % typ - def emit_box(self, src: str, dest: str, typ: RType, declare_dest: bool = False, - can_borrow: bool = False) -> None: + def emit_box( + self, src: str, dest: str, typ: RType, declare_dest: bool = False, can_borrow: bool = False + ) -> None: """Emit code for boxing a value of given type. Generate a simple assignment if no boxing is needed. @@ -848,60 +881,59 @@ def emit_box(self, src: str, dest: str, typ: RType, declare_dest: bool = False, """ # TODO: Always generate a new reference (if a reference type) if declare_dest: - declaration = 'PyObject *' + declaration = "PyObject *" else: - declaration = '' + declaration = "" if is_int_rprimitive(typ) or is_short_int_rprimitive(typ): # Steal the existing reference if it exists. - self.emit_line(f'{declaration}{dest} = CPyTagged_StealAsObject({src});') + self.emit_line(f"{declaration}{dest} = CPyTagged_StealAsObject({src});") elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ): # N.B: bool is special cased to produce a borrowed value # after boxing, so we don't need to increment the refcount # when this comes directly from a Box op. - self.emit_lines(f'{declaration}{dest} = {src} ? Py_True : Py_False;') + self.emit_lines(f"{declaration}{dest} = {src} ? Py_True : Py_False;") if not can_borrow: self.emit_inc_ref(dest, object_rprimitive) elif is_none_rprimitive(typ): # N.B: None is special cased to produce a borrowed value # after boxing, so we don't need to increment the refcount # when this comes directly from a Box op. - self.emit_lines(f'{declaration}{dest} = Py_None;') + self.emit_lines(f"{declaration}{dest} = Py_None;") if not can_borrow: self.emit_inc_ref(dest, object_rprimitive) elif is_int32_rprimitive(typ): - self.emit_line(f'{declaration}{dest} = PyLong_FromLong({src});') + self.emit_line(f"{declaration}{dest} = PyLong_FromLong({src});") elif is_int64_rprimitive(typ): - self.emit_line(f'{declaration}{dest} = PyLong_FromLongLong({src});') + self.emit_line(f"{declaration}{dest} = PyLong_FromLongLong({src});") elif isinstance(typ, RTuple): self.declare_tuple_struct(typ) - self.emit_line(f'{declaration}{dest} = PyTuple_New({len(typ.types)});') - self.emit_line(f'if (unlikely({dest} == NULL))') - self.emit_line(' CPyError_OutOfMemory();') + self.emit_line(f"{declaration}{dest} = PyTuple_New({len(typ.types)});") + self.emit_line(f"if (unlikely({dest} == NULL))") + self.emit_line(" CPyError_OutOfMemory();") # TODO: Fail if dest is None for i in range(0, len(typ.types)): if not typ.is_unboxed: - self.emit_line(f'PyTuple_SET_ITEM({dest}, {i}, {src}.f{i}') + self.emit_line(f"PyTuple_SET_ITEM({dest}, {i}, {src}.f{i}") else: inner_name = self.temp_name() - self.emit_box(f'{src}.f{i}', inner_name, typ.types[i], - declare_dest=True) - self.emit_line(f'PyTuple_SET_ITEM({dest}, {i}, {inner_name});') + self.emit_box(f"{src}.f{i}", inner_name, typ.types[i], declare_dest=True) + self.emit_line(f"PyTuple_SET_ITEM({dest}, {i}, {inner_name});") else: assert not typ.is_unboxed # Type is boxed -- trivially just assign. - self.emit_line(f'{declaration}{dest} = {src};') + self.emit_line(f"{declaration}{dest} = {src};") def emit_error_check(self, value: str, rtype: RType, failure: str) -> None: """Emit code for checking a native function return value for uncaught exception.""" if not isinstance(rtype, RTuple): - self.emit_line(f'if ({value} == {self.c_error_value(rtype)}) {{') + self.emit_line(f"if ({value} == {self.c_error_value(rtype)}) {{") else: if len(rtype.types) == 0: return # empty tuples can't fail. else: - cond = self.tuple_undefined_check_cond(rtype, value, self.c_error_value, '==') - self.emit_line(f'if ({cond}) {{') - self.emit_lines(failure, '}') + cond = self.tuple_undefined_check_cond(rtype, value, self.c_error_value, "==") + self.emit_line(f"if ({cond}) {{") + self.emit_lines(failure, "}") def emit_gc_visit(self, target: str, rtype: RType) -> None: """Emit code for GC visiting a C variable reference. @@ -912,18 +944,18 @@ def emit_gc_visit(self, target: str, rtype: RType) -> None: if not rtype.is_refcounted: # Not refcounted -> no pointers -> no GC interaction. return - elif isinstance(rtype, RPrimitive) and rtype.name == 'builtins.int': - self.emit_line(f'if (CPyTagged_CheckLong({target})) {{') - self.emit_line(f'Py_VISIT(CPyTagged_LongAsObject({target}));') - self.emit_line('}') + elif isinstance(rtype, RPrimitive) and rtype.name == "builtins.int": + self.emit_line(f"if (CPyTagged_CheckLong({target})) {{") + self.emit_line(f"Py_VISIT(CPyTagged_LongAsObject({target}));") + self.emit_line("}") elif isinstance(rtype, RTuple): for i, item_type in enumerate(rtype.types): - self.emit_gc_visit(f'{target}.f{i}', item_type) - elif self.ctype(rtype) == 'PyObject *': + self.emit_gc_visit(f"{target}.f{i}", item_type) + elif self.ctype(rtype) == "PyObject *": # The simplest case. - self.emit_line(f'Py_VISIT({target});') + self.emit_line(f"Py_VISIT({target});") else: - assert False, 'emit_gc_visit() not implemented for %s' % repr(rtype) + assert False, "emit_gc_visit() not implemented for %s" % repr(rtype) def emit_gc_clear(self, target: str, rtype: RType) -> None: """Emit code for clearing a C attribute reference for GC. @@ -934,58 +966,62 @@ def emit_gc_clear(self, target: str, rtype: RType) -> None: if not rtype.is_refcounted: # Not refcounted -> no pointers -> no GC interaction. return - elif isinstance(rtype, RPrimitive) and rtype.name == 'builtins.int': - self.emit_line(f'if (CPyTagged_CheckLong({target})) {{') - self.emit_line(f'CPyTagged __tmp = {target};') - self.emit_line(f'{target} = {self.c_undefined_value(rtype)};') - self.emit_line('Py_XDECREF(CPyTagged_LongAsObject(__tmp));') - self.emit_line('}') + elif isinstance(rtype, RPrimitive) and rtype.name == "builtins.int": + self.emit_line(f"if (CPyTagged_CheckLong({target})) {{") + self.emit_line(f"CPyTagged __tmp = {target};") + self.emit_line(f"{target} = {self.c_undefined_value(rtype)};") + self.emit_line("Py_XDECREF(CPyTagged_LongAsObject(__tmp));") + self.emit_line("}") elif isinstance(rtype, RTuple): for i, item_type in enumerate(rtype.types): - self.emit_gc_clear(f'{target}.f{i}', item_type) - elif self.ctype(rtype) == 'PyObject *' and self.c_undefined_value(rtype) == 'NULL': + self.emit_gc_clear(f"{target}.f{i}", item_type) + elif self.ctype(rtype) == "PyObject *" and self.c_undefined_value(rtype) == "NULL": # The simplest case. - self.emit_line(f'Py_CLEAR({target});') + self.emit_line(f"Py_CLEAR({target});") else: - assert False, 'emit_gc_clear() not implemented for %s' % repr(rtype) + assert False, "emit_gc_clear() not implemented for %s" % repr(rtype) - def emit_traceback(self, - source_path: str, - module_name: str, - traceback_entry: Tuple[str, int]) -> None: - return self._emit_traceback('CPy_AddTraceback', source_path, module_name, traceback_entry) + def emit_traceback( + self, source_path: str, module_name: str, traceback_entry: Tuple[str, int] + ) -> None: + return self._emit_traceback("CPy_AddTraceback", source_path, module_name, traceback_entry) def emit_type_error_traceback( - self, - source_path: str, - module_name: str, - traceback_entry: Tuple[str, int], - *, - typ: RType, - src: str) -> None: - func = 'CPy_TypeErrorTraceback' + self, + source_path: str, + module_name: str, + traceback_entry: Tuple[str, int], + *, + typ: RType, + src: str, + ) -> None: + func = "CPy_TypeErrorTraceback" type_str = f'"{self.pretty_name(typ)}"' return self._emit_traceback( - func, source_path, module_name, traceback_entry, type_str=type_str, src=src) - - def _emit_traceback(self, - func: str, - source_path: str, - module_name: str, - traceback_entry: Tuple[str, int], - type_str: str = '', - src: str = '') -> None: - globals_static = self.static_name('globals', module_name) + func, source_path, module_name, traceback_entry, type_str=type_str, src=src + ) + + def _emit_traceback( + self, + func: str, + source_path: str, + module_name: str, + traceback_entry: Tuple[str, int], + type_str: str = "", + src: str = "", + ) -> None: + globals_static = self.static_name("globals", module_name) line = '%s("%s", "%s", %d, %s' % ( func, source_path.replace("\\", "\\\\"), traceback_entry[0], traceback_entry[1], - globals_static) + globals_static, + ) if type_str: assert src - line += f', {type_str}, {src}' - line += ');' + line += f", {type_str}, {src}" + line += ");" self.emit_line(line) if DEBUG_ERRORS: self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");') diff --git a/mypyc/codegen/emitclass.py b/mypyc/codegen/emitclass.py index 9c1ad58d11bbe..4666443800a61 100644 --- a/mypyc/codegen/emitclass.py +++ b/mypyc/codegen/emitclass.py @@ -1,30 +1,35 @@ """Code generation for native classes and related wrappers.""" -from typing import Optional, List, Tuple, Dict, Callable, Mapping, Set +from typing import Callable, Dict, List, Mapping, Optional, Set, Tuple from mypy.backports import OrderedDict - -from mypyc.common import PREFIX, NATIVE_PREFIX, REG_PREFIX, use_fastcall from mypyc.codegen.emit import Emitter, HeaderDeclaration, ReturnHandler from mypyc.codegen.emitfunc import native_function_header from mypyc.codegen.emitwrapper import ( - generate_dunder_wrapper, generate_hash_wrapper, generate_richcompare_wrapper, - generate_bool_wrapper, generate_get_wrapper, generate_len_wrapper, - generate_set_del_item_wrapper, generate_contains_wrapper, generate_bin_op_wrapper + generate_bin_op_wrapper, + generate_bool_wrapper, + generate_contains_wrapper, + generate_dunder_wrapper, + generate_get_wrapper, + generate_hash_wrapper, + generate_len_wrapper, + generate_richcompare_wrapper, + generate_set_del_item_wrapper, ) -from mypyc.ir.rtypes import RType, RTuple, object_rprimitive -from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD +from mypyc.common import NATIVE_PREFIX, PREFIX, REG_PREFIX, use_fastcall from mypyc.ir.class_ir import ClassIR, VTableEntries -from mypyc.sametype import is_same_type +from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FuncDecl, FuncIR +from mypyc.ir.rtypes import RTuple, RType, object_rprimitive from mypyc.namegen import NameGenerator +from mypyc.sametype import is_same_type def native_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: - return f'{NATIVE_PREFIX}{fn.cname(emitter.names)}' + return f"{NATIVE_PREFIX}{fn.cname(emitter.names)}" def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: - return f'{PREFIX}{fn.cname(emitter.names)}' + return f"{PREFIX}{fn.cname(emitter.names)}" # We maintain a table from dunder function names to struct slots they @@ -35,95 +40,91 @@ def wrapper_slot(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: SlotTable = Mapping[str, Tuple[str, SlotGenerator]] SLOT_DEFS: SlotTable = { - '__init__': ('tp_init', lambda c, t, e: generate_init_for_class(c, t, e)), - '__call__': ('tp_call', lambda c, t, e: generate_call_wrapper(c, t, e)), - '__str__': ('tp_str', native_slot), - '__repr__': ('tp_repr', native_slot), - '__next__': ('tp_iternext', native_slot), - '__iter__': ('tp_iter', native_slot), - '__hash__': ('tp_hash', generate_hash_wrapper), - '__get__': ('tp_descr_get', generate_get_wrapper), + "__init__": ("tp_init", lambda c, t, e: generate_init_for_class(c, t, e)), + "__call__": ("tp_call", lambda c, t, e: generate_call_wrapper(c, t, e)), + "__str__": ("tp_str", native_slot), + "__repr__": ("tp_repr", native_slot), + "__next__": ("tp_iternext", native_slot), + "__iter__": ("tp_iter", native_slot), + "__hash__": ("tp_hash", generate_hash_wrapper), + "__get__": ("tp_descr_get", generate_get_wrapper), } AS_MAPPING_SLOT_DEFS: SlotTable = { - '__getitem__': ('mp_subscript', generate_dunder_wrapper), - '__setitem__': ('mp_ass_subscript', generate_set_del_item_wrapper), - '__delitem__': ('mp_ass_subscript', generate_set_del_item_wrapper), - '__len__': ('mp_length', generate_len_wrapper), + "__getitem__": ("mp_subscript", generate_dunder_wrapper), + "__setitem__": ("mp_ass_subscript", generate_set_del_item_wrapper), + "__delitem__": ("mp_ass_subscript", generate_set_del_item_wrapper), + "__len__": ("mp_length", generate_len_wrapper), } -AS_SEQUENCE_SLOT_DEFS: SlotTable = { - '__contains__': ('sq_contains', generate_contains_wrapper), -} +AS_SEQUENCE_SLOT_DEFS: SlotTable = {"__contains__": ("sq_contains", generate_contains_wrapper)} AS_NUMBER_SLOT_DEFS: SlotTable = { - '__bool__': ('nb_bool', generate_bool_wrapper), - '__neg__': ('nb_negative', generate_dunder_wrapper), - '__invert__': ('nb_invert', generate_dunder_wrapper), - '__int__': ('nb_int', generate_dunder_wrapper), - '__float__': ('nb_float', generate_dunder_wrapper), - '__add__': ('nb_add', generate_bin_op_wrapper), - '__radd__': ('nb_add', generate_bin_op_wrapper), - '__sub__': ('nb_subtract', generate_bin_op_wrapper), - '__rsub__': ('nb_subtract', generate_bin_op_wrapper), - '__mul__': ('nb_multiply', generate_bin_op_wrapper), - '__rmul__': ('nb_multiply', generate_bin_op_wrapper), - '__mod__': ('nb_remainder', generate_bin_op_wrapper), - '__rmod__': ('nb_remainder', generate_bin_op_wrapper), - '__truediv__': ('nb_true_divide', generate_bin_op_wrapper), - '__rtruediv__': ('nb_true_divide', generate_bin_op_wrapper), - '__floordiv__': ('nb_floor_divide', generate_bin_op_wrapper), - '__rfloordiv__': ('nb_floor_divide', generate_bin_op_wrapper), - '__lshift__': ('nb_lshift', generate_bin_op_wrapper), - '__rlshift__': ('nb_lshift', generate_bin_op_wrapper), - '__rshift__': ('nb_rshift', generate_bin_op_wrapper), - '__rrshift__': ('nb_rshift', generate_bin_op_wrapper), - '__and__': ('nb_and', generate_bin_op_wrapper), - '__rand__': ('nb_and', generate_bin_op_wrapper), - '__or__': ('nb_or', generate_bin_op_wrapper), - '__ror__': ('nb_or', generate_bin_op_wrapper), - '__xor__': ('nb_xor', generate_bin_op_wrapper), - '__rxor__': ('nb_xor', generate_bin_op_wrapper), - '__matmul__': ('nb_matrix_multiply', generate_bin_op_wrapper), - '__rmatmul__': ('nb_matrix_multiply', generate_bin_op_wrapper), - '__iadd__': ('nb_inplace_add', generate_dunder_wrapper), - '__isub__': ('nb_inplace_subtract', generate_dunder_wrapper), - '__imul__': ('nb_inplace_multiply', generate_dunder_wrapper), - '__imod__': ('nb_inplace_remainder', generate_dunder_wrapper), - '__itruediv__': ('nb_inplace_true_divide', generate_dunder_wrapper), - '__ifloordiv__': ('nb_inplace_floor_divide', generate_dunder_wrapper), - '__ilshift__': ('nb_inplace_lshift', generate_dunder_wrapper), - '__irshift__': ('nb_inplace_rshift', generate_dunder_wrapper), - '__iand__': ('nb_inplace_and', generate_dunder_wrapper), - '__ior__': ('nb_inplace_or', generate_dunder_wrapper), - '__ixor__': ('nb_inplace_xor', generate_dunder_wrapper), - '__imatmul__': ('nb_inplace_matrix_multiply', generate_dunder_wrapper), + "__bool__": ("nb_bool", generate_bool_wrapper), + "__neg__": ("nb_negative", generate_dunder_wrapper), + "__invert__": ("nb_invert", generate_dunder_wrapper), + "__int__": ("nb_int", generate_dunder_wrapper), + "__float__": ("nb_float", generate_dunder_wrapper), + "__add__": ("nb_add", generate_bin_op_wrapper), + "__radd__": ("nb_add", generate_bin_op_wrapper), + "__sub__": ("nb_subtract", generate_bin_op_wrapper), + "__rsub__": ("nb_subtract", generate_bin_op_wrapper), + "__mul__": ("nb_multiply", generate_bin_op_wrapper), + "__rmul__": ("nb_multiply", generate_bin_op_wrapper), + "__mod__": ("nb_remainder", generate_bin_op_wrapper), + "__rmod__": ("nb_remainder", generate_bin_op_wrapper), + "__truediv__": ("nb_true_divide", generate_bin_op_wrapper), + "__rtruediv__": ("nb_true_divide", generate_bin_op_wrapper), + "__floordiv__": ("nb_floor_divide", generate_bin_op_wrapper), + "__rfloordiv__": ("nb_floor_divide", generate_bin_op_wrapper), + "__lshift__": ("nb_lshift", generate_bin_op_wrapper), + "__rlshift__": ("nb_lshift", generate_bin_op_wrapper), + "__rshift__": ("nb_rshift", generate_bin_op_wrapper), + "__rrshift__": ("nb_rshift", generate_bin_op_wrapper), + "__and__": ("nb_and", generate_bin_op_wrapper), + "__rand__": ("nb_and", generate_bin_op_wrapper), + "__or__": ("nb_or", generate_bin_op_wrapper), + "__ror__": ("nb_or", generate_bin_op_wrapper), + "__xor__": ("nb_xor", generate_bin_op_wrapper), + "__rxor__": ("nb_xor", generate_bin_op_wrapper), + "__matmul__": ("nb_matrix_multiply", generate_bin_op_wrapper), + "__rmatmul__": ("nb_matrix_multiply", generate_bin_op_wrapper), + "__iadd__": ("nb_inplace_add", generate_dunder_wrapper), + "__isub__": ("nb_inplace_subtract", generate_dunder_wrapper), + "__imul__": ("nb_inplace_multiply", generate_dunder_wrapper), + "__imod__": ("nb_inplace_remainder", generate_dunder_wrapper), + "__itruediv__": ("nb_inplace_true_divide", generate_dunder_wrapper), + "__ifloordiv__": ("nb_inplace_floor_divide", generate_dunder_wrapper), + "__ilshift__": ("nb_inplace_lshift", generate_dunder_wrapper), + "__irshift__": ("nb_inplace_rshift", generate_dunder_wrapper), + "__iand__": ("nb_inplace_and", generate_dunder_wrapper), + "__ior__": ("nb_inplace_or", generate_dunder_wrapper), + "__ixor__": ("nb_inplace_xor", generate_dunder_wrapper), + "__imatmul__": ("nb_inplace_matrix_multiply", generate_dunder_wrapper), } AS_ASYNC_SLOT_DEFS: SlotTable = { - '__await__': ('am_await', native_slot), - '__aiter__': ('am_aiter', native_slot), - '__anext__': ('am_anext', native_slot), + "__await__": ("am_await", native_slot), + "__aiter__": ("am_aiter", native_slot), + "__anext__": ("am_anext", native_slot), } SIDE_TABLES = [ - ('as_mapping', 'PyMappingMethods', AS_MAPPING_SLOT_DEFS), - ('as_sequence', 'PySequenceMethods', AS_SEQUENCE_SLOT_DEFS), - ('as_number', 'PyNumberMethods', AS_NUMBER_SLOT_DEFS), - ('as_async', 'PyAsyncMethods', AS_ASYNC_SLOT_DEFS), + ("as_mapping", "PyMappingMethods", AS_MAPPING_SLOT_DEFS), + ("as_sequence", "PySequenceMethods", AS_SEQUENCE_SLOT_DEFS), + ("as_number", "PyNumberMethods", AS_NUMBER_SLOT_DEFS), + ("as_async", "PyAsyncMethods", AS_ASYNC_SLOT_DEFS), ] # Slots that need to always be filled in because they don't get # inherited right. -ALWAYS_FILL = { - '__hash__', -} +ALWAYS_FILL = {"__hash__"} def generate_call_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: if emitter.use_vectorcall(): # Use vectorcall wrapper if supported (PEP 590). - return 'PyVectorcall_Call' + return "PyVectorcall_Call" else: # On older Pythons use the legacy wrapper. return wrapper_slot(cl, fn, emitter) @@ -134,8 +135,8 @@ def slot_key(attr: str) -> str: Sort reverse operator methods and __delitem__ after others ('x' > '_'). """ - if (attr.startswith('__r') and attr != '__rshift__') or attr == '__delitem__': - return 'x' + attr + if (attr.startswith("__r") and attr != "__rshift__") or attr == "__delitem__": + return "x" + attr return attr @@ -158,14 +159,14 @@ def generate_slots(cl: ClassIR, table: SlotTable, emitter: Emitter) -> Dict[str, return fields -def generate_class_type_decl(cl: ClassIR, c_emitter: Emitter, - external_emitter: Emitter, - emitter: Emitter) -> None: +def generate_class_type_decl( + cl: ClassIR, c_emitter: Emitter, external_emitter: Emitter, emitter: Emitter +) -> None: context = c_emitter.context name = emitter.type_struct_name(cl) context.declarations[name] = HeaderDeclaration( - f'PyTypeObject *{emitter.type_struct_name(cl)};', - needs_export=True) + f"PyTypeObject *{emitter.type_struct_name(cl)};", needs_export=True + ) # If this is a non-extension class, all we want is the type object decl. if not cl.is_ext_class: @@ -175,8 +176,7 @@ def generate_class_type_decl(cl: ClassIR, c_emitter: Emitter, generate_full = not cl.is_trait and not cl.builtin_base if generate_full: context.declarations[emitter.native_function_name(cl.ctor)] = HeaderDeclaration( - f'{native_function_header(cl.ctor, emitter)};', - needs_export=True, + f"{native_function_header(cl.ctor, emitter)};", needs_export=True ) @@ -188,33 +188,33 @@ def generate_class(cl: ClassIR, module: str, emitter: Emitter) -> None: name = cl.name name_prefix = cl.name_prefix(emitter.names) - setup_name = f'{name_prefix}_setup' - new_name = f'{name_prefix}_new' - members_name = f'{name_prefix}_members' - getseters_name = f'{name_prefix}_getseters' - vtable_name = f'{name_prefix}_vtable' - traverse_name = f'{name_prefix}_traverse' - clear_name = f'{name_prefix}_clear' - dealloc_name = f'{name_prefix}_dealloc' - methods_name = f'{name_prefix}_methods' - vtable_setup_name = f'{name_prefix}_trait_vtable_setup' + setup_name = f"{name_prefix}_setup" + new_name = f"{name_prefix}_new" + members_name = f"{name_prefix}_members" + getseters_name = f"{name_prefix}_getseters" + vtable_name = f"{name_prefix}_vtable" + traverse_name = f"{name_prefix}_traverse" + clear_name = f"{name_prefix}_clear" + dealloc_name = f"{name_prefix}_dealloc" + methods_name = f"{name_prefix}_methods" + vtable_setup_name = f"{name_prefix}_trait_vtable_setup" fields: Dict[str, str] = OrderedDict() - fields['tp_name'] = f'"{name}"' + fields["tp_name"] = f'"{name}"' generate_full = not cl.is_trait and not cl.builtin_base needs_getseters = cl.needs_getseters or not cl.is_generated if not cl.builtin_base: - fields['tp_new'] = new_name + fields["tp_new"] = new_name if generate_full: - fields['tp_dealloc'] = f'(destructor){name_prefix}_dealloc' - fields['tp_traverse'] = f'(traverseproc){name_prefix}_traverse' - fields['tp_clear'] = f'(inquiry){name_prefix}_clear' + fields["tp_dealloc"] = f"(destructor){name_prefix}_dealloc" + fields["tp_traverse"] = f"(traverseproc){name_prefix}_traverse" + fields["tp_clear"] = f"(inquiry){name_prefix}_clear" if needs_getseters: - fields['tp_getset'] = getseters_name - fields['tp_methods'] = methods_name + fields["tp_getset"] = getseters_name + fields["tp_methods"] = methods_name def emit_line() -> None: emitter.emit_line() @@ -223,10 +223,10 @@ def emit_line() -> None: # If the class has a method to initialize default attribute # values, we need to call it during initialization. - defaults_fn = cl.get_method('__mypyc_defaults_setup') + defaults_fn = cl.get_method("__mypyc_defaults_setup") # If there is a __init__ method, we'll use it in the native constructor. - init_fn = cl.get_method('__init__') + init_fn = cl.get_method("__init__") # Fill out slots in the type object from dunder methods. fields.update(generate_slots(cl, SLOT_DEFS, emitter)) @@ -236,20 +236,20 @@ def emit_line() -> None: slots = generate_slots(cl, slot_defs, emitter) if slots: table_struct_name = generate_side_table_for_class(cl, table_name, type, slots, emitter) - fields[f'tp_{table_name}'] = f'&{table_struct_name}' + fields[f"tp_{table_name}"] = f"&{table_struct_name}" richcompare_name = generate_richcompare_wrapper(cl, emitter) if richcompare_name: - fields['tp_richcompare'] = richcompare_name + fields["tp_richcompare"] = richcompare_name # If the class inherits from python, make space for a __dict__ struct_name = cl.struct_name(emitter.names) if cl.builtin_base: - base_size = f'sizeof({cl.builtin_base})' + base_size = f"sizeof({cl.builtin_base})" elif cl.is_trait: - base_size = 'sizeof(PyObject)' + base_size = "sizeof(PyObject)" else: - base_size = f'sizeof({struct_name})' + base_size = f"sizeof({struct_name})" # Since our types aren't allocated using type() we need to # populate these fields ourselves if we want them to have correct # values. PyType_Ready will inherit the offsets from tp_base but @@ -259,32 +259,32 @@ def emit_line() -> None: if cl.has_dict: # __dict__ lives right after the struct and __weakref__ lives right after that # TODO: They should get members in the struct instead of doing this nonsense. - weak_offset = f'{base_size} + sizeof(PyObject *)' + weak_offset = f"{base_size} + sizeof(PyObject *)" emitter.emit_lines( - f'PyMemberDef {members_name}[] = {{', + f"PyMemberDef {members_name}[] = {{", f'{{"__dict__", T_OBJECT_EX, {base_size}, 0, NULL}},', f'{{"__weakref__", T_OBJECT_EX, {weak_offset}, 0, NULL}},', - '{0}', - '};', + "{0}", + "};", ) - fields['tp_members'] = members_name - fields['tp_basicsize'] = f'{base_size} + 2*sizeof(PyObject *)' - fields['tp_dictoffset'] = base_size - fields['tp_weaklistoffset'] = weak_offset + fields["tp_members"] = members_name + fields["tp_basicsize"] = f"{base_size} + 2*sizeof(PyObject *)" + fields["tp_dictoffset"] = base_size + fields["tp_weaklistoffset"] = weak_offset else: - fields['tp_basicsize'] = base_size + fields["tp_basicsize"] = base_size if generate_full: # Declare setup method that allocates and initializes an object. type is the # type of the class being initialized, which could be another class if there # is an interpreted subclass. - emitter.emit_line(f'static PyObject *{setup_name}(PyTypeObject *type);') + emitter.emit_line(f"static PyObject *{setup_name}(PyTypeObject *type);") assert cl.ctor is not None - emitter.emit_line(native_function_header(cl.ctor, emitter) + ';') + emitter.emit_line(native_function_header(cl.ctor, emitter) + ";") emit_line() - init_fn = cl.get_method('__init__') + init_fn = cl.get_method("__init__") generate_new_for_class(cl, new_name, vtable_name, setup_name, init_fn, emitter) emit_line() generate_traverse_for_class(cl, traverse_name, emitter) @@ -315,79 +315,77 @@ def emit_line() -> None: generate_methods_table(cl, methods_name, emitter) emit_line() - flags = ['Py_TPFLAGS_DEFAULT', 'Py_TPFLAGS_HEAPTYPE', 'Py_TPFLAGS_BASETYPE'] + flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"] if generate_full: - flags.append('Py_TPFLAGS_HAVE_GC') - if cl.has_method('__call__') and emitter.use_vectorcall(): - fields['tp_vectorcall_offset'] = 'offsetof({}, vectorcall)'.format( - cl.struct_name(emitter.names)) - flags.append('_Py_TPFLAGS_HAVE_VECTORCALL') - if not fields.get('tp_vectorcall'): + flags.append("Py_TPFLAGS_HAVE_GC") + if cl.has_method("__call__") and emitter.use_vectorcall(): + fields["tp_vectorcall_offset"] = "offsetof({}, vectorcall)".format( + cl.struct_name(emitter.names) + ) + flags.append("_Py_TPFLAGS_HAVE_VECTORCALL") + if not fields.get("tp_vectorcall"): # This is just a placeholder to please CPython. It will be # overriden during setup. - fields['tp_call'] = 'PyVectorcall_Call' - fields['tp_flags'] = ' | '.join(flags) + fields["tp_call"] = "PyVectorcall_Call" + fields["tp_flags"] = " | ".join(flags) emitter.emit_line(f"static PyTypeObject {emitter.type_struct_name(cl)}_template_ = {{") emitter.emit_line("PyVarObject_HEAD_INIT(NULL, 0)") for field, value in fields.items(): emitter.emit_line(f".{field} = {value},") emitter.emit_line("};") - emitter.emit_line("static PyTypeObject *{t}_template = &{t}_template_;".format( - t=emitter.type_struct_name(cl))) + emitter.emit_line( + "static PyTypeObject *{t}_template = &{t}_template_;".format( + t=emitter.type_struct_name(cl) + ) + ) emitter.emit_line() if generate_full: generate_setup_for_class( - cl, setup_name, defaults_fn, vtable_name, shadow_vtable_name, emitter) + cl, setup_name, defaults_fn, vtable_name, shadow_vtable_name, emitter + ) emitter.emit_line() - generate_constructor_for_class( - cl, cl.ctor, init_fn, setup_name, vtable_name, emitter) + generate_constructor_for_class(cl, cl.ctor, init_fn, setup_name, vtable_name, emitter) emitter.emit_line() if needs_getseters: generate_getseters(cl, emitter) def getter_name(cl: ClassIR, attribute: str, names: NameGenerator) -> str: - return names.private_name(cl.module_name, f'{cl.name}_get_{attribute}') + return names.private_name(cl.module_name, f"{cl.name}_get_{attribute}") def setter_name(cl: ClassIR, attribute: str, names: NameGenerator) -> str: - return names.private_name(cl.module_name, f'{cl.name}_set_{attribute}') + return names.private_name(cl.module_name, f"{cl.name}_set_{attribute}") def generate_object_struct(cl: ClassIR, emitter: Emitter) -> None: seen_attrs: Set[Tuple[str, RType]] = set() lines: List[str] = [] - lines += ['typedef struct {', - 'PyObject_HEAD', - 'CPyVTableItem *vtable;'] - if cl.has_method('__call__') and emitter.use_vectorcall(): - lines.append('vectorcallfunc vectorcall;') + lines += ["typedef struct {", "PyObject_HEAD", "CPyVTableItem *vtable;"] + if cl.has_method("__call__") and emitter.use_vectorcall(): + lines.append("vectorcallfunc vectorcall;") for base in reversed(cl.base_mro): if not base.is_trait: for attr, rtype in base.attributes.items(): if (attr, rtype) not in seen_attrs: - lines.append('{}{};'.format(emitter.ctype_spaced(rtype), - emitter.attr(attr))) + lines.append("{}{};".format(emitter.ctype_spaced(rtype), emitter.attr(attr))) seen_attrs.add((attr, rtype)) if isinstance(rtype, RTuple): emitter.declare_tuple_struct(rtype) - lines.append(f'}} {cl.struct_name(emitter.names)};') - lines.append('') + lines.append(f"}} {cl.struct_name(emitter.names)};") + lines.append("") emitter.context.declarations[cl.struct_name(emitter.names)] = HeaderDeclaration( - lines, - is_type=True + lines, is_type=True ) -def generate_vtables(base: ClassIR, - vtable_setup_name: str, - vtable_name: str, - emitter: Emitter, - shadow: bool) -> str: +def generate_vtables( + base: ClassIR, vtable_setup_name: str, vtable_name: str, emitter: Emitter, shadow: bool +) -> str: """Emit the vtables and vtable setup functions for a class. This includes both the primary vtable and any trait implementation vtables. @@ -417,38 +415,43 @@ def generate_vtables(base: ClassIR, """ def trait_vtable_name(trait: ClassIR) -> str: - return '{}_{}_trait_vtable{}'.format( - base.name_prefix(emitter.names), trait.name_prefix(emitter.names), - '_shadow' if shadow else '') + return "{}_{}_trait_vtable{}".format( + base.name_prefix(emitter.names), + trait.name_prefix(emitter.names), + "_shadow" if shadow else "", + ) def trait_offset_table_name(trait: ClassIR) -> str: - return '{}_{}_offset_table'.format( + return "{}_{}_offset_table".format( base.name_prefix(emitter.names), trait.name_prefix(emitter.names) ) # Emit array definitions with enough space for all the entries - emitter.emit_line('static CPyVTableItem {}[{}];'.format( - vtable_name, - max(1, len(base.vtable_entries) + 3 * len(base.trait_vtables)))) + emitter.emit_line( + "static CPyVTableItem {}[{}];".format( + vtable_name, max(1, len(base.vtable_entries) + 3 * len(base.trait_vtables)) + ) + ) for trait, vtable in base.trait_vtables.items(): # Trait methods entry (vtable index -> method implementation). - emitter.emit_line('static CPyVTableItem {}[{}];'.format( - trait_vtable_name(trait), - max(1, len(vtable)))) + emitter.emit_line( + "static CPyVTableItem {}[{}];".format(trait_vtable_name(trait), max(1, len(vtable))) + ) # Trait attributes entry (attribute number in trait -> offset in actual struct). - emitter.emit_line('static size_t {}[{}];'.format( - trait_offset_table_name(trait), - max(1, len(trait.attributes))) + emitter.emit_line( + "static size_t {}[{}];".format( + trait_offset_table_name(trait), max(1, len(trait.attributes)) + ) ) # Emit vtable setup function - emitter.emit_line('static bool') - emitter.emit_line(f'{NATIVE_PREFIX}{vtable_setup_name}(void)') - emitter.emit_line('{') + emitter.emit_line("static bool") + emitter.emit_line(f"{NATIVE_PREFIX}{vtable_setup_name}(void)") + emitter.emit_line("{") if base.allow_interpreted_subclasses and not shadow: - emitter.emit_line(f'{NATIVE_PREFIX}{vtable_setup_name}_shadow();') + emitter.emit_line(f"{NATIVE_PREFIX}{vtable_setup_name}_shadow();") subtables = [] for trait, vtable in base.trait_vtables.items(): @@ -460,332 +463,333 @@ def trait_offset_table_name(trait: ClassIR) -> str: generate_vtable(base.vtable_entries, vtable_name, emitter, subtables, shadow) - emitter.emit_line('return 1;') - emitter.emit_line('}') + emitter.emit_line("return 1;") + emitter.emit_line("}") return vtable_name if not subtables else f"{vtable_name} + {len(subtables) * 3}" -def generate_offset_table(trait_offset_table_name: str, - emitter: Emitter, - trait: ClassIR, - cl: ClassIR) -> None: +def generate_offset_table( + trait_offset_table_name: str, emitter: Emitter, trait: ClassIR, cl: ClassIR +) -> None: """Generate attribute offset row of a trait vtable.""" - emitter.emit_line(f'size_t {trait_offset_table_name}_scratch[] = {{') + emitter.emit_line(f"size_t {trait_offset_table_name}_scratch[] = {{") for attr in trait.attributes: - emitter.emit_line('offsetof({}, {}),'.format( - cl.struct_name(emitter.names), emitter.attr(attr) - )) + emitter.emit_line( + "offsetof({}, {}),".format(cl.struct_name(emitter.names), emitter.attr(attr)) + ) if not trait.attributes: # This is for msvc. - emitter.emit_line('0') - emitter.emit_line('};') - emitter.emit_line('memcpy({name}, {name}_scratch, sizeof({name}));'.format( - name=trait_offset_table_name) + emitter.emit_line("0") + emitter.emit_line("};") + emitter.emit_line( + "memcpy({name}, {name}_scratch, sizeof({name}));".format(name=trait_offset_table_name) ) -def generate_vtable(entries: VTableEntries, - vtable_name: str, - emitter: Emitter, - subtables: List[Tuple[ClassIR, str, str]], - shadow: bool) -> None: - emitter.emit_line(f'CPyVTableItem {vtable_name}_scratch[] = {{') +def generate_vtable( + entries: VTableEntries, + vtable_name: str, + emitter: Emitter, + subtables: List[Tuple[ClassIR, str, str]], + shadow: bool, +) -> None: + emitter.emit_line(f"CPyVTableItem {vtable_name}_scratch[] = {{") if subtables: - emitter.emit_line('/* Array of trait vtables */') + emitter.emit_line("/* Array of trait vtables */") for trait, table, offset_table in subtables: emitter.emit_line( - '(CPyVTableItem){}, (CPyVTableItem){}, (CPyVTableItem){},'.format( - emitter.type_struct_name(trait), table, offset_table)) - emitter.emit_line('/* Start of real vtable */') + "(CPyVTableItem){}, (CPyVTableItem){}, (CPyVTableItem){},".format( + emitter.type_struct_name(trait), table, offset_table + ) + ) + emitter.emit_line("/* Start of real vtable */") for entry in entries: method = entry.shadow_method if shadow and entry.shadow_method else entry.method - emitter.emit_line('(CPyVTableItem){}{}{},'.format( - emitter.get_group_prefix(entry.method.decl), - NATIVE_PREFIX, - method.cname(emitter.names))) + emitter.emit_line( + "(CPyVTableItem){}{}{},".format( + emitter.get_group_prefix(entry.method.decl), + NATIVE_PREFIX, + method.cname(emitter.names), + ) + ) # msvc doesn't allow empty arrays; maybe allowing them at all is an extension? if not entries: - emitter.emit_line('NULL') - emitter.emit_line('};') - emitter.emit_line('memcpy({name}, {name}_scratch, sizeof({name}));'.format(name=vtable_name)) + emitter.emit_line("NULL") + emitter.emit_line("};") + emitter.emit_line("memcpy({name}, {name}_scratch, sizeof({name}));".format(name=vtable_name)) -def generate_setup_for_class(cl: ClassIR, - func_name: str, - defaults_fn: Optional[FuncIR], - vtable_name: str, - shadow_vtable_name: Optional[str], - emitter: Emitter) -> None: +def generate_setup_for_class( + cl: ClassIR, + func_name: str, + defaults_fn: Optional[FuncIR], + vtable_name: str, + shadow_vtable_name: Optional[str], + emitter: Emitter, +) -> None: """Generate a native function that allocates an instance of a class.""" - emitter.emit_line('static PyObject *') - emitter.emit_line(f'{func_name}(PyTypeObject *type)') - emitter.emit_line('{') - emitter.emit_line(f'{cl.struct_name(emitter.names)} *self;') - emitter.emit_line('self = ({struct} *)type->tp_alloc(type, 0);'.format( - struct=cl.struct_name(emitter.names))) - emitter.emit_line('if (self == NULL)') - emitter.emit_line(' return NULL;') + emitter.emit_line("static PyObject *") + emitter.emit_line(f"{func_name}(PyTypeObject *type)") + emitter.emit_line("{") + emitter.emit_line(f"{cl.struct_name(emitter.names)} *self;") + emitter.emit_line( + "self = ({struct} *)type->tp_alloc(type, 0);".format(struct=cl.struct_name(emitter.names)) + ) + emitter.emit_line("if (self == NULL)") + emitter.emit_line(" return NULL;") if shadow_vtable_name: - emitter.emit_line(f'if (type != {emitter.type_struct_name(cl)}) {{') - emitter.emit_line(f'self->vtable = {shadow_vtable_name};') - emitter.emit_line('} else {') - emitter.emit_line(f'self->vtable = {vtable_name};') - emitter.emit_line('}') + emitter.emit_line(f"if (type != {emitter.type_struct_name(cl)}) {{") + emitter.emit_line(f"self->vtable = {shadow_vtable_name};") + emitter.emit_line("} else {") + emitter.emit_line(f"self->vtable = {vtable_name};") + emitter.emit_line("}") else: - emitter.emit_line(f'self->vtable = {vtable_name};') + emitter.emit_line(f"self->vtable = {vtable_name};") - if cl.has_method('__call__') and emitter.use_vectorcall(): - name = cl.method_decl('__call__').cname(emitter.names) - emitter.emit_line(f'self->vectorcall = {PREFIX}{name};') + if cl.has_method("__call__") and emitter.use_vectorcall(): + name = cl.method_decl("__call__").cname(emitter.names) + emitter.emit_line(f"self->vectorcall = {PREFIX}{name};") for base in reversed(cl.base_mro): for attr, rtype in base.attributes.items(): - emitter.emit_line(r'self->{} = {};'.format( - emitter.attr(attr), emitter.c_undefined_value(rtype))) + emitter.emit_line( + r"self->{} = {};".format(emitter.attr(attr), emitter.c_undefined_value(rtype)) + ) # Initialize attributes to default values, if necessary if defaults_fn is not None: emitter.emit_lines( - 'if ({}{}((PyObject *)self) == 0) {{'.format( - NATIVE_PREFIX, defaults_fn.cname(emitter.names)), - 'Py_DECREF(self);', - 'return NULL;', - '}') - - emitter.emit_line('return (PyObject *)self;') - emitter.emit_line('}') - - -def generate_constructor_for_class(cl: ClassIR, - fn: FuncDecl, - init_fn: Optional[FuncIR], - setup_name: str, - vtable_name: str, - emitter: Emitter) -> None: + "if ({}{}((PyObject *)self) == 0) {{".format( + NATIVE_PREFIX, defaults_fn.cname(emitter.names) + ), + "Py_DECREF(self);", + "return NULL;", + "}", + ) + + emitter.emit_line("return (PyObject *)self;") + emitter.emit_line("}") + + +def generate_constructor_for_class( + cl: ClassIR, + fn: FuncDecl, + init_fn: Optional[FuncIR], + setup_name: str, + vtable_name: str, + emitter: Emitter, +) -> None: """Generate a native function that allocates and initializes an instance of a class.""" - emitter.emit_line(f'{native_function_header(fn, emitter)}') - emitter.emit_line('{') - emitter.emit_line(f'PyObject *self = {setup_name}({emitter.type_struct_name(cl)});') - emitter.emit_line('if (self == NULL)') - emitter.emit_line(' return NULL;') - args = ', '.join(['self'] + [REG_PREFIX + arg.name for arg in fn.sig.args]) + emitter.emit_line(f"{native_function_header(fn, emitter)}") + emitter.emit_line("{") + emitter.emit_line(f"PyObject *self = {setup_name}({emitter.type_struct_name(cl)});") + emitter.emit_line("if (self == NULL)") + emitter.emit_line(" return NULL;") + args = ", ".join(["self"] + [REG_PREFIX + arg.name for arg in fn.sig.args]) if init_fn is not None: - emitter.emit_line('char res = {}{}{}({});'.format( - emitter.get_group_prefix(init_fn.decl), - NATIVE_PREFIX, init_fn.cname(emitter.names), args)) - emitter.emit_line('if (res == 2) {') - emitter.emit_line('Py_DECREF(self);') - emitter.emit_line('return NULL;') - emitter.emit_line('}') + emitter.emit_line( + "char res = {}{}{}({});".format( + emitter.get_group_prefix(init_fn.decl), + NATIVE_PREFIX, + init_fn.cname(emitter.names), + args, + ) + ) + emitter.emit_line("if (res == 2) {") + emitter.emit_line("Py_DECREF(self);") + emitter.emit_line("return NULL;") + emitter.emit_line("}") # If there is a nontrivial ctor that we didn't define, invoke it via tp_init elif len(fn.sig.args) > 1: - emitter.emit_line( - 'int res = {}->tp_init({});'.format( - emitter.type_struct_name(cl), - args)) + emitter.emit_line("int res = {}->tp_init({});".format(emitter.type_struct_name(cl), args)) - emitter.emit_line('if (res < 0) {') - emitter.emit_line('Py_DECREF(self);') - emitter.emit_line('return NULL;') - emitter.emit_line('}') + emitter.emit_line("if (res < 0) {") + emitter.emit_line("Py_DECREF(self);") + emitter.emit_line("return NULL;") + emitter.emit_line("}") - emitter.emit_line('return self;') - emitter.emit_line('}') + emitter.emit_line("return self;") + emitter.emit_line("}") -def generate_init_for_class(cl: ClassIR, - init_fn: FuncIR, - emitter: Emitter) -> str: +def generate_init_for_class(cl: ClassIR, init_fn: FuncIR, emitter: Emitter) -> str: """Generate an init function suitable for use as tp_init. tp_init needs to be a function that returns an int, and our __init__ methods return a PyObject. Translate NULL to -1, everything else to 0. """ - func_name = f'{cl.name_prefix(emitter.names)}_init' + func_name = f"{cl.name_prefix(emitter.names)}_init" - emitter.emit_line('static int') - emitter.emit_line( - f'{func_name}(PyObject *self, PyObject *args, PyObject *kwds)') - emitter.emit_line('{') + emitter.emit_line("static int") + emitter.emit_line(f"{func_name}(PyObject *self, PyObject *args, PyObject *kwds)") + emitter.emit_line("{") if cl.allow_interpreted_subclasses or cl.builtin_base: - emitter.emit_line('return {}{}(self, args, kwds) != NULL ? 0 : -1;'.format( - PREFIX, init_fn.cname(emitter.names))) + emitter.emit_line( + "return {}{}(self, args, kwds) != NULL ? 0 : -1;".format( + PREFIX, init_fn.cname(emitter.names) + ) + ) else: - emitter.emit_line('return 0;') - emitter.emit_line('}') + emitter.emit_line("return 0;") + emitter.emit_line("}") return func_name -def generate_new_for_class(cl: ClassIR, - func_name: str, - vtable_name: str, - setup_name: str, - init_fn: Optional[FuncIR], - emitter: Emitter) -> None: - emitter.emit_line('static PyObject *') - emitter.emit_line( - f'{func_name}(PyTypeObject *type, PyObject *args, PyObject *kwds)') - emitter.emit_line('{') +def generate_new_for_class( + cl: ClassIR, + func_name: str, + vtable_name: str, + setup_name: str, + init_fn: Optional[FuncIR], + emitter: Emitter, +) -> None: + emitter.emit_line("static PyObject *") + emitter.emit_line(f"{func_name}(PyTypeObject *type, PyObject *args, PyObject *kwds)") + emitter.emit_line("{") # TODO: Check and unbox arguments if not cl.allow_interpreted_subclasses: - emitter.emit_line(f'if (type != {emitter.type_struct_name(cl)}) {{') + emitter.emit_line(f"if (type != {emitter.type_struct_name(cl)}) {{") emitter.emit_line( 'PyErr_SetString(PyExc_TypeError, "interpreted classes cannot inherit from compiled");' ) - emitter.emit_line('return NULL;') - emitter.emit_line('}') + emitter.emit_line("return NULL;") + emitter.emit_line("}") - if (not init_fn - or cl.allow_interpreted_subclasses - or cl.builtin_base - or cl.is_serializable()): + if not init_fn or cl.allow_interpreted_subclasses or cl.builtin_base or cl.is_serializable(): # Match Python semantics -- __new__ doesn't call __init__. - emitter.emit_line(f'return {setup_name}(type);') + emitter.emit_line(f"return {setup_name}(type);") else: # __new__ of a native class implicitly calls __init__ so that we # can enforce that instances are always properly initialized. This # is needed to support always defined attributes. - emitter.emit_line(f'PyObject *self = {setup_name}(type);') - emitter.emit_lines('if (self == NULL)', - ' return NULL;') + emitter.emit_line(f"PyObject *self = {setup_name}(type);") + emitter.emit_lines("if (self == NULL)", " return NULL;") emitter.emit_line( - f'PyObject *ret = {PREFIX}{init_fn.cname(emitter.names)}(self, args, kwds);') - emitter.emit_lines('if (ret == NULL)', - ' return NULL;') - emitter.emit_line('return self;') - emitter.emit_line('}') + f"PyObject *ret = {PREFIX}{init_fn.cname(emitter.names)}(self, args, kwds);" + ) + emitter.emit_lines("if (ret == NULL)", " return NULL;") + emitter.emit_line("return self;") + emitter.emit_line("}") -def generate_new_for_trait(cl: ClassIR, - func_name: str, - emitter: Emitter) -> None: - emitter.emit_line('static PyObject *') +def generate_new_for_trait(cl: ClassIR, func_name: str, emitter: Emitter) -> None: + emitter.emit_line("static PyObject *") + emitter.emit_line(f"{func_name}(PyTypeObject *type, PyObject *args, PyObject *kwds)") + emitter.emit_line("{") + emitter.emit_line(f"if (type != {emitter.type_struct_name(cl)}) {{") emitter.emit_line( - f'{func_name}(PyTypeObject *type, PyObject *args, PyObject *kwds)') - emitter.emit_line('{') - emitter.emit_line(f'if (type != {emitter.type_struct_name(cl)}) {{') - emitter.emit_line( - 'PyErr_SetString(PyExc_TypeError, ' + "PyErr_SetString(PyExc_TypeError, " '"interpreted classes cannot inherit from compiled traits");' ) - emitter.emit_line('} else {') - emitter.emit_line( - 'PyErr_SetString(PyExc_TypeError, "traits may not be directly created");' - ) - emitter.emit_line('}') - emitter.emit_line('return NULL;') - emitter.emit_line('}') + emitter.emit_line("} else {") + emitter.emit_line('PyErr_SetString(PyExc_TypeError, "traits may not be directly created");') + emitter.emit_line("}") + emitter.emit_line("return NULL;") + emitter.emit_line("}") -def generate_traverse_for_class(cl: ClassIR, - func_name: str, - emitter: Emitter) -> None: +def generate_traverse_for_class(cl: ClassIR, func_name: str, emitter: Emitter) -> None: """Emit function that performs cycle GC traversal of an instance.""" - emitter.emit_line('static int') - emitter.emit_line('{}({} *self, visitproc visit, void *arg)'.format( - func_name, - cl.struct_name(emitter.names))) - emitter.emit_line('{') + emitter.emit_line("static int") + emitter.emit_line( + "{}({} *self, visitproc visit, void *arg)".format(func_name, cl.struct_name(emitter.names)) + ) + emitter.emit_line("{") for base in reversed(cl.base_mro): for attr, rtype in base.attributes.items(): - emitter.emit_gc_visit(f'self->{emitter.attr(attr)}', rtype) + emitter.emit_gc_visit(f"self->{emitter.attr(attr)}", rtype) if cl.has_dict: struct_name = cl.struct_name(emitter.names) # __dict__ lives right after the struct and __weakref__ lives right after that - emitter.emit_gc_visit('*((PyObject **)((char *)self + sizeof({})))'.format( - struct_name), object_rprimitive) emitter.emit_gc_visit( - '*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))'.format( - struct_name), - object_rprimitive) - emitter.emit_line('return 0;') - emitter.emit_line('}') - - -def generate_clear_for_class(cl: ClassIR, - func_name: str, - emitter: Emitter) -> None: - emitter.emit_line('static int') - emitter.emit_line(f'{func_name}({cl.struct_name(emitter.names)} *self)') - emitter.emit_line('{') + "*((PyObject **)((char *)self + sizeof({})))".format(struct_name), object_rprimitive + ) + emitter.emit_gc_visit( + "*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))".format(struct_name), + object_rprimitive, + ) + emitter.emit_line("return 0;") + emitter.emit_line("}") + + +def generate_clear_for_class(cl: ClassIR, func_name: str, emitter: Emitter) -> None: + emitter.emit_line("static int") + emitter.emit_line(f"{func_name}({cl.struct_name(emitter.names)} *self)") + emitter.emit_line("{") for base in reversed(cl.base_mro): for attr, rtype in base.attributes.items(): - emitter.emit_gc_clear(f'self->{emitter.attr(attr)}', rtype) + emitter.emit_gc_clear(f"self->{emitter.attr(attr)}", rtype) if cl.has_dict: struct_name = cl.struct_name(emitter.names) # __dict__ lives right after the struct and __weakref__ lives right after that - emitter.emit_gc_clear('*((PyObject **)((char *)self + sizeof({})))'.format( - struct_name), object_rprimitive) emitter.emit_gc_clear( - '*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))'.format( - struct_name), - object_rprimitive) - emitter.emit_line('return 0;') - emitter.emit_line('}') - - -def generate_dealloc_for_class(cl: ClassIR, - dealloc_func_name: str, - clear_func_name: str, - emitter: Emitter) -> None: - emitter.emit_line('static void') - emitter.emit_line(f'{dealloc_func_name}({cl.struct_name(emitter.names)} *self)') - emitter.emit_line('{') - emitter.emit_line('PyObject_GC_UnTrack(self);') + "*((PyObject **)((char *)self + sizeof({})))".format(struct_name), object_rprimitive + ) + emitter.emit_gc_clear( + "*((PyObject **)((char *)self + sizeof(PyObject *) + sizeof({})))".format(struct_name), + object_rprimitive, + ) + emitter.emit_line("return 0;") + emitter.emit_line("}") + + +def generate_dealloc_for_class( + cl: ClassIR, dealloc_func_name: str, clear_func_name: str, emitter: Emitter +) -> None: + emitter.emit_line("static void") + emitter.emit_line(f"{dealloc_func_name}({cl.struct_name(emitter.names)} *self)") + emitter.emit_line("{") + emitter.emit_line("PyObject_GC_UnTrack(self);") # The trashcan is needed to handle deep recursive deallocations - emitter.emit_line(f'CPy_TRASHCAN_BEGIN(self, {dealloc_func_name})') - emitter.emit_line(f'{clear_func_name}(self);') - emitter.emit_line('Py_TYPE(self)->tp_free((PyObject *)self);') - emitter.emit_line('CPy_TRASHCAN_END(self)') - emitter.emit_line('}') + emitter.emit_line(f"CPy_TRASHCAN_BEGIN(self, {dealloc_func_name})") + emitter.emit_line(f"{clear_func_name}(self);") + emitter.emit_line("Py_TYPE(self)->tp_free((PyObject *)self);") + emitter.emit_line("CPy_TRASHCAN_END(self)") + emitter.emit_line("}") -def generate_methods_table(cl: ClassIR, - name: str, - emitter: Emitter) -> None: - emitter.emit_line(f'static PyMethodDef {name}[] = {{') +def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter) -> None: + emitter.emit_line(f"static PyMethodDef {name}[] = {{") for fn in cl.methods.values(): if fn.decl.is_prop_setter or fn.decl.is_prop_getter: continue emitter.emit_line(f'{{"{fn.name}",') - emitter.emit_line(f' (PyCFunction){PREFIX}{fn.cname(emitter.names)},') + emitter.emit_line(f" (PyCFunction){PREFIX}{fn.cname(emitter.names)},") if use_fastcall(emitter.capi_version): - flags = ['METH_FASTCALL'] + flags = ["METH_FASTCALL"] else: - flags = ['METH_VARARGS'] - flags.append('METH_KEYWORDS') + flags = ["METH_VARARGS"] + flags.append("METH_KEYWORDS") if fn.decl.kind == FUNC_STATICMETHOD: - flags.append('METH_STATIC') + flags.append("METH_STATIC") elif fn.decl.kind == FUNC_CLASSMETHOD: - flags.append('METH_CLASS') + flags.append("METH_CLASS") - emitter.emit_line(' {}, NULL}},'.format(' | '.join(flags))) + emitter.emit_line(" {}, NULL}},".format(" | ".join(flags))) # Provide a default __getstate__ and __setstate__ - if not cl.has_method('__setstate__') and not cl.has_method('__getstate__'): + if not cl.has_method("__setstate__") and not cl.has_method("__getstate__"): emitter.emit_lines( '{"__setstate__", (PyCFunction)CPyPickle_SetState, METH_O, NULL},', '{"__getstate__", (PyCFunction)CPyPickle_GetState, METH_NOARGS, NULL},', ) - emitter.emit_line('{NULL} /* Sentinel */') - emitter.emit_line('};') + emitter.emit_line("{NULL} /* Sentinel */") + emitter.emit_line("};") -def generate_side_table_for_class(cl: ClassIR, - name: str, - type: str, - slots: Dict[str, str], - emitter: Emitter) -> Optional[str]: - name = f'{cl.name_prefix(emitter.names)}_{name}' - emitter.emit_line(f'static {type} {name} = {{') +def generate_side_table_for_class( + cl: ClassIR, name: str, type: str, slots: Dict[str, str], emitter: Emitter +) -> Optional[str]: + name = f"{cl.name_prefix(emitter.names)}_{name}" + emitter.emit_line(f"static {type} {name} = {{") for field, value in slots.items(): emitter.emit_line(f".{field} = {value},") emitter.emit_line("};") @@ -795,83 +799,92 @@ def generate_side_table_for_class(cl: ClassIR, def generate_getseter_declarations(cl: ClassIR, emitter: Emitter) -> None: if not cl.is_trait: for attr in cl.attributes: - emitter.emit_line('static PyObject *') - emitter.emit_line('{}({} *self, void *closure);'.format( - getter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) - emitter.emit_line('static int') - emitter.emit_line('{}({} *self, PyObject *value, void *closure);'.format( - setter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) + emitter.emit_line("static PyObject *") + emitter.emit_line( + "{}({} *self, void *closure);".format( + getter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) + emitter.emit_line("static int") + emitter.emit_line( + "{}({} *self, PyObject *value, void *closure);".format( + setter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) for prop in cl.properties: # Generate getter declaration - emitter.emit_line('static PyObject *') - emitter.emit_line('{}({} *self, void *closure);'.format( - getter_name(cl, prop, emitter.names), - cl.struct_name(emitter.names))) + emitter.emit_line("static PyObject *") + emitter.emit_line( + "{}({} *self, void *closure);".format( + getter_name(cl, prop, emitter.names), cl.struct_name(emitter.names) + ) + ) # Generate property setter declaration if a setter exists if cl.properties[prop][1]: - emitter.emit_line('static int') - emitter.emit_line('{}({} *self, PyObject *value, void *closure);'.format( - setter_name(cl, prop, emitter.names), - cl.struct_name(emitter.names))) + emitter.emit_line("static int") + emitter.emit_line( + "{}({} *self, PyObject *value, void *closure);".format( + setter_name(cl, prop, emitter.names), cl.struct_name(emitter.names) + ) + ) -def generate_getseters_table(cl: ClassIR, - name: str, - emitter: Emitter) -> None: - emitter.emit_line(f'static PyGetSetDef {name}[] = {{') +def generate_getseters_table(cl: ClassIR, name: str, emitter: Emitter) -> None: + emitter.emit_line(f"static PyGetSetDef {name}[] = {{") if not cl.is_trait: for attr in cl.attributes: emitter.emit_line(f'{{"{attr}",') - emitter.emit_line(' (getter){}, (setter){},'.format( - getter_name(cl, attr, emitter.names), setter_name(cl, attr, emitter.names))) - emitter.emit_line(' NULL, NULL},') + emitter.emit_line( + " (getter){}, (setter){},".format( + getter_name(cl, attr, emitter.names), setter_name(cl, attr, emitter.names) + ) + ) + emitter.emit_line(" NULL, NULL},") for prop in cl.properties: emitter.emit_line(f'{{"{prop}",') - emitter.emit_line(f' (getter){getter_name(cl, prop, emitter.names)},') + emitter.emit_line(f" (getter){getter_name(cl, prop, emitter.names)},") setter = cl.properties[prop][1] if setter: - emitter.emit_line(f' (setter){setter_name(cl, prop, emitter.names)},') - emitter.emit_line('NULL, NULL},') + emitter.emit_line(f" (setter){setter_name(cl, prop, emitter.names)},") + emitter.emit_line("NULL, NULL},") else: - emitter.emit_line('NULL, NULL, NULL},') + emitter.emit_line("NULL, NULL, NULL},") - emitter.emit_line('{NULL} /* Sentinel */') - emitter.emit_line('};') + emitter.emit_line("{NULL} /* Sentinel */") + emitter.emit_line("};") def generate_getseters(cl: ClassIR, emitter: Emitter) -> None: if not cl.is_trait: for i, (attr, rtype) in enumerate(cl.attributes.items()): generate_getter(cl, attr, rtype, emitter) - emitter.emit_line('') + emitter.emit_line("") generate_setter(cl, attr, rtype, emitter) if i < len(cl.attributes) - 1: - emitter.emit_line('') + emitter.emit_line("") for prop, (getter, setter) in cl.properties.items(): rtype = getter.sig.ret_type - emitter.emit_line('') + emitter.emit_line("") generate_readonly_getter(cl, prop, rtype, getter, emitter) if setter: arg_type = setter.sig.args[1].type - emitter.emit_line('') + emitter.emit_line("") generate_property_setter(cl, prop, arg_type, setter, emitter) -def generate_getter(cl: ClassIR, - attr: str, - rtype: RType, - emitter: Emitter) -> None: +def generate_getter(cl: ClassIR, attr: str, rtype: RType, emitter: Emitter) -> None: attr_field = emitter.attr(attr) - emitter.emit_line('static PyObject *') - emitter.emit_line('{}({} *self, void *closure)'.format(getter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) - emitter.emit_line('{') - attr_expr = f'self->{attr_field}' + emitter.emit_line("static PyObject *") + emitter.emit_line( + "{}({} *self, void *closure)".format( + getter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) + emitter.emit_line("{") + attr_expr = f"self->{attr_field}" # HACK: Don't consider refcounted values as always defined, since it's possible to # access uninitialized values via 'gc.get_objects()'. Accessing non-refcounted @@ -879,37 +892,36 @@ def generate_getter(cl: ClassIR, always_defined = cl.is_always_defined(attr) and not rtype.is_refcounted if not always_defined: - emitter.emit_undefined_attr_check(rtype, attr_expr, '==', unlikely=True) - emitter.emit_line('PyErr_SetString(PyExc_AttributeError,') - emitter.emit_line(' "attribute {} of {} undefined");'.format(repr(attr), - repr(cl.name))) - emitter.emit_line('return NULL;') - emitter.emit_line('}') - emitter.emit_inc_ref(f'self->{attr_field}', rtype) - emitter.emit_box(f'self->{attr_field}', 'retval', rtype, declare_dest=True) - emitter.emit_line('return retval;') - emitter.emit_line('}') - - -def generate_setter(cl: ClassIR, - attr: str, - rtype: RType, - emitter: Emitter) -> None: + emitter.emit_undefined_attr_check(rtype, attr_expr, "==", unlikely=True) + emitter.emit_line("PyErr_SetString(PyExc_AttributeError,") + emitter.emit_line(' "attribute {} of {} undefined");'.format(repr(attr), repr(cl.name))) + emitter.emit_line("return NULL;") + emitter.emit_line("}") + emitter.emit_inc_ref(f"self->{attr_field}", rtype) + emitter.emit_box(f"self->{attr_field}", "retval", rtype, declare_dest=True) + emitter.emit_line("return retval;") + emitter.emit_line("}") + + +def generate_setter(cl: ClassIR, attr: str, rtype: RType, emitter: Emitter) -> None: attr_field = emitter.attr(attr) - emitter.emit_line('static int') - emitter.emit_line('{}({} *self, PyObject *value, void *closure)'.format( - setter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) - emitter.emit_line('{') + emitter.emit_line("static int") + emitter.emit_line( + "{}({} *self, PyObject *value, void *closure)".format( + setter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) + emitter.emit_line("{") deletable = cl.is_deletable(attr) if not deletable: - emitter.emit_line('if (value == NULL) {') - emitter.emit_line('PyErr_SetString(PyExc_AttributeError,') - emitter.emit_line(' "{} object attribute {} cannot be deleted");'.format(repr(cl.name), - repr(attr))) - emitter.emit_line('return -1;') - emitter.emit_line('}') + emitter.emit_line("if (value == NULL) {") + emitter.emit_line("PyErr_SetString(PyExc_AttributeError,") + emitter.emit_line( + ' "{} object attribute {} cannot be deleted");'.format(repr(cl.name), repr(attr)) + ) + emitter.emit_line("return -1;") + emitter.emit_line("}") # HACK: Don't consider refcounted values as always defined, since it's possible to # access uninitialized values via 'gc.get_objects()'. Accessing non-refcounted @@ -917,74 +929,78 @@ def generate_setter(cl: ClassIR, always_defined = cl.is_always_defined(attr) and not rtype.is_refcounted if rtype.is_refcounted: - attr_expr = f'self->{attr_field}' + attr_expr = f"self->{attr_field}" if not always_defined: - emitter.emit_undefined_attr_check(rtype, attr_expr, '!=') - emitter.emit_dec_ref('self->{}'.format(attr_field), rtype) + emitter.emit_undefined_attr_check(rtype, attr_expr, "!=") + emitter.emit_dec_ref("self->{}".format(attr_field), rtype) if not always_defined: - emitter.emit_line('}') + emitter.emit_line("}") if deletable: - emitter.emit_line('if (value != NULL) {') + emitter.emit_line("if (value != NULL) {") if rtype.is_unboxed: - emitter.emit_unbox('value', 'tmp', rtype, error=ReturnHandler('-1'), declare_dest=True) + emitter.emit_unbox("value", "tmp", rtype, error=ReturnHandler("-1"), declare_dest=True) elif is_same_type(rtype, object_rprimitive): - emitter.emit_line('PyObject *tmp = value;') + emitter.emit_line("PyObject *tmp = value;") else: - emitter.emit_cast('value', 'tmp', rtype, declare_dest=True) - emitter.emit_lines('if (!tmp)', - ' return -1;') - emitter.emit_inc_ref('tmp', rtype) - emitter.emit_line(f'self->{attr_field} = tmp;') + emitter.emit_cast("value", "tmp", rtype, declare_dest=True) + emitter.emit_lines("if (!tmp)", " return -1;") + emitter.emit_inc_ref("tmp", rtype) + emitter.emit_line(f"self->{attr_field} = tmp;") if deletable: - emitter.emit_line('} else') - emitter.emit_line(' self->{} = {};'.format(attr_field, - emitter.c_undefined_value(rtype))) - emitter.emit_line('return 0;') - emitter.emit_line('}') - - -def generate_readonly_getter(cl: ClassIR, - attr: str, - rtype: RType, - func_ir: FuncIR, - emitter: Emitter) -> None: - emitter.emit_line('static PyObject *') - emitter.emit_line('{}({} *self, void *closure)'.format(getter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) - emitter.emit_line('{') + emitter.emit_line("} else") + emitter.emit_line( + " self->{} = {};".format(attr_field, emitter.c_undefined_value(rtype)) + ) + emitter.emit_line("return 0;") + emitter.emit_line("}") + + +def generate_readonly_getter( + cl: ClassIR, attr: str, rtype: RType, func_ir: FuncIR, emitter: Emitter +) -> None: + emitter.emit_line("static PyObject *") + emitter.emit_line( + "{}({} *self, void *closure)".format( + getter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) + emitter.emit_line("{") if rtype.is_unboxed: - emitter.emit_line('{}retval = {}{}((PyObject *) self);'.format( - emitter.ctype_spaced(rtype), NATIVE_PREFIX, func_ir.cname(emitter.names))) - emitter.emit_box('retval', 'retbox', rtype, declare_dest=True) - emitter.emit_line('return retbox;') + emitter.emit_line( + "{}retval = {}{}((PyObject *) self);".format( + emitter.ctype_spaced(rtype), NATIVE_PREFIX, func_ir.cname(emitter.names) + ) + ) + emitter.emit_box("retval", "retbox", rtype, declare_dest=True) + emitter.emit_line("return retbox;") else: - emitter.emit_line('return {}{}((PyObject *) self);'.format(NATIVE_PREFIX, - func_ir.cname(emitter.names))) - emitter.emit_line('}') - - -def generate_property_setter(cl: ClassIR, - attr: str, - arg_type: RType, - func_ir: FuncIR, - emitter: Emitter) -> None: - - emitter.emit_line('static int') - emitter.emit_line('{}({} *self, PyObject *value, void *closure)'.format( - setter_name(cl, attr, emitter.names), - cl.struct_name(emitter.names))) - emitter.emit_line('{') + emitter.emit_line( + "return {}{}((PyObject *) self);".format(NATIVE_PREFIX, func_ir.cname(emitter.names)) + ) + emitter.emit_line("}") + + +def generate_property_setter( + cl: ClassIR, attr: str, arg_type: RType, func_ir: FuncIR, emitter: Emitter +) -> None: + + emitter.emit_line("static int") + emitter.emit_line( + "{}({} *self, PyObject *value, void *closure)".format( + setter_name(cl, attr, emitter.names), cl.struct_name(emitter.names) + ) + ) + emitter.emit_line("{") if arg_type.is_unboxed: - emitter.emit_unbox('value', 'tmp', arg_type, error=ReturnHandler('-1'), - declare_dest=True) - emitter.emit_line('{}{}((PyObject *) self, tmp);'.format( - NATIVE_PREFIX, - func_ir.cname(emitter.names))) + emitter.emit_unbox("value", "tmp", arg_type, error=ReturnHandler("-1"), declare_dest=True) + emitter.emit_line( + "{}{}((PyObject *) self, tmp);".format(NATIVE_PREFIX, func_ir.cname(emitter.names)) + ) else: - emitter.emit_line('{}{}((PyObject *) self, value);'.format( - NATIVE_PREFIX, - func_ir.cname(emitter.names))) - emitter.emit_line('return 0;') - emitter.emit_line('}') + emitter.emit_line( + "{}{}((PyObject *) self, value);".format(NATIVE_PREFIX, func_ir.cname(emitter.names)) + ) + emitter.emit_line("return 0;") + emitter.emit_line("}") diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index 683bf3e7a034f..ca93313dbf12a 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -1,57 +1,100 @@ """Code generation for native function bodies.""" -from typing import List, Union, Optional +from typing import List, Optional, Union + from typing_extensions import Final -from mypyc.common import ( - REG_PREFIX, NATIVE_PREFIX, STATIC_PREFIX, TYPE_PREFIX, MODULE_PREFIX, -) -from mypyc.codegen.emit import Emitter, TracebackAndGotoHandler, DEBUG_ERRORS +from mypyc.analysis.blockfreq import frequently_executed_blocks +from mypyc.codegen.emit import DEBUG_ERRORS, Emitter, TracebackAndGotoHandler +from mypyc.common import MODULE_PREFIX, NATIVE_PREFIX, REG_PREFIX, STATIC_PREFIX, TYPE_PREFIX +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FuncDecl, FuncIR, all_values from mypyc.ir.ops import ( - Op, OpVisitor, Goto, Branch, Return, Assign, Integer, LoadErrorValue, GetAttr, SetAttr, - LoadStatic, InitStatic, TupleGet, TupleSet, Call, IncRef, DecRef, Box, Cast, Unbox, - BasicBlock, Value, MethodCall, Unreachable, NAMESPACE_STATIC, NAMESPACE_TYPE, NAMESPACE_MODULE, - RaiseStandardError, CallC, LoadGlobal, Truncate, IntOp, LoadMem, GetElementPtr, - LoadAddress, ComparisonOp, SetMem, Register, LoadLiteral, AssignMulti, KeepAlive, Extend, - ERR_FALSE + ERR_FALSE, + NAMESPACE_MODULE, + NAMESPACE_STATIC, + NAMESPACE_TYPE, + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + DecRef, + Extend, + GetAttr, + GetElementPtr, + Goto, + IncRef, + InitStatic, + Integer, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + OpVisitor, + RaiseStandardError, + Register, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unbox, + Unreachable, + Value, ) +from mypyc.ir.pprint import generate_names_for_ir from mypyc.ir.rtypes import ( - RType, RTuple, RArray, is_tagged, is_int32_rprimitive, is_int64_rprimitive, RStruct, - is_pointer_rprimitive, is_int_rprimitive + RArray, + RStruct, + RTuple, + RType, + is_int32_rprimitive, + is_int64_rprimitive, + is_int_rprimitive, + is_pointer_rprimitive, + is_tagged, ) -from mypyc.ir.func_ir import FuncIR, FuncDecl, FUNC_STATICMETHOD, FUNC_CLASSMETHOD, all_values -from mypyc.ir.class_ir import ClassIR -from mypyc.ir.pprint import generate_names_for_ir -from mypyc.analysis.blockfreq import frequently_executed_blocks def native_function_type(fn: FuncIR, emitter: Emitter) -> str: - args = ', '.join(emitter.ctype(arg.type) for arg in fn.args) or 'void' + args = ", ".join(emitter.ctype(arg.type) for arg in fn.args) or "void" ret = emitter.ctype(fn.ret_type) - return f'{ret} (*)({args})' + return f"{ret} (*)({args})" def native_function_header(fn: FuncDecl, emitter: Emitter) -> str: args = [] for arg in fn.sig.args: - args.append(f'{emitter.ctype_spaced(arg.type)}{REG_PREFIX}{arg.name}') + args.append(f"{emitter.ctype_spaced(arg.type)}{REG_PREFIX}{arg.name}") - return '{ret_type}{name}({args})'.format( + return "{ret_type}{name}({args})".format( ret_type=emitter.ctype_spaced(fn.sig.ret_type), name=emitter.native_function_name(fn), - args=', '.join(args) or 'void') + args=", ".join(args) or "void", + ) -def generate_native_function(fn: FuncIR, - emitter: Emitter, - source_path: str, - module_name: str) -> None: +def generate_native_function( + fn: FuncIR, emitter: Emitter, source_path: str, module_name: str +) -> None: declarations = Emitter(emitter.context) names = generate_names_for_ir(fn.arg_regs, fn.blocks) body = Emitter(emitter.context, names) visitor = FunctionEmitterVisitor(body, declarations, source_path, module_name) - declarations.emit_line(f'{native_function_header(fn.decl, emitter)} {{') + declarations.emit_line(f"{native_function_header(fn.decl, emitter)} {{") body.indent() for r in all_values(fn.arg_regs, fn.blocks): @@ -64,11 +107,12 @@ def generate_native_function(fn: FuncIR, continue # Skip the arguments ctype = emitter.ctype_spaced(r.type) - init = '' - declarations.emit_line('{ctype}{prefix}{name}{init};'.format(ctype=ctype, - prefix=REG_PREFIX, - name=names[r], - init=init)) + init = "" + declarations.emit_line( + "{ctype}{prefix}{name}{init};".format( + ctype=ctype, prefix=REG_PREFIX, name=names[r], init=init + ) + ) # Before we emit the blocks, give them all labels blocks = fn.blocks @@ -93,18 +137,16 @@ def generate_native_function(fn: FuncIR, ops[visitor.op_index].accept(visitor) visitor.op_index += 1 - body.emit_line('}') + body.emit_line("}") emitter.emit_from_emitter(declarations) emitter.emit_from_emitter(body) class FunctionEmitterVisitor(OpVisitor[None]): - def __init__(self, - emitter: Emitter, - declarations: Emitter, - source_path: str, - module_name: str) -> None: + def __init__( + self, emitter: Emitter, declarations: Emitter, source_path: str, module_name: str + ) -> None: self.emitter = emitter self.names = emitter.names self.declarations = declarations @@ -124,7 +166,7 @@ def temp_name(self) -> str: def visit_goto(self, op: Goto) -> None: if op.label is not self.next_block: - self.emit_line('goto %s;' % self.label(op.label)) + self.emit_line("goto %s;" % self.label(op.label)) def visit_branch(self, op: Branch) -> None: true, false = op.true, op.false @@ -133,7 +175,7 @@ def visit_branch(self, op: Branch) -> None: if op2.class_type.class_ir.is_always_defined(op2.attr): # Getting an always defined attribute never fails, so the branch can be omitted. if false is not self.next_block: - self.emit_line('goto {};'.format(self.label(false))) + self.emit_line("goto {};".format(self.label(false))) return negated = op.negated negated_rare = False @@ -143,66 +185,58 @@ def visit_branch(self, op: Branch) -> None: negated = not negated negated_rare = True - neg = '!' if negated else '' - cond = '' + neg = "!" if negated else "" + cond = "" if op.op == Branch.BOOL: expr_result = self.reg(op.value) - cond = f'{neg}{expr_result}' + cond = f"{neg}{expr_result}" elif op.op == Branch.IS_ERROR: typ = op.value.type - compare = '!=' if negated else '==' + compare = "!=" if negated else "==" if isinstance(typ, RTuple): # TODO: What about empty tuple? - cond = self.emitter.tuple_undefined_check_cond(typ, - self.reg(op.value), - self.c_error_value, - compare) + cond = self.emitter.tuple_undefined_check_cond( + typ, self.reg(op.value), self.c_error_value, compare + ) else: - cond = '{} {} {}'.format(self.reg(op.value), - compare, - self.c_error_value(typ)) + cond = "{} {} {}".format(self.reg(op.value), compare, self.c_error_value(typ)) else: assert False, "Invalid branch" # For error checks, tell the compiler the branch is unlikely if op.traceback_entry is not None or op.rare: if not negated_rare: - cond = f'unlikely({cond})' + cond = f"unlikely({cond})" else: - cond = f'likely({cond})' + cond = f"likely({cond})" if false is self.next_block: if op.traceback_entry is None: - self.emit_line(f'if ({cond}) goto {self.label(true)};') + self.emit_line(f"if ({cond}) goto {self.label(true)};") else: - self.emit_line(f'if ({cond}) {{') + self.emit_line(f"if ({cond}) {{") self.emit_traceback(op) - self.emit_lines( - 'goto %s;' % self.label(true), - '}' - ) + self.emit_lines("goto %s;" % self.label(true), "}") else: - self.emit_line(f'if ({cond}) {{') + self.emit_line(f"if ({cond}) {{") self.emit_traceback(op) self.emit_lines( - 'goto %s;' % self.label(true), - '} else', - ' goto %s;' % self.label(false) + "goto %s;" % self.label(true), "} else", " goto %s;" % self.label(false) ) def visit_return(self, op: Return) -> None: value_str = self.reg(op.value) - self.emit_line('return %s;' % value_str) + self.emit_line("return %s;" % value_str) def visit_tuple_set(self, op: TupleSet) -> None: dest = self.reg(op) tuple_type = op.tuple_type self.emitter.declare_tuple_struct(tuple_type) if len(op.items) == 0: # empty tuple - self.emit_line(f'{dest}.empty_struct_error_flag = 0;') + self.emit_line(f"{dest}.empty_struct_error_flag = 0;") else: for i, item in enumerate(op.items): - self.emit_line(f'{dest}.f{i} = {self.reg(item)};') + self.emit_line(f"{dest}.f{i} = {self.reg(item)};") self.emit_inc_ref(dest, tuple_type) def visit_assign(self, op: Assign) -> None: @@ -214,8 +248,8 @@ def visit_assign(self, op: Assign) -> None: # We sometimes assign from an integer prepresentation of a pointer # to a real pointer, and C compilers insist on a cast. if op.src.type.is_unboxed and not op.dest.type.is_unboxed: - src = f'(void *){src}' - self.emit_line(f'{dest} = {src};') + src = f"(void *){src}" + self.emit_line(f"{dest} = {src};") def visit_assign_multi(self, op: AssignMulti) -> None: typ = op.dest.type @@ -223,34 +257,36 @@ def visit_assign_multi(self, op: AssignMulti) -> None: dest = self.reg(op.dest) # RArray values can only be assigned to once, so we can always # declare them on initialization. - self.emit_line('%s%s[%d] = {%s};' % ( - self.emitter.ctype_spaced(typ.item_type), - dest, - len(op.src), - ', '.join(self.reg(s) for s in op.src))) + self.emit_line( + "%s%s[%d] = {%s};" + % ( + self.emitter.ctype_spaced(typ.item_type), + dest, + len(op.src), + ", ".join(self.reg(s) for s in op.src), + ) + ) def visit_load_error_value(self, op: LoadErrorValue) -> None: if isinstance(op.type, RTuple): values = [self.c_undefined_value(item) for item in op.type.types] tmp = self.temp_name() - self.emit_line('{} {} = {{ {} }};'.format(self.ctype(op.type), tmp, ', '.join(values))) - self.emit_line(f'{self.reg(op)} = {tmp};') + self.emit_line("{} {} = {{ {} }};".format(self.ctype(op.type), tmp, ", ".join(values))) + self.emit_line(f"{self.reg(op)} = {tmp};") else: - self.emit_line('{} = {};'.format(self.reg(op), - self.c_error_value(op.type))) + self.emit_line("{} = {};".format(self.reg(op), self.c_error_value(op.type))) def visit_load_literal(self, op: LoadLiteral) -> None: index = self.literals.literal_index(op.value) s = repr(op.value) - if not any(x in s for x in ('/*', '*/', '\0')): - ann = ' /* %s */' % s + if not any(x in s for x in ("/*", "*/", "\0")): + ann = " /* %s */" % s else: - ann = '' + ann = "" if not is_int_rprimitive(op.type): - self.emit_line('%s = CPyStatics[%d];%s' % (self.reg(op), index, ann)) + self.emit_line("%s = CPyStatics[%d];%s" % (self.reg(op), index, ann)) else: - self.emit_line('%s = (CPyTagged)CPyStatics[%d] | 1;%s' % ( - self.reg(op), index, ann)) + self.emit_line("%s = (CPyTagged)CPyStatics[%d] | 1;%s" % (self.reg(op), index, ann)) def get_attr_expr(self, obj: str, op: Union[GetAttr, SetAttr], decl_cl: ClassIR) -> str: """Generate attribute accessor for normal (non-property) access. @@ -259,7 +295,7 @@ def get_attr_expr(self, obj: str, op: Union[GetAttr, SetAttr], decl_cl: ClassIR) classes, and *(obj + attr_offset) for attributes defined by traits. We also insert all necessary C casts here. """ - cast = f'({op.class_type.struct_name(self.emitter.names)} *)' + cast = f"({op.class_type.struct_name(self.emitter.names)} *)" if decl_cl.is_trait and op.class_type.class_ir.is_trait: # For pure trait access find the offset first, offsets # are ordered by attribute position in the cl.attributes dict. @@ -267,26 +303,26 @@ def get_attr_expr(self, obj: str, op: Union[GetAttr, SetAttr], decl_cl: ClassIR) trait_attr_index = list(decl_cl.attributes).index(op.attr) # TODO: reuse these names somehow? offset = self.emitter.temp_name() - self.declarations.emit_line(f'size_t {offset};') - self.emitter.emit_line('{} = {};'.format( - offset, - 'CPy_FindAttrOffset({}, {}, {})'.format( - self.emitter.type_struct_name(decl_cl), - f'({cast}{obj})->vtable', - trait_attr_index, + self.declarations.emit_line(f"size_t {offset};") + self.emitter.emit_line( + "{} = {};".format( + offset, + "CPy_FindAttrOffset({}, {}, {})".format( + self.emitter.type_struct_name(decl_cl), + f"({cast}{obj})->vtable", + trait_attr_index, + ), ) - )) - attr_cast = f'({self.ctype(op.class_type.attr_type(op.attr))} *)' - return f'*{attr_cast}((char *){obj} + {offset})' + ) + attr_cast = f"({self.ctype(op.class_type.attr_type(op.attr))} *)" + return f"*{attr_cast}((char *){obj} + {offset})" else: # Cast to something non-trait. Note: for this to work, all struct # members for non-trait classes must obey monotonic linear growth. if op.class_type.class_ir.is_trait: assert not decl_cl.is_trait - cast = f'({decl_cl.struct_name(self.emitter.names)} *)' - return '({}{})->{}'.format( - cast, obj, self.emitter.attr(op.attr) - ) + cast = f"({decl_cl.struct_name(self.emitter.names)} *)" + return "({}{})->{}".format(cast, obj, self.emitter.attr(op.attr)) def visit_get_attr(self, op: GetAttr) -> None: dest = self.reg(op) @@ -296,54 +332,60 @@ def visit_get_attr(self, op: GetAttr) -> None: attr_rtype, decl_cl = cl.attr_details(op.attr) if cl.get_method(op.attr): # Properties are essentially methods, so use vtable access for them. - version = '_TRAIT' if cl.is_trait else '' - self.emit_line('%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */' % ( - dest, - version, - obj, - self.emitter.type_struct_name(rtype.class_ir), - rtype.getter_index(op.attr), - rtype.struct_name(self.names), - self.ctype(rtype.attr_type(op.attr)), - op.attr)) + version = "_TRAIT" if cl.is_trait else "" + self.emit_line( + "%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */" + % ( + dest, + version, + obj, + self.emitter.type_struct_name(rtype.class_ir), + rtype.getter_index(op.attr), + rtype.struct_name(self.names), + self.ctype(rtype.attr_type(op.attr)), + op.attr, + ) + ) else: # Otherwise, use direct or offset struct access. attr_expr = self.get_attr_expr(obj, op, decl_cl) - self.emitter.emit_line(f'{dest} = {attr_expr};') + self.emitter.emit_line(f"{dest} = {attr_expr};") always_defined = cl.is_always_defined(op.attr) merged_branch = None if not always_defined: - self.emitter.emit_undefined_attr_check( - attr_rtype, dest, '==', unlikely=True - ) + self.emitter.emit_undefined_attr_check(attr_rtype, dest, "==", unlikely=True) branch = self.next_branch() if branch is not None: - if (branch.value is op - and branch.op == Branch.IS_ERROR - and branch.traceback_entry is not None - and not branch.negated): + if ( + branch.value is op + and branch.op == Branch.IS_ERROR + and branch.traceback_entry is not None + and not branch.negated + ): # Generate code for the following branch here to avoid # redundant branches in the generated code. self.emit_attribute_error(branch, cl.name, op.attr) - self.emit_line('goto %s;' % self.label(branch.true)) + self.emit_line("goto %s;" % self.label(branch.true)) merged_branch = branch - self.emitter.emit_line('}') + self.emitter.emit_line("}") if not merged_branch: - exc_class = 'PyExc_AttributeError' + exc_class = "PyExc_AttributeError" self.emitter.emit_line( 'PyErr_SetString({}, "attribute {} of {} undefined");'.format( - exc_class, repr(op.attr), repr(cl.name))) + exc_class, repr(op.attr), repr(cl.name) + ) + ) if attr_rtype.is_refcounted and not op.is_borrowed: if not merged_branch and not always_defined: - self.emitter.emit_line('} else {') + self.emitter.emit_line("} else {") self.emitter.emit_inc_ref(dest, attr_rtype) if merged_branch: if merged_branch.false is not self.next_block: - self.emit_line('goto %s;' % self.label(merged_branch.false)) + self.emit_line("goto %s;" % self.label(merged_branch.false)) self.op_index += 1 elif not always_defined: - self.emitter.emit_line('}') + self.emitter.emit_line("}") def next_branch(self) -> Optional[Branch]: if self.op_index + 1 < len(self.ops): @@ -362,19 +404,27 @@ def visit_set_attr(self, op: SetAttr) -> None: attr_rtype, decl_cl = cl.attr_details(op.attr) if cl.get_method(op.attr): # Again, use vtable access for properties... - assert not op.is_init and op.error_kind == ERR_FALSE, '%s %d %d %s' % ( - op.attr, op.is_init, op.error_kind, rtype) - version = '_TRAIT' if cl.is_trait else '' - self.emit_line('%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */' % ( - dest, - version, - obj, - self.emitter.type_struct_name(rtype.class_ir), - rtype.setter_index(op.attr), - src, - rtype.struct_name(self.names), - self.ctype(rtype.attr_type(op.attr)), - op.attr)) + assert not op.is_init and op.error_kind == ERR_FALSE, "%s %d %d %s" % ( + op.attr, + op.is_init, + op.error_kind, + rtype, + ) + version = "_TRAIT" if cl.is_trait else "" + self.emit_line( + "%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */" + % ( + dest, + version, + obj, + self.emitter.type_struct_name(rtype.class_ir), + rtype.setter_index(op.attr), + src, + rtype.struct_name(self.names), + self.ctype(rtype.attr_type(op.attr)), + op.attr, + ) + ) else: # ...and struct access for normal attributes. attr_expr = self.get_attr_expr(obj, op, decl_cl) @@ -383,14 +433,14 @@ def visit_set_attr(self, op: SetAttr) -> None: # previously undefined), so decref the old value. always_defined = cl.is_always_defined(op.attr) if not always_defined: - self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, '!=') + self.emitter.emit_undefined_attr_check(attr_rtype, attr_expr, "!=") self.emitter.emit_dec_ref(attr_expr, attr_rtype) if not always_defined: - self.emitter.emit_line('}') + self.emitter.emit_line("}") # This steals the reference to src, so we don't need to increment the arg - self.emitter.emit_line(f'{attr_expr} = {src};') + self.emitter.emit_line(f"{attr_expr} = {src};") if op.error_kind == ERR_FALSE: - self.emitter.emit_line(f'{dest} = 1;') + self.emitter.emit_line(f"{dest} = 1;") PREFIX_MAP: Final = { NAMESPACE_STATIC: STATIC_PREFIX, @@ -403,42 +453,42 @@ def visit_load_static(self, op: LoadStatic) -> None: prefix = self.PREFIX_MAP[op.namespace] name = self.emitter.static_name(op.identifier, op.module_name, prefix) if op.namespace == NAMESPACE_TYPE: - name = '(PyObject *)%s' % name - ann = '' + name = "(PyObject *)%s" % name + ann = "" if op.ann: s = repr(op.ann) - if not any(x in s for x in ('/*', '*/', '\0')): - ann = ' /* %s */' % s - self.emit_line(f'{dest} = {name};{ann}') + if not any(x in s for x in ("/*", "*/", "\0")): + ann = " /* %s */" % s + self.emit_line(f"{dest} = {name};{ann}") def visit_init_static(self, op: InitStatic) -> None: value = self.reg(op.value) prefix = self.PREFIX_MAP[op.namespace] name = self.emitter.static_name(op.identifier, op.module_name, prefix) if op.namespace == NAMESPACE_TYPE: - value = '(PyTypeObject *)%s' % value - self.emit_line(f'{name} = {value};') + value = "(PyTypeObject *)%s" % value + self.emit_line(f"{name} = {value};") self.emit_inc_ref(name, op.value.type) def visit_tuple_get(self, op: TupleGet) -> None: dest = self.reg(op) src = self.reg(op.src) - self.emit_line(f'{dest} = {src}.f{op.index};') + self.emit_line(f"{dest} = {src}.f{op.index};") self.emit_inc_ref(dest, op.type) def get_dest_assign(self, dest: Value) -> str: if not dest.is_void: - return self.reg(dest) + ' = ' + return self.reg(dest) + " = " else: - return '' + return "" def visit_call(self, op: Call) -> None: """Call native function.""" dest = self.get_dest_assign(op) - args = ', '.join(self.reg(arg) for arg in op.args) + args = ", ".join(self.reg(arg) for arg in op.args) lib = self.emitter.get_group_prefix(op.fn) cname = op.fn.cname(self.names) - self.emit_line(f'{dest}{lib}{NATIVE_PREFIX}{cname}({args});') + self.emit_line(f"{dest}{lib}{NATIVE_PREFIX}{cname}({args});") def visit_method_call(self, op: MethodCall) -> None: """Call native method.""" @@ -457,23 +507,37 @@ def visit_method_call(self, op: MethodCall) -> None: # The first argument gets omitted for static methods and # turned into the class for class methods obj_args = ( - [] if method.decl.kind == FUNC_STATICMETHOD else - [f'(PyObject *)Py_TYPE({obj})'] if method.decl.kind == FUNC_CLASSMETHOD else - [obj]) - args = ', '.join(obj_args + [self.reg(arg) for arg in op.args]) + [] + if method.decl.kind == FUNC_STATICMETHOD + else [f"(PyObject *)Py_TYPE({obj})"] + if method.decl.kind == FUNC_CLASSMETHOD + else [obj] + ) + args = ", ".join(obj_args + [self.reg(arg) for arg in op.args]) mtype = native_function_type(method, self.emitter) - version = '_TRAIT' if rtype.class_ir.is_trait else '' + version = "_TRAIT" if rtype.class_ir.is_trait else "" if is_direct: # Directly call method, without going through the vtable. lib = self.emitter.get_group_prefix(method.decl) - self.emit_line('{}{}{}{}({});'.format( - dest, lib, NATIVE_PREFIX, method.cname(self.names), args)) + self.emit_line( + "{}{}{}{}({});".format(dest, lib, NATIVE_PREFIX, method.cname(self.names), args) + ) else: # Call using vtable. method_idx = rtype.method_index(name) - self.emit_line('{}CPY_GET_METHOD{}({}, {}, {}, {}, {})({}); /* {} */'.format( - dest, version, obj, self.emitter.type_struct_name(rtype.class_ir), - method_idx, rtype.struct_name(self.names), mtype, args, op.method)) + self.emit_line( + "{}CPY_GET_METHOD{}({}, {}, {}, {}, {})({}); /* {} */".format( + dest, + version, + obj, + self.emitter.type_struct_name(rtype.class_ir), + method_idx, + rtype.struct_name(self.names), + mtype, + args, + op.method, + ) + ) def visit_inc_ref(self, op: IncRef) -> None: src = self.reg(op.src) @@ -490,51 +554,57 @@ def visit_cast(self, op: Cast) -> None: branch = self.next_branch() handler = None if branch is not None: - if (branch.value is op - and branch.op == Branch.IS_ERROR - and branch.traceback_entry is not None - and not branch.negated - and branch.false is self.next_block): + if ( + branch.value is op + and branch.op == Branch.IS_ERROR + and branch.traceback_entry is not None + and not branch.negated + and branch.false is self.next_block + ): # Generate code also for the following branch here to avoid # redundant branches in the generated code. - handler = TracebackAndGotoHandler(self.label(branch.true), - self.source_path, - self.module_name, - branch.traceback_entry) + handler = TracebackAndGotoHandler( + self.label(branch.true), + self.source_path, + self.module_name, + branch.traceback_entry, + ) self.op_index += 1 - self.emitter.emit_cast(self.reg(op.src), self.reg(op), op.type, - src_type=op.src.type, error=handler) + self.emitter.emit_cast( + self.reg(op.src), self.reg(op), op.type, src_type=op.src.type, error=handler + ) def visit_unbox(self, op: Unbox) -> None: self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type) def visit_unreachable(self, op: Unreachable) -> None: - self.emitter.emit_line('CPy_Unreachable();') + self.emitter.emit_line("CPy_Unreachable();") def visit_raise_standard_error(self, op: RaiseStandardError) -> None: # TODO: Better escaping of backspaces and such if op.value is not None: if isinstance(op.value, str): message = op.value.replace('"', '\\"') - self.emitter.emit_line( - f'PyErr_SetString(PyExc_{op.class_name}, "{message}");') + self.emitter.emit_line(f'PyErr_SetString(PyExc_{op.class_name}, "{message}");') elif isinstance(op.value, Value): self.emitter.emit_line( - 'PyErr_SetObject(PyExc_{}, {});'.format(op.class_name, - self.emitter.reg(op.value))) + "PyErr_SetObject(PyExc_{}, {});".format( + op.class_name, self.emitter.reg(op.value) + ) + ) else: - assert False, 'op value type must be either str or Value' + assert False, "op value type must be either str or Value" else: - self.emitter.emit_line(f'PyErr_SetNone(PyExc_{op.class_name});') - self.emitter.emit_line(f'{self.reg(op)} = 0;') + self.emitter.emit_line(f"PyErr_SetNone(PyExc_{op.class_name});") + self.emitter.emit_line(f"{self.reg(op)} = 0;") def visit_call_c(self, op: CallC) -> None: if op.is_void: - dest = '' + dest = "" else: dest = self.get_dest_assign(op) - args = ', '.join(self.reg(arg) for arg in op.args) + args = ", ".join(self.reg(arg) for arg in op.args) self.emitter.emit_line(f"{dest}{op.function_name}({args});") def visit_truncate(self, op: Truncate) -> None: @@ -554,12 +624,12 @@ def visit_extend(self, op: Extend) -> None: def visit_load_global(self, op: LoadGlobal) -> None: dest = self.reg(op) - ann = '' + ann = "" if op.ann: s = repr(op.ann) - if not any(x in s for x in ('/*', '*/', '\0')): - ann = ' /* %s */' % s - self.emit_line(f'{dest} = {op.identifier};{ann}') + if not any(x in s for x in ("/*", "*/", "\0")): + ann = " /* %s */" % s + self.emit_line(f"{dest} = {op.identifier};{ann}") def visit_int_op(self, op: IntOp) -> None: dest = self.reg(op) @@ -569,7 +639,7 @@ def visit_int_op(self, op: IntOp) -> None: # Signed right shift lhs = self.emit_signed_int_cast(op.lhs.type) + lhs rhs = self.emit_signed_int_cast(op.rhs.type) + rhs - self.emit_line(f'{dest} = {lhs} {op.op_str[op.op]} {rhs};') + self.emit_line(f"{dest} = {lhs} {op.op_str[op.op]} {rhs};") def visit_comparison_op(self, op: ComparisonOp) -> None: dest = self.reg(op) @@ -591,15 +661,16 @@ def visit_comparison_op(self, op: ComparisonOp) -> None: elif isinstance(op.rhs, Integer) and op.rhs.value < 0: # Force signed ==/!= with negative operand lhs_cast = self.emit_signed_int_cast(op.lhs.type) - self.emit_line('{} = {}{} {} {}{};'.format(dest, lhs_cast, lhs, - op.op_str[op.op], rhs_cast, rhs)) + self.emit_line( + "{} = {}{} {} {}{};".format(dest, lhs_cast, lhs, op.op_str[op.op], rhs_cast, rhs) + ) def visit_load_mem(self, op: LoadMem) -> None: dest = self.reg(op) src = self.reg(op.src) # TODO: we shouldn't dereference to type that are pointer type so far type = self.ctype(op.type) - self.emit_line(f'{dest} = *({type} *){src};') + self.emit_line(f"{dest} = *({type} *){src};") def visit_set_mem(self, op: SetMem) -> None: dest = self.reg(op.dest) @@ -608,7 +679,7 @@ def visit_set_mem(self, op: SetMem) -> None: # clang whines about self assignment (which we might generate # for some casts), so don't emit it. if dest != src: - self.emit_line(f'*({dest_type} *){dest} = {src};') + self.emit_line(f"*({dest_type} *){dest} = {src};") def visit_get_element_ptr(self, op: GetElementPtr) -> None: dest = self.reg(op) @@ -616,14 +687,17 @@ def visit_get_element_ptr(self, op: GetElementPtr) -> None: # TODO: support tuple type assert isinstance(op.src_type, RStruct) assert op.field in op.src_type.names, "Invalid field name." - self.emit_line('{} = ({})&(({} *){})->{};'.format(dest, op.type._ctype, op.src_type.name, - src, op.field)) + self.emit_line( + "{} = ({})&(({} *){})->{};".format( + dest, op.type._ctype, op.src_type.name, src, op.field + ) + ) def visit_load_address(self, op: LoadAddress) -> None: typ = op.type dest = self.reg(op) src = self.reg(op.src) if isinstance(op.src, Register) else op.src - self.emit_line(f'{dest} = ({typ._ctype})&{src};') + self.emit_line(f"{dest} = ({typ._ctype})&{src};") def visit_keep_alive(self, op: KeepAlive) -> None: # This is a no-op. @@ -643,14 +717,14 @@ def reg(self, reg: Value) -> str: if val >= (1 << 31): # Avoid overflowing signed 32-bit int if val >= (1 << 63): - s += 'ULL' + s += "ULL" else: - s += 'LL' + s += "LL" elif val == -(1 << 63): # Avoid overflowing C integer literal - s = '(-9223372036854775807LL - 1)' + s = "(-9223372036854775807LL - 1)" elif val <= -(1 << 31): - s += 'LL' + s += "LL" return s else: return self.emitter.reg(reg) @@ -685,27 +759,31 @@ def emit_traceback(self, op: Branch) -> None: def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None: assert op.traceback_entry is not None - globals_static = self.emitter.static_name('globals', self.module_name) - self.emit_line('CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' % ( - self.source_path.replace("\\", "\\\\"), - op.traceback_entry[0], - class_name, - attr, - op.traceback_entry[1], - globals_static)) + globals_static = self.emitter.static_name("globals", self.module_name) + self.emit_line( + 'CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);' + % ( + self.source_path.replace("\\", "\\\\"), + op.traceback_entry[0], + class_name, + attr, + op.traceback_entry[1], + globals_static, + ) + ) if DEBUG_ERRORS: self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");') def emit_signed_int_cast(self, type: RType) -> str: if is_tagged(type): - return '(Py_ssize_t)' + return "(Py_ssize_t)" else: - return '' + return "" def emit_unsigned_int_cast(self, type: RType) -> str: if is_int32_rprimitive(type): - return '(uint32_t)' + return "(uint32_t)" elif is_int64_rprimitive(type): - return '(uint64_t)' + return "(uint64_t)" else: - return '' + return "" diff --git a/mypyc/codegen/emitmodule.py b/mypyc/codegen/emitmodule.py index 6eea3f1ea8810..d16e0a74b7925 100644 --- a/mypyc/codegen/emitmodule.py +++ b/mypyc/codegen/emitmodule.py @@ -3,50 +3,62 @@ # FIXME: Basically nothing in this file operates on the level of a # single module and it should be renamed. -import os import json -from mypy.backports import OrderedDict -from typing import List, Tuple, Dict, Iterable, Set, TypeVar, Optional +import os +from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar -from mypy.nodes import MypyFile +from mypy.backports import OrderedDict from mypy.build import ( - BuildSource, BuildResult, State, build, sorted_components, get_cache_names, - create_metastore, compute_hash, + BuildResult, + BuildSource, + State, + build, + compute_hash, + create_metastore, + get_cache_names, + sorted_components, ) from mypy.errors import CompileError +from mypy.fscache import FileSystemCache +from mypy.nodes import MypyFile from mypy.options import Options from mypy.plugin import Plugin, ReportConfigContext -from mypy.fscache import FileSystemCache from mypy.util import hash_digest - -from mypyc.irbuild.main import build_ir -from mypyc.irbuild.prepare import load_type_map -from mypyc.irbuild.mapper import Mapper -from mypyc.common import ( - PREFIX, TOP_LEVEL_NAME, MODULE_PREFIX, RUNTIME_C_FILES, short_id_from_name, use_fastcall, - use_vectorcall, shared_lib_name, -) from mypyc.codegen.cstring import c_string_initializer -from mypyc.codegen.literals import Literals -from mypyc.codegen.emit import EmitterContext, Emitter, HeaderDeclaration +from mypyc.codegen.emit import Emitter, EmitterContext, HeaderDeclaration +from mypyc.codegen.emitclass import generate_class, generate_class_type_decl from mypyc.codegen.emitfunc import generate_native_function, native_function_header -from mypyc.codegen.emitclass import generate_class_type_decl, generate_class from mypyc.codegen.emitwrapper import ( - generate_wrapper_function, wrapper_function_header, - generate_legacy_wrapper_function, legacy_wrapper_function_header, + generate_legacy_wrapper_function, + generate_wrapper_function, + legacy_wrapper_function_header, + wrapper_function_header, ) -from mypyc.ir.ops import DeserMaps, LoadLiteral -from mypyc.ir.rtypes import RType, RTuple -from mypyc.ir.func_ir import FuncIR +from mypyc.codegen.literals import Literals +from mypyc.common import ( + MODULE_PREFIX, + PREFIX, + RUNTIME_C_FILES, + TOP_LEVEL_NAME, + shared_lib_name, + short_id_from_name, + use_fastcall, + use_vectorcall, +) +from mypyc.errors import Errors from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncIR from mypyc.ir.module_ir import ModuleIR, ModuleIRs, deserialize_modules +from mypyc.ir.ops import DeserMaps, LoadLiteral +from mypyc.ir.rtypes import RTuple, RType +from mypyc.irbuild.main import build_ir +from mypyc.irbuild.mapper import Mapper +from mypyc.irbuild.prepare import load_type_map +from mypyc.namegen import NameGenerator, exported_name from mypyc.options import CompilerOptions -from mypyc.transform.uninit import insert_uninit_checks -from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.transform.exceptions import insert_exception_handling -from mypyc.namegen import NameGenerator, exported_name -from mypyc.errors import Errors - +from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.uninit import insert_uninit_checks # All of the modules being compiled are divided into "groups". A group # is a set of modules that are placed into the same shared library. @@ -77,6 +89,7 @@ class MarkedDeclaration: """Add a mark, useful for topological sort.""" + def __init__(self, declaration: HeaderDeclaration, mark: bool) -> None: self.declaration = declaration self.mark = False @@ -95,7 +108,8 @@ class MypycPlugin(Plugin): """ def __init__( - self, options: Options, compiler_options: CompilerOptions, groups: Groups) -> None: + self, options: Options, compiler_options: CompilerOptions, groups: Groups + ) -> None: super().__init__(options) self.group_map: Dict[str, Tuple[Optional[str], List[str]]] = {} for sources, name in groups: @@ -107,7 +121,8 @@ def __init__( self.metastore = create_metastore(options) def report_config_data( - self, ctx: ReportConfigContext) -> Optional[Tuple[Optional[str], List[str]]]: + self, ctx: ReportConfigContext + ) -> Optional[Tuple[Optional[str], List[str]]]: # The config data we report is the group map entry for the module. # If the data is being used to check validity, we do additional checks # that the IR cache exists and matches the metadata cache and all @@ -137,16 +152,16 @@ def report_config_data( ir_data = json.loads(ir_json) # Check that the IR cache matches the metadata cache - if compute_hash(meta_json) != ir_data['meta_hash']: + if compute_hash(meta_json) != ir_data["meta_hash"]: return None # Check that all of the source files are present and as # expected. The main situation where this would come up is the # user deleting the build directory without deleting # .mypy_cache, which we should handle gracefully. - for path, hash in ir_data['src_hashes'].items(): + for path, hash in ir_data["src_hashes"].items(): try: - with open(os.path.join(self.compiler_options.target_dir, path), 'rb') as f: + with open(os.path.join(self.compiler_options.target_dir, path), "rb") as f: contents = f.read() except FileNotFoundError: return None @@ -167,14 +182,16 @@ def parse_and_typecheck( compiler_options: CompilerOptions, groups: Groups, fscache: Optional[FileSystemCache] = None, - alt_lib_path: Optional[str] = None + alt_lib_path: Optional[str] = None, ) -> BuildResult: - assert options.strict_optional, 'strict_optional must be turned on' - result = build(sources=sources, - options=options, - alt_lib_path=alt_lib_path, - fscache=fscache, - extra_plugins=[MypycPlugin(options, compiler_options, groups)]) + assert options.strict_optional, "strict_optional must be turned on" + result = build( + sources=sources, + options=options, + alt_lib_path=alt_lib_path, + fscache=fscache, + extra_plugins=[MypycPlugin(options, compiler_options, groups)], + ) if result.errors: raise CompileError(result.errors) return result @@ -206,9 +223,7 @@ def compile_scc_to_ir( print("Compiling {}".format(", ".join(x.name for x in scc))) # Generate basic IR, with missing exception and refcount handling. - modules = build_ir( - scc, result.graph, result.types, mapper, compiler_options, errors - ) + modules = build_ir(scc, result.graph, result.types, mapper, compiler_options, errors) if errors.num_errors > 0: return modules @@ -229,10 +244,7 @@ def compile_scc_to_ir( def compile_modules_to_ir( - result: BuildResult, - mapper: Mapper, - compiler_options: CompilerOptions, - errors: Errors, + result: BuildResult, mapper: Mapper, compiler_options: CompilerOptions, errors: Errors ) -> ModuleIRs: """Compile a collection of modules into ModuleIRs. @@ -273,8 +285,11 @@ def compile_ir_to_c( Returns a dictionary mapping group names to a list of (file name, file text) pairs. """ - source_paths = {source.module: result.graph[source.module].xpath - for sources, _ in groups for source in sources} + source_paths = { + source.module: result.graph[source.module].xpath + for sources, _ in groups + for source in sources + } names = NameGenerator([[source.module for source in sources] for sources, _ in groups]) @@ -282,15 +297,16 @@ def compile_ir_to_c( # compiled into a separate extension module. ctext: Dict[Optional[str], List[Tuple[str, str]]] = {} for group_sources, group_name in groups: - group_modules = [(source.module, modules[source.module]) for source in group_sources - if source.module in modules] + group_modules = [ + (source.module, modules[source.module]) + for source in group_sources + if source.module in modules + ] if not group_modules: ctext[group_name] = [] continue generator = GroupGenerator( - group_modules, source_paths, - group_name, mapper.group_map, names, - compiler_options + group_modules, source_paths, group_name, mapper.group_map, names, compiler_options ) ctext[group_name] = generator.generate_c_for_modules() @@ -299,7 +315,7 @@ def compile_ir_to_c( def get_ir_cache_name(id: str, path: str, options: Options) -> str: meta_path, _, _ = get_cache_names(id, path, options) - return meta_path.replace('.meta.json', '.ir.json') + return meta_path.replace(".meta.json", ".ir.json") def get_state_ir_cache_name(state: State) -> str: @@ -330,7 +346,7 @@ def write_cache( * The hashes of all of the source file outputs for the group the module is in. This is so that the module will be recompiled if the source outputs are missing. - """ + """ hashes = {} for name, files in ctext.items(): @@ -349,9 +365,9 @@ def write_cache( newpath = get_state_ir_cache_name(st) ir_data = { - 'ir': module.serialize(), - 'meta_hash': compute_hash(meta_data), - 'src_hashes': hashes[group_map[id]], + "ir": module.serialize(), + "meta_hash": compute_hash(meta_data), + "src_hashes": hashes[group_map[id]], } result.manager.metastore.write(newpath, json.dumps(ir_data)) @@ -360,10 +376,7 @@ def write_cache( def load_scc_from_cache( - scc: List[MypyFile], - result: BuildResult, - mapper: Mapper, - ctx: DeserMaps, + scc: List[MypyFile], result: BuildResult, mapper: Mapper, ctx: DeserMaps ) -> ModuleIRs: """Load IR for an SCC of modules from the cache. @@ -372,7 +385,8 @@ def load_scc_from_cache( cache_data = { k.fullname: json.loads( result.manager.metastore.read(get_state_ir_cache_name(result.graph[k.fullname])) - )['ir'] for k in scc + )["ir"] + for k in scc } modules = deserialize_modules(cache_data, ctx) load_type_map(mapper, scc, ctx) @@ -380,10 +394,7 @@ def load_scc_from_cache( def compile_modules_to_c( - result: BuildResult, - compiler_options: CompilerOptions, - errors: Errors, - groups: Groups, + result: BuildResult, compiler_options: CompilerOptions, errors: Errors, groups: Groups ) -> Tuple[ModuleIRs, List[FileContents]]: """Compile Python module(s) to the source of Python C extension modules. @@ -409,7 +420,7 @@ def compile_modules_to_c( # Sometimes when we call back into mypy, there might be errors. # We don't want to crash when that happens. - result.manager.errors.set_file('', module=None, scope=None) + result.manager.errors.set_file("", module=None, scope=None) modules = compile_modules_to_ir(result, mapper, compiler_options, errors) ctext = compile_ir_to_c(groups, modules, result, mapper, compiler_options) @@ -422,41 +433,45 @@ def compile_modules_to_c( def generate_function_declaration(fn: FuncIR, emitter: Emitter) -> None: emitter.context.declarations[emitter.native_function_name(fn.decl)] = HeaderDeclaration( - f'{native_function_header(fn.decl, emitter)};', - needs_export=True) + f"{native_function_header(fn.decl, emitter)};", needs_export=True + ) if fn.name != TOP_LEVEL_NAME: if is_fastcall_supported(fn, emitter.capi_version): emitter.context.declarations[PREFIX + fn.cname(emitter.names)] = HeaderDeclaration( - f'{wrapper_function_header(fn, emitter.names)};') + f"{wrapper_function_header(fn, emitter.names)};" + ) else: emitter.context.declarations[PREFIX + fn.cname(emitter.names)] = HeaderDeclaration( - f'{legacy_wrapper_function_header(fn, emitter.names)};') + f"{legacy_wrapper_function_header(fn, emitter.names)};" + ) def pointerize(decl: str, name: str) -> str: """Given a C decl and its name, modify it to be a declaration to a pointer.""" # This doesn't work in general but does work for all our types... - if '(' in decl: + if "(" in decl: # Function pointer. Stick an * in front of the name and wrap it in parens. - return decl.replace(name, f'(*{name})') + return decl.replace(name, f"(*{name})") else: # Non-function pointer. Just stick an * in front of the name. - return decl.replace(name, f'*{name}') + return decl.replace(name, f"*{name}") def group_dir(group_name: str) -> str: - """Given a group name, return the relative directory path for it. """ - return os.sep.join(group_name.split('.')[:-1]) + """Given a group name, return the relative directory path for it.""" + return os.sep.join(group_name.split(".")[:-1]) class GroupGenerator: - def __init__(self, - modules: List[Tuple[str, ModuleIR]], - source_paths: Dict[str, str], - group_name: Optional[str], - group_map: Dict[str, Optional[str]], - names: NameGenerator, - compiler_options: CompilerOptions) -> None: + def __init__( + self, + modules: List[Tuple[str, ModuleIR]], + source_paths: Dict[str, str], + group_name: Optional[str], + group_map: Dict[str, Optional[str]], + names: NameGenerator, + compiler_options: CompilerOptions, + ) -> None: """Generator for C source for a compilation group. The code for a compilation group contains an internal and an @@ -486,11 +501,11 @@ def __init__(self, @property def group_suffix(self) -> str: - return '_' + exported_name(self.group_name) if self.group_name else '' + return "_" + exported_name(self.group_name) if self.group_name else "" @property def short_group_suffix(self) -> str: - return '_' + exported_name(self.group_name.split('.')[-1]) if self.group_name else '' + return "_" + exported_name(self.group_name.split(".")[-1]) if self.group_name else "" def generate_c_for_modules(self) -> List[Tuple[str, str]]: file_contents = [] @@ -517,8 +532,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]: if multi_file: emitter = Emitter(self.context) emitter.emit_line(f'#include "__native{self.short_group_suffix}.h"') - emitter.emit_line( - f'#include "__native_internal{self.short_group_suffix}.h"') + emitter.emit_line(f'#include "__native_internal{self.short_group_suffix}.h"') self.declare_module(module_name, emitter) self.declare_internal_globals(module_name, emitter) @@ -538,32 +552,34 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]: emitter.emit_line() if is_fastcall_supported(fn, emitter.capi_version): generate_wrapper_function( - fn, emitter, self.source_paths[module_name], module_name) + fn, emitter, self.source_paths[module_name], module_name + ) else: generate_legacy_wrapper_function( - fn, emitter, self.source_paths[module_name], module_name) + fn, emitter, self.source_paths[module_name], module_name + ) if multi_file: - name = (f'__native_{emitter.names.private_name(module_name)}.c') - file_contents.append((name, ''.join(emitter.fragments))) + name = f"__native_{emitter.names.private_name(module_name)}.c" + file_contents.append((name, "".join(emitter.fragments))) # The external header file contains type declarations while # the internal contains declarations of functions and objects # (which are shared between shared libraries via dynamic # exports tables and not accessed directly.) ext_declarations = Emitter(self.context) - ext_declarations.emit_line(f'#ifndef MYPYC_NATIVE{self.group_suffix}_H') - ext_declarations.emit_line(f'#define MYPYC_NATIVE{self.group_suffix}_H') - ext_declarations.emit_line('#include ') - ext_declarations.emit_line('#include ') + ext_declarations.emit_line(f"#ifndef MYPYC_NATIVE{self.group_suffix}_H") + ext_declarations.emit_line(f"#define MYPYC_NATIVE{self.group_suffix}_H") + ext_declarations.emit_line("#include ") + ext_declarations.emit_line("#include ") declarations = Emitter(self.context) - declarations.emit_line(f'#ifndef MYPYC_NATIVE_INTERNAL{self.group_suffix}_H') - declarations.emit_line(f'#define MYPYC_NATIVE_INTERNAL{self.group_suffix}_H') - declarations.emit_line('#include ') - declarations.emit_line('#include ') + declarations.emit_line(f"#ifndef MYPYC_NATIVE_INTERNAL{self.group_suffix}_H") + declarations.emit_line(f"#define MYPYC_NATIVE_INTERNAL{self.group_suffix}_H") + declarations.emit_line("#include ") + declarations.emit_line("#include ") declarations.emit_line(f'#include "__native{self.short_group_suffix}.h"') declarations.emit_line() - declarations.emit_line('int CPyGlobalsInit(void);') + declarations.emit_line("int CPyGlobalsInit(void);") declarations.emit_line() for module_name, module in self.modules: @@ -575,12 +591,10 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]: for lib in sorted(self.context.group_deps): elib = exported_name(lib) - short_lib = exported_name(lib.split('.')[-1]) + short_lib = exported_name(lib.split(".")[-1]) declarations.emit_lines( - '#include <{}>'.format( - os.path.join(group_dir(lib), f"__native_{short_lib}.h") - ), - f'struct export_table_{elib} exports_{elib};' + "#include <{}>".format(os.path.join(group_dir(lib), f"__native_{short_lib}.h")), + f"struct export_table_{elib} exports_{elib};", ) sorted_decls = self.toposort_declarations() @@ -593,8 +607,7 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]: for declaration in sorted_decls: decls = ext_declarations if declaration.is_type else declarations if not declaration.is_type: - decls.emit_lines( - f'extern {declaration.decl[0]}', *declaration.decl[1:]) + decls.emit_lines(f"extern {declaration.decl[0]}", *declaration.decl[1:]) # If there is a definition, emit it. Otherwise repeat the declaration # (without an extern). if declaration.defn: @@ -609,17 +622,23 @@ def generate_c_for_modules(self) -> List[Tuple[str, str]]: self.generate_shared_lib_init(emitter) - ext_declarations.emit_line('#endif') - declarations.emit_line('#endif') + ext_declarations.emit_line("#endif") + declarations.emit_line("#endif") - output_dir = group_dir(self.group_name) if self.group_name else '' + output_dir = group_dir(self.group_name) if self.group_name else "" return file_contents + [ - (os.path.join(output_dir, f'__native{self.short_group_suffix}.c'), - ''.join(emitter.fragments)), - (os.path.join(output_dir, f'__native_internal{self.short_group_suffix}.h'), - ''.join(declarations.fragments)), - (os.path.join(output_dir, f'__native{self.short_group_suffix}.h'), - ''.join(ext_declarations.fragments)), + ( + os.path.join(output_dir, f"__native{self.short_group_suffix}.c"), + "".join(emitter.fragments), + ), + ( + os.path.join(output_dir, f"__native_internal{self.short_group_suffix}.h"), + "".join(declarations.fragments), + ), + ( + os.path.join(output_dir, f"__native{self.short_group_suffix}.h"), + "".join(ext_declarations.fragments), + ), ] def generate_literal_tables(self) -> None: @@ -631,25 +650,25 @@ def generate_literal_tables(self) -> None: """ literals = self.context.literals # During module initialization we store all the constructed objects here - self.declare_global('PyObject *[%d]' % literals.num_literals(), 'CPyStatics') + self.declare_global("PyObject *[%d]" % literals.num_literals(), "CPyStatics") # Descriptions of str literals init_str = c_string_array_initializer(literals.encoded_str_values()) - self.declare_global('const char * const []', 'CPyLit_Str', initializer=init_str) + self.declare_global("const char * const []", "CPyLit_Str", initializer=init_str) # Descriptions of bytes literals init_bytes = c_string_array_initializer(literals.encoded_bytes_values()) - self.declare_global('const char * const []', 'CPyLit_Bytes', initializer=init_bytes) + self.declare_global("const char * const []", "CPyLit_Bytes", initializer=init_bytes) # Descriptions of int literals init_int = c_string_array_initializer(literals.encoded_int_values()) - self.declare_global('const char * const []', 'CPyLit_Int', initializer=init_int) + self.declare_global("const char * const []", "CPyLit_Int", initializer=init_int) # Descriptions of float literals init_floats = c_array_initializer(literals.encoded_float_values()) - self.declare_global('const double []', 'CPyLit_Float', initializer=init_floats) + self.declare_global("const double []", "CPyLit_Float", initializer=init_floats) # Descriptions of complex literals init_complex = c_array_initializer(literals.encoded_complex_values()) - self.declare_global('const double []', 'CPyLit_Complex', initializer=init_complex) + self.declare_global("const double []", "CPyLit_Complex", initializer=init_complex) # Descriptions of tuple literals init_tuple = c_array_initializer(literals.encoded_tuple_values()) - self.declare_global('const int []', 'CPyLit_Tuple', initializer=init_tuple) + self.declare_global("const int []", "CPyLit_Tuple", initializer=init_tuple) def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> None: """Generate the declaration and definition of the group's export struct. @@ -697,25 +716,19 @@ def generate_export_table(self, decl_emitter: Emitter, code_emitter: Emitter) -> decls = decl_emitter.context.declarations - decl_emitter.emit_lines( - '', - f'struct export_table{self.group_suffix} {{', - ) + decl_emitter.emit_lines("", f"struct export_table{self.group_suffix} {{") for name, decl in decls.items(): if decl.needs_export: - decl_emitter.emit_line(pointerize('\n'.join(decl.decl), name)) + decl_emitter.emit_line(pointerize("\n".join(decl.decl), name)) - decl_emitter.emit_line('};') + decl_emitter.emit_line("};") - code_emitter.emit_lines( - '', - f'static struct export_table{self.group_suffix} exports = {{', - ) + code_emitter.emit_lines("", f"static struct export_table{self.group_suffix} exports = {{") for name, decl in decls.items(): if decl.needs_export: - code_emitter.emit_line(f'&{name},') + code_emitter.emit_line(f"&{name},") - code_emitter.emit_line('};') + code_emitter.emit_line("};") def generate_shared_lib_init(self, emitter: Emitter) -> None: """Generate the init function for a shared library. @@ -735,138 +748,138 @@ def generate_shared_lib_init(self, emitter: Emitter) -> None: emitter.emit_line() emitter.emit_lines( - 'PyMODINIT_FUNC PyInit_{}(void)'.format( - shared_lib_name(self.group_name).split('.')[-1]), - '{', - ('static PyModuleDef def = {{ PyModuleDef_HEAD_INIT, "{}", NULL, -1, NULL, NULL }};' - .format(shared_lib_name(self.group_name))), - 'int res;', - 'PyObject *capsule;', - 'PyObject *tmp;', - 'static PyObject *module;', - 'if (module) {', - 'Py_INCREF(module);', - 'return module;', - '}', - 'module = PyModule_Create(&def);', - 'if (!module) {', - 'goto fail;', - '}', - '', + "PyMODINIT_FUNC PyInit_{}(void)".format( + shared_lib_name(self.group_name).split(".")[-1] + ), + "{", + ( + 'static PyModuleDef def = {{ PyModuleDef_HEAD_INIT, "{}", NULL, -1, NULL, NULL }};'.format( + shared_lib_name(self.group_name) + ) + ), + "int res;", + "PyObject *capsule;", + "PyObject *tmp;", + "static PyObject *module;", + "if (module) {", + "Py_INCREF(module);", + "return module;", + "}", + "module = PyModule_Create(&def);", + "if (!module) {", + "goto fail;", + "}", + "", ) emitter.emit_lines( 'capsule = PyCapsule_New(&exports, "{}.exports", NULL);'.format( - shared_lib_name(self.group_name)), - 'if (!capsule) {', - 'goto fail;', - '}', + shared_lib_name(self.group_name) + ), + "if (!capsule) {", + "goto fail;", + "}", 'res = PyObject_SetAttrString(module, "exports", capsule);', - 'Py_DECREF(capsule);', - 'if (res < 0) {', - 'goto fail;', - '}', - '', + "Py_DECREF(capsule);", + "if (res < 0) {", + "goto fail;", + "}", + "", ) for mod, _ in self.modules: name = exported_name(mod) emitter.emit_lines( - f'extern PyObject *CPyInit_{name}(void);', + f"extern PyObject *CPyInit_{name}(void);", 'capsule = PyCapsule_New((void *)CPyInit_{}, "{}.init_{}", NULL);'.format( - name, shared_lib_name(self.group_name), name), - 'if (!capsule) {', - 'goto fail;', - '}', + name, shared_lib_name(self.group_name), name + ), + "if (!capsule) {", + "goto fail;", + "}", f'res = PyObject_SetAttrString(module, "init_{name}", capsule);', - 'Py_DECREF(capsule);', - 'if (res < 0) {', - 'goto fail;', - '}', - '', + "Py_DECREF(capsule);", + "if (res < 0) {", + "goto fail;", + "}", + "", ) for group in sorted(self.context.group_deps): egroup = exported_name(group) emitter.emit_lines( 'tmp = PyImport_ImportModule("{}"); if (!tmp) goto fail; Py_DECREF(tmp);'.format( - shared_lib_name(group)), + shared_lib_name(group) + ), 'struct export_table_{} *pexports_{} = PyCapsule_Import("{}.exports", 0);'.format( - egroup, egroup, shared_lib_name(group)), - f'if (!pexports_{egroup}) {{', - 'goto fail;', - '}', - 'memcpy(&exports_{group}, pexports_{group}, sizeof(exports_{group}));'.format( - group=egroup), - '', + egroup, egroup, shared_lib_name(group) + ), + f"if (!pexports_{egroup}) {{", + "goto fail;", + "}", + "memcpy(&exports_{group}, pexports_{group}, sizeof(exports_{group}));".format( + group=egroup + ), + "", ) - emitter.emit_lines( - 'return module;', - 'fail:', - 'Py_XDECREF(module);', - 'return NULL;', - '}', - ) + emitter.emit_lines("return module;", "fail:", "Py_XDECREF(module);", "return NULL;", "}") def generate_globals_init(self, emitter: Emitter) -> None: emitter.emit_lines( - '', - 'int CPyGlobalsInit(void)', - '{', - 'static int is_initialized = 0;', - 'if (is_initialized) return 0;', - '' + "", + "int CPyGlobalsInit(void)", + "{", + "static int is_initialized = 0;", + "if (is_initialized) return 0;", + "", ) - emitter.emit_line('CPy_Init();') + emitter.emit_line("CPy_Init();") for symbol, fixup in self.simple_inits: - emitter.emit_line(f'{symbol} = {fixup};') - - values = 'CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple' - emitter.emit_lines(f'if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{', - 'return -1;', - '}') + emitter.emit_line(f"{symbol} = {fixup};") + values = "CPyLit_Str, CPyLit_Bytes, CPyLit_Int, CPyLit_Float, CPyLit_Complex, CPyLit_Tuple" emitter.emit_lines( - 'is_initialized = 1;', - 'return 0;', - '}', + f"if (CPyStatics_Initialize(CPyStatics, {values}) < 0) {{", "return -1;", "}" ) + emitter.emit_lines("is_initialized = 1;", "return 0;", "}") + def generate_module_def(self, emitter: Emitter, module_name: str, module: ModuleIR) -> None: """Emit the PyModuleDef struct for a module and the module init function.""" # Emit module methods module_prefix = emitter.names.private_name(module_name) - emitter.emit_line(f'static PyMethodDef {module_prefix}module_methods[] = {{') + emitter.emit_line(f"static PyMethodDef {module_prefix}module_methods[] = {{") for fn in module.functions: if fn.class_name is not None or fn.name == TOP_LEVEL_NAME: continue name = short_id_from_name(fn.name, fn.decl.shortname, fn.line) if is_fastcall_supported(fn, emitter.capi_version): - flag = 'METH_FASTCALL' + flag = "METH_FASTCALL" else: - flag = 'METH_VARARGS' + flag = "METH_VARARGS" emitter.emit_line( - ('{{"{name}", (PyCFunction){prefix}{cname}, {flag} | METH_KEYWORDS, ' - 'NULL /* docstring */}},').format( - name=name, - cname=fn.cname(emitter.names), - prefix=PREFIX, - flag=flag)) - emitter.emit_line('{NULL, NULL, 0, NULL}') - emitter.emit_line('};') + ( + '{{"{name}", (PyCFunction){prefix}{cname}, {flag} | METH_KEYWORDS, ' + "NULL /* docstring */}}," + ).format(name=name, cname=fn.cname(emitter.names), prefix=PREFIX, flag=flag) + ) + emitter.emit_line("{NULL, NULL, 0, NULL}") + emitter.emit_line("};") emitter.emit_line() # Emit module definition struct - emitter.emit_lines(f'static struct PyModuleDef {module_prefix}module = {{', - 'PyModuleDef_HEAD_INIT,', - f'"{module_name}",', - 'NULL, /* docstring */', - '-1, /* size of per-interpreter state of the module,', - ' or -1 if the module keeps state in global variables. */', - f'{module_prefix}module_methods', - '};') + emitter.emit_lines( + f"static struct PyModuleDef {module_prefix}module = {{", + "PyModuleDef_HEAD_INIT,", + f'"{module_name}",', + "NULL, /* docstring */", + "-1, /* size of per-interpreter state of the module,", + " or -1 if the module keeps state in global variables. */", + f"{module_prefix}module_methods", + "};", + ) emitter.emit_line() # Emit module init function. If we are compiling just one module, this # will be the C API init function. If we are compiling 2+ modules, we @@ -874,12 +887,11 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module # the shared library, and in this case we use an internal module # initialized function that will be called by the shim. if not self.use_shared_lib: - declaration = f'PyMODINIT_FUNC PyInit_{module_name}(void)' + declaration = f"PyMODINIT_FUNC PyInit_{module_name}(void)" else: - declaration = f'PyObject *CPyInit_{exported_name(module_name)}(void)' - emitter.emit_lines(declaration, - '{') - emitter.emit_line('PyObject* modname = NULL;') + declaration = f"PyObject *CPyInit_{exported_name(module_name)}(void)" + emitter.emit_lines(declaration, "{") + emitter.emit_line("PyObject* modname = NULL;") # Store the module reference in a static and return it when necessary. # This is separate from the *global* reference to the module that will # be populated when it is imported by a compiled module. We want that @@ -887,22 +899,28 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module # imported, whereas this we want to have to stop a circular import. module_static = self.module_internal_static_name(module_name, emitter) - emitter.emit_lines(f'if ({module_static}) {{', - f'Py_INCREF({module_static});', - f'return {module_static};', - '}') + emitter.emit_lines( + f"if ({module_static}) {{", + f"Py_INCREF({module_static});", + f"return {module_static};", + "}", + ) - emitter.emit_lines(f'{module_static} = PyModule_Create(&{module_prefix}module);', - f'if (unlikely({module_static} == NULL))', - ' goto fail;') + emitter.emit_lines( + f"{module_static} = PyModule_Create(&{module_prefix}module);", + f"if (unlikely({module_static} == NULL))", + " goto fail;", + ) emitter.emit_line( - 'modname = PyObject_GetAttrString((PyObject *){}, "__name__");'.format( - module_static)) + 'modname = PyObject_GetAttrString((PyObject *){}, "__name__");'.format(module_static) + ) - module_globals = emitter.static_name('globals', module_name) - emitter.emit_lines(f'{module_globals} = PyModule_GetDict({module_static});', - f'if (unlikely({module_globals} == NULL))', - ' goto fail;') + module_globals = emitter.static_name("globals", module_name) + emitter.emit_lines( + f"{module_globals} = PyModule_GetDict({module_static});", + f"if (unlikely({module_globals} == NULL))", + " goto fail;", + ) # HACK: Manually instantiate generated classes here type_structs: List[str] = [] @@ -911,34 +929,30 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module type_structs.append(type_struct) if cl.is_generated: emitter.emit_lines( - '{t} = (PyTypeObject *)CPyType_FromTemplate(' - '(PyObject *){t}_template, NULL, modname);' - .format(t=type_struct)) - emitter.emit_lines(f'if (unlikely(!{type_struct}))', - ' goto fail;') + "{t} = (PyTypeObject *)CPyType_FromTemplate(" + "(PyObject *){t}_template, NULL, modname);".format(t=type_struct) + ) + emitter.emit_lines(f"if (unlikely(!{type_struct}))", " goto fail;") - emitter.emit_lines('if (CPyGlobalsInit() < 0)', - ' goto fail;') + emitter.emit_lines("if (CPyGlobalsInit() < 0)", " goto fail;") self.generate_top_level_call(module, emitter) - emitter.emit_lines('Py_DECREF(modname);') + emitter.emit_lines("Py_DECREF(modname);") - emitter.emit_line(f'return {module_static};') - emitter.emit_lines('fail:', - f'Py_CLEAR({module_static});', - 'Py_CLEAR(modname);') + emitter.emit_line(f"return {module_static};") + emitter.emit_lines("fail:", f"Py_CLEAR({module_static});", "Py_CLEAR(modname);") for name, typ in module.final_names: static_name = emitter.static_name(name, module_name) emitter.emit_dec_ref(static_name, typ, is_xdec=True) undef = emitter.c_undefined_value(typ) - emitter.emit_line(f'{static_name} = {undef};') + emitter.emit_line(f"{static_name} = {undef};") # the type objects returned from CPyType_FromTemplate are all new references # so we have to decref them for t in type_structs: - emitter.emit_line(f'Py_CLEAR({t});') - emitter.emit_line('return NULL;') - emitter.emit_line('}') + emitter.emit_line(f"Py_CLEAR({t});") + emitter.emit_line("return NULL;") + emitter.emit_line("}") def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None: """Generate call to function representing module top level.""" @@ -946,9 +960,9 @@ def generate_top_level_call(self, module: ModuleIR, emitter: Emitter) -> None: for fn in reversed(module.functions): if fn.name == TOP_LEVEL_NAME: emitter.emit_lines( - f'char result = {emitter.native_function_name(fn.decl)}();', - 'if (result == 2)', - ' goto fail;', + f"char result = {emitter.native_function_name(fn.decl)}();", + "if (result == 2)", + " goto fail;", ) break @@ -982,31 +996,28 @@ def _toposort_visit(name: str) -> None: return result - def declare_global(self, type_spaced: str, name: str, - *, - initializer: Optional[str] = None) -> None: - if '[' not in type_spaced: - base = f'{type_spaced}{name}' + def declare_global( + self, type_spaced: str, name: str, *, initializer: Optional[str] = None + ) -> None: + if "[" not in type_spaced: + base = f"{type_spaced}{name}" else: - a, b = type_spaced.split('[', 1) - base = f'{a}{name}[{b}' + a, b = type_spaced.split("[", 1) + base = f"{a}{name}[{b}" if not initializer: defn = None else: - defn = [f'{base} = {initializer};'] + defn = [f"{base} = {initializer};"] if name not in self.context.declarations: - self.context.declarations[name] = HeaderDeclaration( - f'{base};', - defn=defn, - ) + self.context.declarations[name] = HeaderDeclaration(f"{base};", defn=defn) def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None: - static_name = emitter.static_name('globals', module_name) - self.declare_global('PyObject *', static_name) + static_name = emitter.static_name("globals", module_name) + self.declare_global("PyObject *", static_name) def module_internal_static_name(self, module_name: str, emitter: Emitter) -> str: - return emitter.static_name(module_name + '_internal', None, prefix=MODULE_PREFIX) + return emitter.static_name(module_name + "_internal", None, prefix=MODULE_PREFIX) def declare_module(self, module_name: str, emitter: Emitter) -> None: # We declare two globals for each module: @@ -1014,38 +1025,39 @@ def declare_module(self, module_name: str, emitter: Emitter) -> None: # and prevent infinite recursion in import cycles, and one used # by other modules to refer to it. internal_static_name = self.module_internal_static_name(module_name, emitter) - self.declare_global('CPyModule *', internal_static_name, initializer='NULL') + self.declare_global("CPyModule *", internal_static_name, initializer="NULL") static_name = emitter.static_name(module_name, None, prefix=MODULE_PREFIX) - self.declare_global('CPyModule *', static_name) - self.simple_inits.append((static_name, 'Py_None')) + self.declare_global("CPyModule *", static_name) + self.simple_inits.append((static_name, "Py_None")) def declare_imports(self, imps: Iterable[str], emitter: Emitter) -> None: for imp in imps: self.declare_module(imp, emitter) def declare_finals( - self, module: str, final_names: Iterable[Tuple[str, RType]], emitter: Emitter) -> None: + self, module: str, final_names: Iterable[Tuple[str, RType]], emitter: Emitter + ) -> None: for name, typ in final_names: static_name = emitter.static_name(name, module) emitter.context.declarations[static_name] = HeaderDeclaration( - f'{emitter.ctype_spaced(typ)}{static_name};', + f"{emitter.ctype_spaced(typ)}{static_name};", [self.final_definition(module, name, typ, emitter)], - needs_export=True) + needs_export=True, + ) - def final_definition( - self, module: str, name: str, typ: RType, emitter: Emitter) -> str: + def final_definition(self, module: str, name: str, typ: RType, emitter: Emitter) -> str: static_name = emitter.static_name(name, module) # Here we rely on the fact that undefined value and error value are always the same if isinstance(typ, RTuple): # We need to inline because initializer must be static - undefined = '{{ {} }}'.format(''.join(emitter.tuple_undefined_value_helper(typ))) + undefined = "{{ {} }}".format("".join(emitter.tuple_undefined_value_helper(typ))) else: undefined = emitter.c_undefined_value(typ) - return f'{emitter.ctype_spaced(typ)}{static_name} = {undefined};' + return f"{emitter.ctype_spaced(typ)}{static_name} = {undefined};" def declare_static_pyobject(self, identifier: str, emitter: Emitter) -> None: symbol = emitter.static_name(identifier, None) - self.declare_global('PyObject *', symbol) + self.declare_global("PyObject *", symbol) def sort_classes(classes: List[Tuple[str, ClassIR]]) -> List[Tuple[str, ClassIR]]: @@ -1062,7 +1074,7 @@ def sort_classes(classes: List[Tuple[str, ClassIR]]) -> List[Tuple[str, ClassIR] return [(mod_name[ir], ir) for ir in sorted_irs] -T = TypeVar('T') +T = TypeVar("T") def toposort(deps: Dict[T, Set[T]]) -> List[T]: @@ -1091,11 +1103,11 @@ def visit(item: T) -> None: def is_fastcall_supported(fn: FuncIR, capi_version: Tuple[int, int]) -> bool: if fn.class_name is not None: - if fn.name == '__call__': + if fn.name == "__call__": # We can use vectorcalls (PEP 590) when supported return use_vectorcall(capi_version) # TODO: Support fastcall for __init__. - return use_fastcall(capi_version) and fn.name != '__init__' + return use_fastcall(capi_version) and fn.name != "__init__" return use_fastcall(capi_version) @@ -1131,21 +1143,21 @@ def c_array_initializer(components: List[str]) -> str: current.append(c) cur_len += len(c) + 2 else: - res.append(', '.join(current)) + res.append(", ".join(current)) current = [c] cur_len = len(c) if not res: # Result fits on a single line - return '{%s}' % ', '.join(current) + return "{%s}" % ", ".join(current) # Multi-line result - res.append(', '.join(current)) - return '{\n ' + ',\n '.join(res) + '\n}' + res.append(", ".join(current)) + return "{\n " + ",\n ".join(res) + "\n}" def c_string_array_initializer(components: List[bytes]) -> str: result = [] - result.append('{\n') + result.append("{\n") for s in components: - result.append(' ' + c_string_initializer(s) + ',\n') - result.append('}') - return ''.join(result) + result.append(" " + c_string_initializer(s) + ",\n") + result.append("}") + return "".join(result) diff --git a/mypyc/codegen/emitwrapper.py b/mypyc/codegen/emitwrapper.py index a68438c5f0dbd..4c60ee34b8d9c 100644 --- a/mypyc/codegen/emitwrapper.py +++ b/mypyc/codegen/emitwrapper.py @@ -10,22 +10,24 @@ or methods in a single compilation unit. """ -from typing import List, Dict, Optional, Sequence +from typing import Dict, List, Optional, Sequence -from mypy.nodes import ArgKind, ARG_POS, ARG_OPT, ARG_NAMED_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2 -from mypy.operators import op_methods_to_symbols, reverse_op_methods, reverse_op_method_names - -from mypyc.common import PREFIX, NATIVE_PREFIX, DUNDER_PREFIX, use_vectorcall -from mypyc.codegen.emit import Emitter, ErrorHandler, GotoHandler, AssignHandler, ReturnHandler +from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind +from mypy.operators import op_methods_to_symbols, reverse_op_method_names, reverse_op_methods +from mypyc.codegen.emit import AssignHandler, Emitter, ErrorHandler, GotoHandler, ReturnHandler +from mypyc.common import DUNDER_PREFIX, NATIVE_PREFIX, PREFIX, use_vectorcall +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR, RuntimeArg from mypyc.ir.rtypes import ( - RType, RInstance, is_object_rprimitive, is_int_rprimitive, is_bool_rprimitive, - object_rprimitive + RInstance, + RType, + is_bool_rprimitive, + is_int_rprimitive, + is_object_rprimitive, + object_rprimitive, ) -from mypyc.ir.func_ir import FuncIR, RuntimeArg, FUNC_STATICMETHOD -from mypyc.ir.class_ir import ClassIR from mypyc.namegen import NameGenerator - # Generic vectorcall wrapper functions (Python 3.7+) # # A wrapper function has a signature like this: @@ -51,27 +53,26 @@ def wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str: See comment above for a summary of the arguments. """ return ( - 'PyObject *{prefix}{name}(' - 'PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames)').format( - prefix=PREFIX, - name=fn.cname(names)) + "PyObject *{prefix}{name}(" + "PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames)" + ).format(prefix=PREFIX, name=fn.cname(names)) -def generate_traceback_code(fn: FuncIR, - emitter: Emitter, - source_path: str, - module_name: str) -> str: +def generate_traceback_code( + fn: FuncIR, emitter: Emitter, source_path: str, module_name: str +) -> str: # If we hit an error while processing arguments, then we emit a # traceback frame to make it possible to debug where it happened. # Unlike traceback frames added for exceptions seen in IR, we do this # even if there is no `traceback_name`. This is because the error will # have originated here and so we need it in the traceback. - globals_static = emitter.static_name('globals', module_name) + globals_static = emitter.static_name("globals", module_name) traceback_code = 'CPy_AddTraceback("%s", "%s", %d, %s);' % ( source_path.replace("\\", "\\\\"), fn.traceback_name or fn.name, fn.line, - globals_static) + globals_static, + ) return traceback_code @@ -86,8 +87,8 @@ def reorder_arg_groups(groups: Dict[ArgKind, List[RuntimeArg]]) -> List[RuntimeA def make_static_kwlist(args: List[RuntimeArg]) -> str: - arg_names = ''.join(f'"{arg.name}", ' for arg in args) - return f'static const char * const kwlist[] = {{{arg_names}0}};' + arg_names = "".join(f'"{arg.name}", ' for arg in args) + return f"static const char * const kwlist[] = {{{arg_names}0}};" def make_format_string(func_name: Optional[str], groups: Dict[ArgKind, List[RuntimeArg]]) -> str: @@ -105,37 +106,36 @@ def make_format_string(func_name: Optional[str], groups: Dict[ArgKind, List[Runt These are used by both vectorcall and legacy wrapper functions. """ - format = '' + format = "" if groups[ARG_STAR] or groups[ARG_STAR2]: - format += '%' - format += 'O' * len(groups[ARG_POS]) + format += "%" + format += "O" * len(groups[ARG_POS]) if groups[ARG_OPT] or groups[ARG_NAMED_OPT] or groups[ARG_NAMED]: - format += '|' + 'O' * len(groups[ARG_OPT]) + format += "|" + "O" * len(groups[ARG_OPT]) if groups[ARG_NAMED_OPT] or groups[ARG_NAMED]: - format += '$' + 'O' * len(groups[ARG_NAMED_OPT]) + format += "$" + "O" * len(groups[ARG_NAMED_OPT]) if groups[ARG_NAMED]: - format += '@' + 'O' * len(groups[ARG_NAMED]) + format += "@" + "O" * len(groups[ARG_NAMED]) if func_name is not None: - format += f':{func_name}' + format += f":{func_name}" return format -def generate_wrapper_function(fn: FuncIR, - emitter: Emitter, - source_path: str, - module_name: str) -> None: +def generate_wrapper_function( + fn: FuncIR, emitter: Emitter, source_path: str, module_name: str +) -> None: """Generate a CPython-compatible vectorcall wrapper for a native function. In particular, this handles unboxing the arguments, calling the native function, and then boxing the return value. """ - emitter.emit_line(f'{wrapper_function_header(fn, emitter.names)} {{') + emitter.emit_line(f"{wrapper_function_header(fn, emitter.names)} {{") # If fn is a method, then the first argument is a self param real_args = list(fn.args) if fn.class_name and not fn.decl.kind == FUNC_STATICMETHOD: arg = real_args.pop(0) - emitter.emit_line(f'PyObject *obj_{arg.name} = self;') + emitter.emit_line(f"PyObject *obj_{arg.name} = self;") # Need to order args as: required, optional, kwonly optional, kwonly required # This is because CPyArg_ParseStackAndKeywords format string requires @@ -149,44 +149,50 @@ def generate_wrapper_function(fn: FuncIR, emitter.emit_line(f'static CPyArg_Parser parser = {{"{fmt}", kwlist, 0}};') for arg in real_args: - emitter.emit_line('PyObject *obj_{}{};'.format( - arg.name, ' = NULL' if arg.optional else '')) + emitter.emit_line( + "PyObject *obj_{}{};".format(arg.name, " = NULL" if arg.optional else "") + ) - cleanups = [f'CPy_DECREF(obj_{arg.name});' - for arg in groups[ARG_STAR] + groups[ARG_STAR2]] + cleanups = [f"CPy_DECREF(obj_{arg.name});" for arg in groups[ARG_STAR] + groups[ARG_STAR2]] arg_ptrs: List[str] = [] if groups[ARG_STAR] or groups[ARG_STAR2]: - arg_ptrs += [f'&obj_{groups[ARG_STAR][0].name}' if groups[ARG_STAR] else 'NULL'] - arg_ptrs += [f'&obj_{groups[ARG_STAR2][0].name}' if groups[ARG_STAR2] else 'NULL'] - arg_ptrs += [f'&obj_{arg.name}' for arg in reordered_args] + arg_ptrs += [f"&obj_{groups[ARG_STAR][0].name}" if groups[ARG_STAR] else "NULL"] + arg_ptrs += [f"&obj_{groups[ARG_STAR2][0].name}" if groups[ARG_STAR2] else "NULL"] + arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args] - if fn.name == '__call__' and use_vectorcall(emitter.capi_version): - nargs = 'PyVectorcall_NARGS(nargs)' + if fn.name == "__call__" and use_vectorcall(emitter.capi_version): + nargs = "PyVectorcall_NARGS(nargs)" else: - nargs = 'nargs' - parse_fn = 'CPyArg_ParseStackAndKeywords' + nargs = "nargs" + parse_fn = "CPyArg_ParseStackAndKeywords" # Special case some common signatures if len(real_args) == 0: # No args - parse_fn = 'CPyArg_ParseStackAndKeywordsNoArgs' + parse_fn = "CPyArg_ParseStackAndKeywordsNoArgs" elif len(real_args) == 1 and len(groups[ARG_POS]) == 1: # Single positional arg - parse_fn = 'CPyArg_ParseStackAndKeywordsOneArg' + parse_fn = "CPyArg_ParseStackAndKeywordsOneArg" elif len(real_args) == len(groups[ARG_POS]) + len(groups[ARG_OPT]): # No keyword-only args, *args or **kwargs - parse_fn = 'CPyArg_ParseStackAndKeywordsSimple' + parse_fn = "CPyArg_ParseStackAndKeywordsSimple" emitter.emit_lines( - 'if (!{}(args, {}, kwnames, &parser{})) {{'.format( - parse_fn, nargs, ''.join(', ' + n for n in arg_ptrs)), - 'return NULL;', - '}') + "if (!{}(args, {}, kwnames, &parser{})) {{".format( + parse_fn, nargs, "".join(", " + n for n in arg_ptrs) + ), + "return NULL;", + "}", + ) traceback_code = generate_traceback_code(fn, emitter, source_path, module_name) - generate_wrapper_core(fn, emitter, groups[ARG_OPT] + groups[ARG_NAMED_OPT], - cleanups=cleanups, - traceback_code=traceback_code) + generate_wrapper_core( + fn, + emitter, + groups[ARG_OPT] + groups[ARG_NAMED_OPT], + cleanups=cleanups, + traceback_code=traceback_code, + ) - emitter.emit_line('}') + emitter.emit_line("}") # Legacy generic wrapper functions @@ -198,27 +204,26 @@ def generate_wrapper_function(fn: FuncIR, def legacy_wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str: - return 'PyObject *{prefix}{name}(PyObject *self, PyObject *args, PyObject *kw)'.format( - prefix=PREFIX, - name=fn.cname(names)) + return "PyObject *{prefix}{name}(PyObject *self, PyObject *args, PyObject *kw)".format( + prefix=PREFIX, name=fn.cname(names) + ) -def generate_legacy_wrapper_function(fn: FuncIR, - emitter: Emitter, - source_path: str, - module_name: str) -> None: +def generate_legacy_wrapper_function( + fn: FuncIR, emitter: Emitter, source_path: str, module_name: str +) -> None: """Generates a CPython-compatible legacy wrapper for a native function. In particular, this handles unboxing the arguments, calling the native function, and then boxing the return value. """ - emitter.emit_line(f'{legacy_wrapper_function_header(fn, emitter.names)} {{') + emitter.emit_line(f"{legacy_wrapper_function_header(fn, emitter.names)} {{") # If fn is a method, then the first argument is a self param real_args = list(fn.args) if fn.class_name and not fn.decl.kind == FUNC_STATICMETHOD: arg = real_args.pop(0) - emitter.emit_line(f'PyObject *obj_{arg.name} = self;') + emitter.emit_line(f"PyObject *obj_{arg.name} = self;") # Need to order args as: required, optional, kwonly optional, kwonly required # This is because CPyArg_ParseTupleAndKeywords format string requires @@ -228,29 +233,35 @@ def generate_legacy_wrapper_function(fn: FuncIR, emitter.emit_line(make_static_kwlist(reordered_args)) for arg in real_args: - emitter.emit_line('PyObject *obj_{}{};'.format( - arg.name, ' = NULL' if arg.optional else '')) + emitter.emit_line( + "PyObject *obj_{}{};".format(arg.name, " = NULL" if arg.optional else "") + ) - cleanups = [f'CPy_DECREF(obj_{arg.name});' - for arg in groups[ARG_STAR] + groups[ARG_STAR2]] + cleanups = [f"CPy_DECREF(obj_{arg.name});" for arg in groups[ARG_STAR] + groups[ARG_STAR2]] arg_ptrs: List[str] = [] if groups[ARG_STAR] or groups[ARG_STAR2]: - arg_ptrs += [f'&obj_{groups[ARG_STAR][0].name}' if groups[ARG_STAR] else 'NULL'] - arg_ptrs += [f'&obj_{groups[ARG_STAR2][0].name}' if groups[ARG_STAR2] else 'NULL'] - arg_ptrs += [f'&obj_{arg.name}' for arg in reordered_args] + arg_ptrs += [f"&obj_{groups[ARG_STAR][0].name}" if groups[ARG_STAR] else "NULL"] + arg_ptrs += [f"&obj_{groups[ARG_STAR2][0].name}" if groups[ARG_STAR2] else "NULL"] + arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args] emitter.emit_lines( 'if (!CPyArg_ParseTupleAndKeywords(args, kw, "{}", "{}", kwlist{})) {{'.format( - make_format_string(None, groups), fn.name, ''.join(', ' + n for n in arg_ptrs)), - 'return NULL;', - '}') + make_format_string(None, groups), fn.name, "".join(", " + n for n in arg_ptrs) + ), + "return NULL;", + "}", + ) traceback_code = generate_traceback_code(fn, emitter, source_path, module_name) - generate_wrapper_core(fn, emitter, groups[ARG_OPT] + groups[ARG_NAMED_OPT], - cleanups=cleanups, - traceback_code=traceback_code) + generate_wrapper_core( + fn, + emitter, + groups[ARG_OPT] + groups[ARG_NAMED_OPT], + cleanups=cleanups, + traceback_code=traceback_code, + ) - emitter.emit_line('}') + emitter.emit_line("}") # Specialized wrapper functions @@ -280,7 +291,7 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """ gen = WrapperGenerator(cl, emitter) gen.set_target(fn) - gen.arg_names = ['left', 'right'] + gen.arg_names = ["left", "right"] wrapper_name = gen.wrapper_name() gen.emit_header() @@ -299,13 +310,13 @@ def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: return wrapper_name -def generate_bin_op_forward_only_wrapper(fn: FuncIR, - emitter: Emitter, - gen: 'WrapperGenerator') -> None: - gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False) - gen.emit_call(not_implemented_handler='goto typefail;') +def generate_bin_op_forward_only_wrapper( + fn: FuncIR, emitter: Emitter, gen: "WrapperGenerator" +) -> None: + gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) + gen.emit_call(not_implemented_handler="goto typefail;") gen.emit_error_handling() - emitter.emit_label('typefail') + emitter.emit_label("typefail") # If some argument has an incompatible type, treat this the same as # returning NotImplemented, and try to call the reverse operator method. # @@ -322,31 +333,29 @@ def generate_bin_op_forward_only_wrapper(fn: FuncIR, # return NotImplemented # ... rmethod = reverse_op_methods[fn.name] - emitter.emit_line(f'_Py_IDENTIFIER({rmethod});') + emitter.emit_line(f"_Py_IDENTIFIER({rmethod});") emitter.emit_line( 'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format( - op_methods_to_symbols[fn.name], - rmethod)) + op_methods_to_symbols[fn.name], rmethod + ) + ) gen.finish() -def generate_bin_op_reverse_only_wrapper(emitter: Emitter, - gen: 'WrapperGenerator') -> None: - gen.arg_names = ['right', 'left'] - gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False) +def generate_bin_op_reverse_only_wrapper(emitter: Emitter, gen: "WrapperGenerator") -> None: + gen.arg_names = ["right", "left"] + gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) gen.emit_call() gen.emit_error_handling() - emitter.emit_label('typefail') - emitter.emit_line('Py_INCREF(Py_NotImplemented);') - emitter.emit_line('return Py_NotImplemented;') + emitter.emit_label("typefail") + emitter.emit_line("Py_INCREF(Py_NotImplemented);") + emitter.emit_line("return Py_NotImplemented;") gen.finish() -def generate_bin_op_both_wrappers(cl: ClassIR, - fn: FuncIR, - fn_rev: FuncIR, - emitter: Emitter, - gen: 'WrapperGenerator') -> None: +def generate_bin_op_both_wrappers( + cl: ClassIR, fn: FuncIR, fn_rev: FuncIR, emitter: Emitter, gen: "WrapperGenerator" +) -> None: # There's both a forward and a reverse operator method. First # check if we should try calling the forward one. If the # argument type check fails, fall back to the reverse method. @@ -355,40 +364,47 @@ def generate_bin_op_both_wrappers(cl: ClassIR, # In regular Python code you'd return NotImplemented if the # operand has the wrong type, but in compiled code we'll never # get to execute the type check. - emitter.emit_line('if (PyObject_IsInstance(obj_left, (PyObject *){})) {{'.format( - emitter.type_struct_name(cl))) - gen.emit_arg_processing(error=GotoHandler('typefail'), raise_exception=False) - gen.emit_call(not_implemented_handler='goto typefail;') + emitter.emit_line( + "if (PyObject_IsInstance(obj_left, (PyObject *){})) {{".format( + emitter.type_struct_name(cl) + ) + ) + gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False) + gen.emit_call(not_implemented_handler="goto typefail;") gen.emit_error_handling() - emitter.emit_line('}') - emitter.emit_label('typefail') - emitter.emit_line('if (PyObject_IsInstance(obj_right, (PyObject *){})) {{'.format( - emitter.type_struct_name(cl))) + emitter.emit_line("}") + emitter.emit_label("typefail") + emitter.emit_line( + "if (PyObject_IsInstance(obj_right, (PyObject *){})) {{".format( + emitter.type_struct_name(cl) + ) + ) gen.set_target(fn_rev) - gen.arg_names = ['right', 'left'] - gen.emit_arg_processing(error=GotoHandler('typefail2'), raise_exception=False) + gen.arg_names = ["right", "left"] + gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False) gen.emit_call() gen.emit_error_handling() - emitter.emit_line('} else {') - emitter.emit_line(f'_Py_IDENTIFIER({fn_rev.name});') + emitter.emit_line("} else {") + emitter.emit_line(f"_Py_IDENTIFIER({fn_rev.name});") emitter.emit_line( 'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", &PyId_{});'.format( - op_methods_to_symbols[fn.name], - fn_rev.name)) - emitter.emit_line('}') - emitter.emit_label('typefail2') - emitter.emit_line('Py_INCREF(Py_NotImplemented);') - emitter.emit_line('return Py_NotImplemented;') + op_methods_to_symbols[fn.name], fn_rev.name + ) + ) + emitter.emit_line("}") + emitter.emit_label("typefail2") + emitter.emit_line("Py_INCREF(Py_NotImplemented);") + emitter.emit_line("return Py_NotImplemented;") gen.finish() RICHCOMPARE_OPS = { - '__lt__': 'Py_LT', - '__gt__': 'Py_GT', - '__le__': 'Py_LE', - '__ge__': 'Py_GE', - '__eq__': 'Py_EQ', - '__ne__': 'Py_NE', + "__lt__": "Py_LT", + "__gt__": "Py_GT", + "__le__": "Py_LE", + "__ge__": "Py_GE", + "__eq__": "Py_EQ", + "__ne__": "Py_NE", } @@ -399,107 +415,114 @@ def generate_richcompare_wrapper(cl: ClassIR, emitter: Emitter) -> Optional[str] if not matches: return None - name = f'{DUNDER_PREFIX}_RichCompare_{cl.name_prefix(emitter.names)}' + name = f"{DUNDER_PREFIX}_RichCompare_{cl.name_prefix(emitter.names)}" emitter.emit_line( - 'static PyObject *{name}(PyObject *obj_lhs, PyObject *obj_rhs, int op) {{'.format( - name=name) + "static PyObject *{name}(PyObject *obj_lhs, PyObject *obj_rhs, int op) {{".format( + name=name + ) ) - emitter.emit_line('switch (op) {') + emitter.emit_line("switch (op) {") for func in matches: - emitter.emit_line(f'case {RICHCOMPARE_OPS[func]}: {{') + emitter.emit_line(f"case {RICHCOMPARE_OPS[func]}: {{") method = cl.get_method(func) assert method is not None - generate_wrapper_core(method, emitter, arg_names=['lhs', 'rhs']) - emitter.emit_line('}') - emitter.emit_line('}') + generate_wrapper_core(method, emitter, arg_names=["lhs", "rhs"]) + emitter.emit_line("}") + emitter.emit_line("}") - emitter.emit_line('Py_INCREF(Py_NotImplemented);') - emitter.emit_line('return Py_NotImplemented;') + emitter.emit_line("Py_INCREF(Py_NotImplemented);") + emitter.emit_line("return Py_NotImplemented;") - emitter.emit_line('}') + emitter.emit_line("}") return name def generate_get_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for native __get__ methods.""" - name = f'{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}' + name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}" emitter.emit_line( - 'static PyObject *{name}(PyObject *self, PyObject *instance, PyObject *owner) {{'. - format(name=name)) - emitter.emit_line('instance = instance ? instance : Py_None;') - emitter.emit_line('return {}{}(self, instance, owner);'.format( - NATIVE_PREFIX, - fn.cname(emitter.names))) - emitter.emit_line('}') + "static PyObject *{name}(PyObject *self, PyObject *instance, PyObject *owner) {{".format( + name=name + ) + ) + emitter.emit_line("instance = instance ? instance : Py_None;") + emitter.emit_line( + "return {}{}(self, instance, owner);".format(NATIVE_PREFIX, fn.cname(emitter.names)) + ) + emitter.emit_line("}") return name def generate_hash_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for native __hash__ methods.""" - name = f'{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}' - emitter.emit_line('static Py_ssize_t {name}(PyObject *self) {{'.format( - name=name - )) - emitter.emit_line('{}retval = {}{}{}(self);'.format(emitter.ctype_spaced(fn.ret_type), - emitter.get_group_prefix(fn.decl), - NATIVE_PREFIX, - fn.cname(emitter.names))) - emitter.emit_error_check('retval', fn.ret_type, 'return -1;') + name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}" + emitter.emit_line("static Py_ssize_t {name}(PyObject *self) {{".format(name=name)) + emitter.emit_line( + "{}retval = {}{}{}(self);".format( + emitter.ctype_spaced(fn.ret_type), + emitter.get_group_prefix(fn.decl), + NATIVE_PREFIX, + fn.cname(emitter.names), + ) + ) + emitter.emit_error_check("retval", fn.ret_type, "return -1;") if is_int_rprimitive(fn.ret_type): - emitter.emit_line('Py_ssize_t val = CPyTagged_AsSsize_t(retval);') + emitter.emit_line("Py_ssize_t val = CPyTagged_AsSsize_t(retval);") else: - emitter.emit_line('Py_ssize_t val = PyLong_AsSsize_t(retval);') - emitter.emit_dec_ref('retval', fn.ret_type) - emitter.emit_line('if (PyErr_Occurred()) return -1;') + emitter.emit_line("Py_ssize_t val = PyLong_AsSsize_t(retval);") + emitter.emit_dec_ref("retval", fn.ret_type) + emitter.emit_line("if (PyErr_Occurred()) return -1;") # We can't return -1 from a hash function.. - emitter.emit_line('if (val == -1) return -2;') - emitter.emit_line('return val;') - emitter.emit_line('}') + emitter.emit_line("if (val == -1) return -2;") + emitter.emit_line("return val;") + emitter.emit_line("}") return name def generate_len_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for native __len__ methods.""" - name = f'{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}' - emitter.emit_line('static Py_ssize_t {name}(PyObject *self) {{'.format( - name=name - )) - emitter.emit_line('{}retval = {}{}{}(self);'.format(emitter.ctype_spaced(fn.ret_type), - emitter.get_group_prefix(fn.decl), - NATIVE_PREFIX, - fn.cname(emitter.names))) - emitter.emit_error_check('retval', fn.ret_type, 'return -1;') + name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}" + emitter.emit_line("static Py_ssize_t {name}(PyObject *self) {{".format(name=name)) + emitter.emit_line( + "{}retval = {}{}{}(self);".format( + emitter.ctype_spaced(fn.ret_type), + emitter.get_group_prefix(fn.decl), + NATIVE_PREFIX, + fn.cname(emitter.names), + ) + ) + emitter.emit_error_check("retval", fn.ret_type, "return -1;") if is_int_rprimitive(fn.ret_type): - emitter.emit_line('Py_ssize_t val = CPyTagged_AsSsize_t(retval);') + emitter.emit_line("Py_ssize_t val = CPyTagged_AsSsize_t(retval);") else: - emitter.emit_line('Py_ssize_t val = PyLong_AsSsize_t(retval);') - emitter.emit_dec_ref('retval', fn.ret_type) - emitter.emit_line('if (PyErr_Occurred()) return -1;') - emitter.emit_line('return val;') - emitter.emit_line('}') + emitter.emit_line("Py_ssize_t val = PyLong_AsSsize_t(retval);") + emitter.emit_dec_ref("retval", fn.ret_type) + emitter.emit_line("if (PyErr_Occurred()) return -1;") + emitter.emit_line("return val;") + emitter.emit_line("}") return name def generate_bool_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for native __bool__ methods.""" - name = f'{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}' - emitter.emit_line('static int {name}(PyObject *self) {{'.format( - name=name - )) - emitter.emit_line('{}val = {}{}(self);'.format(emitter.ctype_spaced(fn.ret_type), - NATIVE_PREFIX, - fn.cname(emitter.names))) - emitter.emit_error_check('val', fn.ret_type, 'return -1;') + name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}" + emitter.emit_line("static int {name}(PyObject *self) {{".format(name=name)) + emitter.emit_line( + "{}val = {}{}(self);".format( + emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names) + ) + ) + emitter.emit_error_check("val", fn.ret_type, "return -1;") # This wouldn't be that hard to fix but it seems unimportant and # getting error handling and unboxing right would be fiddly. (And # way easier to do in IR!) assert is_bool_rprimitive(fn.ret_type), "Only bool return supported for __bool__" - emitter.emit_line('return val;') - emitter.emit_line('}') + emitter.emit_line("return val;") + emitter.emit_line("}") return name @@ -509,12 +532,11 @@ def generate_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: This is only called from a combined __delitem__/__setitem__ wrapper. """ - name = '{}{}{}'.format(DUNDER_PREFIX, '__delitem__', cl.name_prefix(emitter.names)) - input_args = ', '.join(f'PyObject *obj_{arg.name}' for arg in fn.args) - emitter.emit_line('static int {name}({input_args}) {{'.format( - name=name, - input_args=input_args, - )) + name = "{}{}{}".format(DUNDER_PREFIX, "__delitem__", cl.name_prefix(emitter.names)) + input_args = ", ".join(f"PyObject *obj_{arg.name}" for arg in fn.args) + emitter.emit_line( + "static int {name}({input_args}) {{".format(name=name, input_args=input_args) + ) generate_set_del_item_wrapper_inner(fn, emitter, fn.args) return name @@ -529,105 +551,107 @@ def generate_set_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> __setitem__ wrapper call it if the value is NULL. Return the name of the outer (__setitem__) wrapper. """ - method_cls = cl.get_method_and_class('__delitem__') + method_cls = cl.get_method_and_class("__delitem__") del_name = None if method_cls and method_cls[1] == cl: # Generate a separate wrapper for __delitem__ del_name = generate_del_item_wrapper(cl, method_cls[0], emitter) args = fn.args - if fn.name == '__delitem__': + if fn.name == "__delitem__": # Add an extra argument for value that we expect to be NULL. - args = list(args) + [RuntimeArg('___value', object_rprimitive, ARG_POS)] + args = list(args) + [RuntimeArg("___value", object_rprimitive, ARG_POS)] - name = '{}{}{}'.format(DUNDER_PREFIX, '__setitem__', cl.name_prefix(emitter.names)) - input_args = ', '.join(f'PyObject *obj_{arg.name}' for arg in args) - emitter.emit_line('static int {name}({input_args}) {{'.format( - name=name, - input_args=input_args, - )) + name = "{}{}{}".format(DUNDER_PREFIX, "__setitem__", cl.name_prefix(emitter.names)) + input_args = ", ".join(f"PyObject *obj_{arg.name}" for arg in args) + emitter.emit_line( + "static int {name}({input_args}) {{".format(name=name, input_args=input_args) + ) # First check if this is __delitem__ - emitter.emit_line(f'if (obj_{args[2].name} == NULL) {{') + emitter.emit_line(f"if (obj_{args[2].name} == NULL) {{") if del_name is not None: # We have a native implementation, so call it - emitter.emit_line('return {}(obj_{}, obj_{});'.format(del_name, - args[0].name, - args[1].name)) + emitter.emit_line( + "return {}(obj_{}, obj_{});".format(del_name, args[0].name, args[1].name) + ) else: # Try to call superclass method instead - emitter.emit_line( - f'PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});') - emitter.emit_line('if (super == NULL) return -1;') + emitter.emit_line(f"PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});") + emitter.emit_line("if (super == NULL) return -1;") emitter.emit_line( 'PyObject *result = PyObject_CallMethod(super, "__delitem__", "O", obj_{});'.format( - args[1].name)) - emitter.emit_line('Py_DECREF(super);') - emitter.emit_line('Py_XDECREF(result);') - emitter.emit_line('return result == NULL ? -1 : 0;') - emitter.emit_line('}') - - method_cls = cl.get_method_and_class('__setitem__') + args[1].name + ) + ) + emitter.emit_line("Py_DECREF(super);") + emitter.emit_line("Py_XDECREF(result);") + emitter.emit_line("return result == NULL ? -1 : 0;") + emitter.emit_line("}") + + method_cls = cl.get_method_and_class("__setitem__") if method_cls and method_cls[1] == cl: generate_set_del_item_wrapper_inner(fn, emitter, args) else: - emitter.emit_line( - f'PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});') - emitter.emit_line('if (super == NULL) return -1;') - emitter.emit_line('PyObject *result;') + emitter.emit_line(f"PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});") + emitter.emit_line("if (super == NULL) return -1;") + emitter.emit_line("PyObject *result;") if method_cls is None and cl.builtin_base is None: msg = f"'{cl.name}' object does not support item assignment" - emitter.emit_line( - f'PyErr_SetString(PyExc_TypeError, "{msg}");') - emitter.emit_line('result = NULL;') + emitter.emit_line(f'PyErr_SetString(PyExc_TypeError, "{msg}");') + emitter.emit_line("result = NULL;") else: # A base class may have __setitem__ emitter.emit_line( 'result = PyObject_CallMethod(super, "__setitem__", "OO", obj_{}, obj_{});'.format( - args[1].name, args[2].name)) - emitter.emit_line('Py_DECREF(super);') - emitter.emit_line('Py_XDECREF(result);') - emitter.emit_line('return result == NULL ? -1 : 0;') - emitter.emit_line('}') + args[1].name, args[2].name + ) + ) + emitter.emit_line("Py_DECREF(super);") + emitter.emit_line("Py_XDECREF(result);") + emitter.emit_line("return result == NULL ? -1 : 0;") + emitter.emit_line("}") return name -def generate_set_del_item_wrapper_inner(fn: FuncIR, emitter: Emitter, - args: Sequence[RuntimeArg]) -> None: +def generate_set_del_item_wrapper_inner( + fn: FuncIR, emitter: Emitter, args: Sequence[RuntimeArg] +) -> None: for arg in args: - generate_arg_check(arg.name, arg.type, emitter, GotoHandler('fail')) - native_args = ', '.join(f'arg_{arg.name}' for arg in args) - emitter.emit_line('{}val = {}{}({});'.format(emitter.ctype_spaced(fn.ret_type), - NATIVE_PREFIX, - fn.cname(emitter.names), - native_args)) - emitter.emit_error_check('val', fn.ret_type, 'goto fail;') - emitter.emit_dec_ref('val', fn.ret_type) - emitter.emit_line('return 0;') - emitter.emit_label('fail') - emitter.emit_line('return -1;') - emitter.emit_line('}') + generate_arg_check(arg.name, arg.type, emitter, GotoHandler("fail")) + native_args = ", ".join(f"arg_{arg.name}" for arg in args) + emitter.emit_line( + "{}val = {}{}({});".format( + emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names), native_args + ) + ) + emitter.emit_error_check("val", fn.ret_type, "goto fail;") + emitter.emit_dec_ref("val", fn.ret_type) + emitter.emit_line("return 0;") + emitter.emit_label("fail") + emitter.emit_line("return -1;") + emitter.emit_line("}") def generate_contains_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: """Generates a wrapper for a native __contains__ method.""" - name = f'{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}' + name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}" + emitter.emit_line("static int {name}(PyObject *self, PyObject *obj_item) {{".format(name=name)) + generate_arg_check("item", fn.args[1].type, emitter, ReturnHandler("-1")) emitter.emit_line( - 'static int {name}(PyObject *self, PyObject *obj_item) {{'. - format(name=name)) - generate_arg_check('item', fn.args[1].type, emitter, ReturnHandler('-1')) - emitter.emit_line('{}val = {}{}(self, arg_item);'.format(emitter.ctype_spaced(fn.ret_type), - NATIVE_PREFIX, - fn.cname(emitter.names))) - emitter.emit_error_check('val', fn.ret_type, 'return -1;') + "{}val = {}{}(self, arg_item);".format( + emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names) + ) + ) + emitter.emit_error_check("val", fn.ret_type, "return -1;") if is_bool_rprimitive(fn.ret_type): - emitter.emit_line('return val;') + emitter.emit_line("return val;") else: - emitter.emit_line('int boolval = PyObject_IsTrue(val);') - emitter.emit_dec_ref('val', fn.ret_type) - emitter.emit_line('return boolval;') - emitter.emit_line('}') + emitter.emit_line("int boolval = PyObject_IsTrue(val);") + emitter.emit_dec_ref("val", fn.ret_type) + emitter.emit_line("return boolval;") + emitter.emit_line("}") return name @@ -635,12 +659,14 @@ def generate_contains_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str: # Helpers -def generate_wrapper_core(fn: FuncIR, - emitter: Emitter, - optional_args: Optional[List[RuntimeArg]] = None, - arg_names: Optional[List[str]] = None, - cleanups: Optional[List[str]] = None, - traceback_code: Optional[str] = None) -> None: +def generate_wrapper_core( + fn: FuncIR, + emitter: Emitter, + optional_args: Optional[List[RuntimeArg]] = None, + arg_names: Optional[List[str]] = None, + cleanups: Optional[List[str]] = None, + traceback_code: Optional[str] = None, +) -> None: """Generates the core part of a wrapper function for a native function. This expects each argument as a PyObject * named obj_{arg} as a precondition. @@ -652,21 +678,23 @@ def generate_wrapper_core(fn: FuncIR, gen.arg_names = arg_names or [arg.name for arg in fn.args] gen.cleanups = cleanups or [] gen.optional_args = optional_args or [] - gen.traceback_code = traceback_code or '' + gen.traceback_code = traceback_code or "" - error = ReturnHandler('NULL') if not gen.use_goto() else GotoHandler('fail') + error = ReturnHandler("NULL") if not gen.use_goto() else GotoHandler("fail") gen.emit_arg_processing(error=error) gen.emit_call() gen.emit_error_handling() -def generate_arg_check(name: str, - typ: RType, - emitter: Emitter, - error: Optional[ErrorHandler] = None, - *, - optional: bool = False, - raise_exception: bool = True) -> None: +def generate_arg_check( + name: str, + typ: RType, + emitter: Emitter, + error: Optional[ErrorHandler] = None, + *, + optional: bool = False, + raise_exception: bool = True, +) -> None: """Insert a runtime check for argument and unbox if necessary. The object is named PyObject *obj_{}. This is expected to generate @@ -676,31 +704,35 @@ def generate_arg_check(name: str, error = error or AssignHandler() if typ.is_unboxed: # Borrow when unboxing to avoid reference count manipulation. - emitter.emit_unbox(f'obj_{name}', - f'arg_{name}', - typ, - declare_dest=True, - raise_exception=raise_exception, - error=error, - borrow=True, - optional=optional) + emitter.emit_unbox( + f"obj_{name}", + f"arg_{name}", + typ, + declare_dest=True, + raise_exception=raise_exception, + error=error, + borrow=True, + optional=optional, + ) elif is_object_rprimitive(typ): # Object is trivial since any object is valid if optional: - emitter.emit_line(f'PyObject *arg_{name};') - emitter.emit_line(f'if (obj_{name} == NULL) {{') - emitter.emit_line(f'arg_{name} = {emitter.c_error_value(typ)};') - emitter.emit_lines('} else {', f'arg_{name} = obj_{name}; ', '}') + emitter.emit_line(f"PyObject *arg_{name};") + emitter.emit_line(f"if (obj_{name} == NULL) {{") + emitter.emit_line(f"arg_{name} = {emitter.c_error_value(typ)};") + emitter.emit_lines("} else {", f"arg_{name} = obj_{name}; ", "}") else: - emitter.emit_line(f'PyObject *arg_{name} = obj_{name};') + emitter.emit_line(f"PyObject *arg_{name} = obj_{name};") else: - emitter.emit_cast(f'obj_{name}', - f'arg_{name}', - typ, - declare_dest=True, - raise_exception=raise_exception, - error=error, - optional=optional) + emitter.emit_cast( + f"obj_{name}", + f"arg_{name}", + typ, + declare_dest=True, + raise_exception=raise_exception, + error=error, + optional=optional, + ) class WrapperGenerator: @@ -713,7 +745,7 @@ def __init__(self, cl: Optional[ClassIR], emitter: Emitter) -> None: self.emitter = emitter self.cleanups: List[str] = [] self.optional_args: List[RuntimeArg] = [] - self.traceback_code = '' + self.traceback_code = "" def set_target(self, fn: FuncIR) -> None: """Set the wrapped function. @@ -729,9 +761,11 @@ def set_target(self, fn: FuncIR) -> None: def wrapper_name(self) -> str: """Return the name of the wrapper function.""" - return '{}{}{}'.format(DUNDER_PREFIX, - self.target_name, - self.cl.name_prefix(self.emitter.names) if self.cl else '') + return "{}{}{}".format( + DUNDER_PREFIX, + self.target_name, + self.cl.name_prefix(self.emitter.names) if self.cl else "", + ) def use_goto(self) -> bool: """Do we use a goto for error handling (instead of straight return)?""" @@ -739,84 +773,91 @@ def use_goto(self) -> bool: def emit_header(self) -> None: """Emit the function header of the wrapper implementation.""" - input_args = ', '.join(f'PyObject *obj_{arg}' for arg in self.arg_names) - self.emitter.emit_line('static PyObject *{name}({input_args}) {{'.format( - name=self.wrapper_name(), - input_args=input_args, - )) - - def emit_arg_processing(self, - error: Optional[ErrorHandler] = None, - raise_exception: bool = True) -> None: + input_args = ", ".join(f"PyObject *obj_{arg}" for arg in self.arg_names) + self.emitter.emit_line( + "static PyObject *{name}({input_args}) {{".format( + name=self.wrapper_name(), input_args=input_args + ) + ) + + def emit_arg_processing( + self, error: Optional[ErrorHandler] = None, raise_exception: bool = True + ) -> None: """Emit validation and unboxing of arguments.""" error = error or self.error() for arg_name, arg in zip(self.arg_names, self.args): # Suppress the argument check for *args/**kwargs, since we know it must be right. typ = arg.type if arg.kind not in (ARG_STAR, ARG_STAR2) else object_rprimitive - generate_arg_check(arg_name, - typ, - self.emitter, - error, - raise_exception=raise_exception, - optional=arg in self.optional_args) - - def emit_call(self, not_implemented_handler: str = '') -> None: + generate_arg_check( + arg_name, + typ, + self.emitter, + error, + raise_exception=raise_exception, + optional=arg in self.optional_args, + ) + + def emit_call(self, not_implemented_handler: str = "") -> None: """Emit call to the wrapper function. If not_implemented_handler is non-empty, use this C code to handle a NotImplemented return value (if it's possible based on the return type). """ - native_args = ', '.join(f'arg_{arg}' for arg in self.arg_names) + native_args = ", ".join(f"arg_{arg}" for arg in self.arg_names) ret_type = self.ret_type emitter = self.emitter if ret_type.is_unboxed or self.use_goto(): # TODO: The Py_RETURN macros return the correct PyObject * with reference count # handling. Are they relevant? - emitter.emit_line('{}retval = {}{}({});'.format(emitter.ctype_spaced(ret_type), - NATIVE_PREFIX, - self.target_cname, - native_args)) + emitter.emit_line( + "{}retval = {}{}({});".format( + emitter.ctype_spaced(ret_type), NATIVE_PREFIX, self.target_cname, native_args + ) + ) emitter.emit_lines(*self.cleanups) if ret_type.is_unboxed: - emitter.emit_error_check('retval', ret_type, 'return NULL;') - emitter.emit_box('retval', 'retbox', ret_type, declare_dest=True) + emitter.emit_error_check("retval", ret_type, "return NULL;") + emitter.emit_box("retval", "retbox", ret_type, declare_dest=True) - emitter.emit_line( - 'return {};'.format('retbox' if ret_type.is_unboxed else 'retval')) + emitter.emit_line("return {};".format("retbox" if ret_type.is_unboxed else "retval")) else: if not_implemented_handler and not isinstance(ret_type, RInstance): # The return value type may overlap with NotImplemented. - emitter.emit_line('PyObject *retbox = {}{}({});'.format(NATIVE_PREFIX, - self.target_cname, - native_args)) - emitter.emit_lines('if (retbox == Py_NotImplemented) {', - not_implemented_handler, - '}', - 'return retbox;') + emitter.emit_line( + "PyObject *retbox = {}{}({});".format( + NATIVE_PREFIX, self.target_cname, native_args + ) + ) + emitter.emit_lines( + "if (retbox == Py_NotImplemented) {", + not_implemented_handler, + "}", + "return retbox;", + ) else: - emitter.emit_line('return {}{}({});'.format(NATIVE_PREFIX, - self.target_cname, - native_args)) + emitter.emit_line( + "return {}{}({});".format(NATIVE_PREFIX, self.target_cname, native_args) + ) # TODO: Tracebacks? def error(self) -> ErrorHandler: """Figure out how to deal with errors in the wrapper.""" if self.cleanups or self.traceback_code: # We'll have a label at the end with error handling code. - return GotoHandler('fail') + return GotoHandler("fail") else: # Nothing special needs to done to handle errors, so just return. - return ReturnHandler('NULL') + return ReturnHandler("NULL") def emit_error_handling(self) -> None: """Emit error handling block at the end of the wrapper, if needed.""" emitter = self.emitter if self.use_goto(): - emitter.emit_label('fail') + emitter.emit_label("fail") emitter.emit_lines(*self.cleanups) if self.traceback_code: emitter.emit_line(self.traceback_code) - emitter.emit_line('return NULL;') + emitter.emit_line("return NULL;") def finish(self) -> None: - self.emitter.emit_line('}') + self.emitter.emit_line("}") diff --git a/mypyc/codegen/literals.py b/mypyc/codegen/literals.py index a37e6ef072217..3b01afcb49820 100644 --- a/mypyc/codegen/literals.py +++ b/mypyc/codegen/literals.py @@ -1,8 +1,7 @@ -from typing import Dict, List, Union, Tuple, Any, cast +from typing import Any, Dict, List, Tuple, Union, cast from typing_extensions import Final - # Supported Python literal types. All tuple items must have supported # literal types as well, but we can't represent the type precisely. LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None] @@ -56,7 +55,7 @@ def record_literal(self, value: LiteralValue) -> None: self.record_literal(cast(Any, item)) tuple_literals[value] = len(tuple_literals) else: - assert False, 'invalid literal: %r' % value + assert False, "invalid literal: %r" % value def literal_index(self, value: LiteralValue) -> int: """Return the index to the literals array for given value.""" @@ -86,13 +85,19 @@ def literal_index(self, value: LiteralValue) -> int: n += len(self.complex_literals) if isinstance(value, tuple): return n + self.tuple_literals[value] - assert False, 'invalid literal: %r' % value + assert False, "invalid literal: %r" % value def num_literals(self) -> int: # The first three are for None, True and False - return (NUM_SINGLETONS + len(self.str_literals) + len(self.bytes_literals) + - len(self.int_literals) + len(self.float_literals) + len(self.complex_literals) + - len(self.tuple_literals)) + return ( + NUM_SINGLETONS + + len(self.str_literals) + + len(self.bytes_literals) + + len(self.int_literals) + + len(self.float_literals) + + len(self.complex_literals) + + len(self.tuple_literals) + ) # The following methods return the C encodings of literal values # of different types @@ -149,14 +154,14 @@ def _encode_str_values(values: Dict[str, int]) -> List[bytes]: c_literal = format_str_literal(value) c_len = len(c_literal) if line_len > 0 and line_len + c_len > 70: - result.append(format_int(len(line)) + b''.join(line)) + result.append(format_int(len(line)) + b"".join(line)) line = [] line_len = 0 line.append(c_literal) line_len += c_len if line: - result.append(format_int(len(line)) + b''.join(line)) - result.append(b'') + result.append(format_int(len(line)) + b"".join(line)) + result.append(b"") return result @@ -170,14 +175,14 @@ def _encode_bytes_values(values: Dict[bytes, int]) -> List[bytes]: c_init = format_int(len(value)) c_len = len(c_init) + len(value) if line_len > 0 and line_len + c_len > 70: - result.append(format_int(len(line)) + b''.join(line)) + result.append(format_int(len(line)) + b"".join(line)) line = [] line_len = 0 line.append(c_init + value) line_len += c_len if line: - result.append(format_int(len(line)) + b''.join(line)) - result.append(b'') + result.append(format_int(len(line)) + b"".join(line)) + result.append(b"") return result @@ -188,7 +193,7 @@ def format_int(n: int) -> bytes: else: a = [] while n > 0: - a.insert(0, n & 0x7f) + a.insert(0, n & 0x7F) n >>= 7 for i in range(len(a) - 1): # If the highest bit is set, more 7-bit digits follow @@ -197,7 +202,7 @@ def format_int(n: int) -> bytes: def format_str_literal(s: str) -> bytes: - utf8 = s.encode('utf-8') + utf8 = s.encode("utf-8") return format_int(len(utf8)) + utf8 @@ -212,26 +217,26 @@ def _encode_int_values(values: Dict[int, int]) -> List[bytes]: line_len = 0 for i in range(len(values)): value = value_by_index[i] - encoded = b'%d' % value + encoded = b"%d" % value if line_len > 0 and line_len + len(encoded) > 70: - result.append(format_int(len(line)) + b'\0'.join(line)) + result.append(format_int(len(line)) + b"\0".join(line)) line = [] line_len = 0 line.append(encoded) line_len += len(encoded) if line: - result.append(format_int(len(line)) + b'\0'.join(line)) - result.append(b'') + result.append(format_int(len(line)) + b"\0".join(line)) + result.append(b"") return result def float_to_c(x: float) -> str: """Return C literal representation of a float value.""" s = str(x) - if s == 'inf': - return 'INFINITY' - elif s == '-inf': - return '-INFINITY' + if s == "inf": + return "INFINITY" + elif s == "-inf": + return "-INFINITY" return s diff --git a/mypyc/common.py b/mypyc/common.py index e07bbe2511cb2..ac238c41e9539 100644 --- a/mypyc/common.py +++ b/mypyc/common.py @@ -1,9 +1,10 @@ -from mypy.util import unnamed_function -from typing import Dict, Any, Optional, Tuple import sys +from typing import Any, Dict, Optional, Tuple from typing_extensions import Final +from mypy.util import unnamed_function + PREFIX: Final = "CPyPy_" # Python wrappers NATIVE_PREFIX: Final = "CPyDef_" # Native functions etc. DUNDER_PREFIX: Final = "CPyDunder_" # Wrappers for exposing dunder methods to the API @@ -46,24 +47,24 @@ # # Note: Assume that the compiled code uses the same bit width as mypyc, except for # Python 3.5 on macOS. -MAX_LITERAL_SHORT_INT: Final = sys.maxsize >> 1 if not IS_MIXED_32_64_BIT_BUILD else 2 ** 30 - 1 +MAX_LITERAL_SHORT_INT: Final = sys.maxsize >> 1 if not IS_MIXED_32_64_BIT_BUILD else 2**30 - 1 MIN_LITERAL_SHORT_INT: Final = -MAX_LITERAL_SHORT_INT - 1 # Runtime C library files RUNTIME_C_FILES: Final = [ - 'init.c', - 'getargs.c', - 'getargsfast.c', - 'int_ops.c', - 'str_ops.c', - 'bytes_ops.c', - 'list_ops.c', - 'dict_ops.c', - 'set_ops.c', - 'tuple_ops.c', - 'exc_ops.c', - 'misc_ops.c', - 'generic_ops.c', + "init.c", + "getargs.c", + "getargsfast.c", + "int_ops.c", + "str_ops.c", + "bytes_ops.c", + "list_ops.c", + "dict_ops.c", + "set_ops.c", + "tuple_ops.c", + "exc_ops.c", + "misc_ops.c", + "generic_ops.c", ] @@ -75,11 +76,11 @@ def shared_lib_name(group_name: str) -> str: (This just adds a suffix to the final component.) """ - return f'{group_name}__mypyc' + return f"{group_name}__mypyc" def short_name(name: str) -> str: - if name.startswith('builtins.'): + if name.startswith("builtins."): return name[9:] return name diff --git a/mypyc/crash.py b/mypyc/crash.py index b248e27bbdb87..0d2efe524e02c 100644 --- a/mypyc/crash.py +++ b/mypyc/crash.py @@ -1,9 +1,9 @@ -from typing import Iterator -from typing_extensions import NoReturn - import sys import traceback from contextlib import contextmanager +from typing import Iterator + +from typing_extensions import NoReturn @contextmanager @@ -14,18 +14,18 @@ def catch_errors(module_path: str, line: int) -> Iterator[None]: crash_report(module_path, line) -def crash_report(module_path: str, line: int) -> 'NoReturn': +def crash_report(module_path: str, line: int) -> "NoReturn": # Adapted from report_internal_error in mypy err = sys.exc_info()[1] tb = traceback.extract_stack()[:-4] # Excise all the traceback from the test runner for i, x in enumerate(tb): - if x.name == 'pytest_runtest_call': - tb = tb[i + 1:] + if x.name == "pytest_runtest_call": + tb = tb[i + 1 :] break tb2 = traceback.extract_tb(sys.exc_info()[2])[1:] - print('Traceback (most recent call last):') + print("Traceback (most recent call last):") for s in traceback.format_list(tb + tb2): - print(s.rstrip('\n')) - print(f'{module_path}:{line}: {type(err).__name__}: {err}') + print(s.rstrip("\n")) + print(f"{module_path}:{line}: {type(err).__name__}: {err}") raise SystemExit(2) diff --git a/mypyc/doc/conf.py b/mypyc/doc/conf.py index fa980bbb1b065..775c4638fe040 100644 --- a/mypyc/doc/conf.py +++ b/mypyc/doc/conf.py @@ -4,28 +4,28 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -import sys import os +import sys # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) from mypy.version import __version__ as mypy_version # -- Project information ----------------------------------------------------- -project = 'mypyc' -copyright = '2020-2022, mypyc team' -author = 'mypyc team' +project = "mypyc" +copyright = "2020-2022, mypyc team" +author = "mypyc team" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = mypy_version.split('-')[0] +version = mypy_version.split("-")[0] # The full version, including alpha/beta/rc tags. release = mypy_version @@ -34,25 +34,24 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. -extensions = [ # type: ignore -] +extensions = [] # type: ignore # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'furo' +html_theme = "furo" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/mypyc/errors.py b/mypyc/errors.py index 3d3b8694c9d6b..dd0c5dcbc4cc5 100644 --- a/mypyc/errors.py +++ b/mypyc/errors.py @@ -10,14 +10,14 @@ def __init__(self) -> None: self._errors = mypy.errors.Errors() def error(self, msg: str, path: str, line: int) -> None: - self._errors.report(line, None, msg, severity='error', file=path) + self._errors.report(line, None, msg, severity="error", file=path) self.num_errors += 1 def note(self, msg: str, path: str, line: int) -> None: - self._errors.report(line, None, msg, severity='note', file=path) + self._errors.report(line, None, msg, severity="note", file=path) def warning(self, msg: str, path: str, line: int) -> None: - self._errors.report(line, None, msg, severity='warning', file=path) + self._errors.report(line, None, msg, severity="warning", file=path) self.num_warnings += 1 def new_messages(self) -> List[str]: diff --git a/mypyc/ir/class_ir.py b/mypyc/ir/class_ir.py index 197b267633d71..015fd503ffc7f 100644 --- a/mypyc/ir/class_ir.py +++ b/mypyc/ir/class_ir.py @@ -1,15 +1,13 @@ """Intermediate representation of classes.""" -from typing import List, Optional, Set, Tuple, Dict, NamedTuple -from mypy.backports import OrderedDict +from typing import Dict, List, NamedTuple, Optional, Set, Tuple -from mypyc.common import JsonDict -from mypyc.ir.ops import Value, DeserMaps -from mypyc.ir.rtypes import RType, RInstance, deserialize_type -from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature +from mypy.backports import OrderedDict +from mypyc.common import PROPSET_PREFIX, JsonDict +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature +from mypyc.ir.ops import DeserMaps, Value +from mypyc.ir.rtypes import RInstance, RType, deserialize_type from mypyc.namegen import NameGenerator, exported_name -from mypyc.common import PROPSET_PREFIX - # Some notes on the vtable layout: Each concrete class has a vtable # that contains function pointers for its methods. So that subclasses @@ -70,10 +68,9 @@ # placed in the class's shadow vtable (if it has one). VTableMethod = NamedTuple( - 'VTableMethod', [('cls', 'ClassIR'), - ('name', str), - ('method', FuncIR), - ('shadow_method', Optional[FuncIR])]) + "VTableMethod", + [("cls", "ClassIR"), ("name", str), ("method", FuncIR), ("shadow_method", Optional[FuncIR])], +) VTableEntries = List[VTableMethod] @@ -85,9 +82,15 @@ class ClassIR: This also describes the runtime structure of native instances. """ - def __init__(self, name: str, module_name: str, is_trait: bool = False, - is_generated: bool = False, is_abstract: bool = False, - is_ext_class: bool = True) -> None: + def __init__( + self, + name: str, + module_name: str, + is_trait: bool = False, + is_generated: bool = False, + is_abstract: bool = False, + is_ext_class: bool = True, + ) -> None: self.name = name self.module_name = module_name self.is_trait = is_trait @@ -185,13 +188,14 @@ def __repr__(self) -> str: "name={self.name}, module_name={self.module_name}, " "is_trait={self.is_trait}, is_generated={self.is_generated}, " "is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}" - ")".format(self=self)) + ")".format(self=self) + ) @property def fullname(self) -> str: return f"{self.module_name}.{self.name}" - def real_base(self) -> Optional['ClassIR']: + def real_base(self) -> Optional["ClassIR"]: """Return the actual concrete base class, if there is one.""" if len(self.mro) > 1 and not self.mro[1].is_trait: return self.mro[1] @@ -199,16 +203,16 @@ def real_base(self) -> Optional['ClassIR']: def vtable_entry(self, name: str) -> int: assert self.vtable is not None, "vtable not computed yet" - assert name in self.vtable, f'{self.name!r} has no attribute {name!r}' + assert name in self.vtable, f"{self.name!r} has no attribute {name!r}" return self.vtable[name] - def attr_details(self, name: str) -> Tuple[RType, 'ClassIR']: + def attr_details(self, name: str) -> Tuple[RType, "ClassIR"]: for ir in self.mro: if name in ir.attributes: return ir.attributes[name], ir if name in ir.property_types: return ir.property_types[name], ir - raise KeyError(f'{self.name!r} has no attribute {name!r}') + raise KeyError(f"{self.name!r} has no attribute {name!r}") def attr_type(self, name: str) -> RType: return self.attr_details(name)[0] @@ -217,7 +221,7 @@ def method_decl(self, name: str) -> FuncDecl: for ir in self.mro: if name in ir.method_decls: return ir.method_decls[name] - raise KeyError(f'{self.name!r} has no attribute {name!r}') + raise KeyError(f"{self.name!r} has no attribute {name!r}") def method_sig(self, name: str) -> FuncSignature: return self.method_decl(name).sig @@ -266,9 +270,9 @@ def name_prefix(self, names: NameGenerator) -> str: return names.private_name(self.module_name, self.name) def struct_name(self, names: NameGenerator) -> str: - return f'{exported_name(self.fullname)}Object' + return f"{exported_name(self.fullname)}Object" - def get_method_and_class(self, name: str) -> Optional[Tuple[FuncIR, 'ClassIR']]: + def get_method_and_class(self, name: str) -> Optional[Tuple[FuncIR, "ClassIR"]]: for ir in self.mro: if name in ir.methods: return ir.methods[name], ir @@ -279,7 +283,7 @@ def get_method(self, name: str) -> Optional[FuncIR]: res = self.get_method_and_class(name) return res[0] if res else None - def subclasses(self) -> Optional[Set['ClassIR']]: + def subclasses(self) -> Optional[Set["ClassIR"]]: """Return all subclasses of this class, both direct and indirect. Return None if it is impossible to identify all subclasses, for example @@ -296,7 +300,7 @@ def subclasses(self) -> Optional[Set['ClassIR']]: result.update(child_subs) return result - def concrete_subclasses(self) -> Optional[List['ClassIR']]: + def concrete_subclasses(self) -> Optional[List["ClassIR"]]: """Return all concrete (i.e. non-trait and non-abstract) subclasses. Include both direct and indirect subclasses. Place classes with no children first. @@ -315,111 +319,108 @@ def is_serializable(self) -> bool: def serialize(self) -> JsonDict: return { - 'name': self.name, - 'module_name': self.module_name, - 'is_trait': self.is_trait, - 'is_ext_class': self.is_ext_class, - 'is_abstract': self.is_abstract, - 'is_generated': self.is_generated, - 'is_augmented': self.is_augmented, - 'inherits_python': self.inherits_python, - 'has_dict': self.has_dict, - 'allow_interpreted_subclasses': self.allow_interpreted_subclasses, - 'needs_getseters': self.needs_getseters, - '_serializable': self._serializable, - 'builtin_base': self.builtin_base, - 'ctor': self.ctor.serialize(), + "name": self.name, + "module_name": self.module_name, + "is_trait": self.is_trait, + "is_ext_class": self.is_ext_class, + "is_abstract": self.is_abstract, + "is_generated": self.is_generated, + "is_augmented": self.is_augmented, + "inherits_python": self.inherits_python, + "has_dict": self.has_dict, + "allow_interpreted_subclasses": self.allow_interpreted_subclasses, + "needs_getseters": self.needs_getseters, + "_serializable": self._serializable, + "builtin_base": self.builtin_base, + "ctor": self.ctor.serialize(), # We serialize dicts as lists to ensure order is preserved - 'attributes': [(k, t.serialize()) for k, t in self.attributes.items()], + "attributes": [(k, t.serialize()) for k, t in self.attributes.items()], # We try to serialize a name reference, but if the decl isn't in methods # then we can't be sure that will work so we serialize the whole decl. - 'method_decls': [(k, d.id if k in self.methods else d.serialize()) - for k, d in self.method_decls.items()], + "method_decls": [ + (k, d.id if k in self.methods else d.serialize()) + for k, d in self.method_decls.items() + ], # We serialize method fullnames out and put methods in a separate dict - 'methods': [(k, m.id) for k, m in self.methods.items()], - 'glue_methods': [ - ((cir.fullname, k), m.id) - for (cir, k), m in self.glue_methods.items() + "methods": [(k, m.id) for k, m in self.methods.items()], + "glue_methods": [ + ((cir.fullname, k), m.id) for (cir, k), m in self.glue_methods.items() ], - # We serialize properties and property_types separately out of an # abundance of caution about preserving dict ordering... - 'property_types': [(k, t.serialize()) for k, t in self.property_types.items()], - 'properties': list(self.properties), - - 'vtable': self.vtable, - 'vtable_entries': serialize_vtable(self.vtable_entries), - 'trait_vtables': [ + "property_types": [(k, t.serialize()) for k, t in self.property_types.items()], + "properties": list(self.properties), + "vtable": self.vtable, + "vtable_entries": serialize_vtable(self.vtable_entries), + "trait_vtables": [ (cir.fullname, serialize_vtable(v)) for cir, v in self.trait_vtables.items() ], - # References to class IRs are all just names - 'base': self.base.fullname if self.base else None, - 'traits': [cir.fullname for cir in self.traits], - 'mro': [cir.fullname for cir in self.mro], - 'base_mro': [cir.fullname for cir in self.base_mro], - 'children': [ - cir.fullname for cir in self.children - ] if self.children is not None else None, - 'deletable': self.deletable, - 'attrs_with_defaults': sorted(self.attrs_with_defaults), - '_always_initialized_attrs': sorted(self._always_initialized_attrs), - '_sometimes_initialized_attrs': sorted(self._sometimes_initialized_attrs), - 'init_self_leak': self.init_self_leak, + "base": self.base.fullname if self.base else None, + "traits": [cir.fullname for cir in self.traits], + "mro": [cir.fullname for cir in self.mro], + "base_mro": [cir.fullname for cir in self.base_mro], + "children": [cir.fullname for cir in self.children] + if self.children is not None + else None, + "deletable": self.deletable, + "attrs_with_defaults": sorted(self.attrs_with_defaults), + "_always_initialized_attrs": sorted(self._always_initialized_attrs), + "_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs), + "init_self_leak": self.init_self_leak, } @classmethod - def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'ClassIR': - fullname = data['module_name'] + '.' + data['name'] + def deserialize(cls, data: JsonDict, ctx: "DeserMaps") -> "ClassIR": + fullname = data["module_name"] + "." + data["name"] assert fullname in ctx.classes, "Class %s not in deser class map" % fullname ir = ctx.classes[fullname] - ir.is_trait = data['is_trait'] - ir.is_generated = data['is_generated'] - ir.is_abstract = data['is_abstract'] - ir.is_ext_class = data['is_ext_class'] - ir.is_augmented = data['is_augmented'] - ir.inherits_python = data['inherits_python'] - ir.has_dict = data['has_dict'] - ir.allow_interpreted_subclasses = data['allow_interpreted_subclasses'] - ir.needs_getseters = data['needs_getseters'] - ir._serializable = data['_serializable'] - ir.builtin_base = data['builtin_base'] - ir.ctor = FuncDecl.deserialize(data['ctor'], ctx) - ir.attributes = OrderedDict( - (k, deserialize_type(t, ctx)) for k, t in data['attributes'] + ir.is_trait = data["is_trait"] + ir.is_generated = data["is_generated"] + ir.is_abstract = data["is_abstract"] + ir.is_ext_class = data["is_ext_class"] + ir.is_augmented = data["is_augmented"] + ir.inherits_python = data["inherits_python"] + ir.has_dict = data["has_dict"] + ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"] + ir.needs_getseters = data["needs_getseters"] + ir._serializable = data["_serializable"] + ir.builtin_base = data["builtin_base"] + ir.ctor = FuncDecl.deserialize(data["ctor"], ctx) + ir.attributes = OrderedDict((k, deserialize_type(t, ctx)) for k, t in data["attributes"]) + ir.method_decls = OrderedDict( + (k, ctx.functions[v].decl if isinstance(v, str) else FuncDecl.deserialize(v, ctx)) + for k, v in data["method_decls"] ) - ir.method_decls = OrderedDict((k, ctx.functions[v].decl - if isinstance(v, str) else FuncDecl.deserialize(v, ctx)) - for k, v in data['method_decls']) - ir.methods = OrderedDict((k, ctx.functions[v]) for k, v in data['methods']) + ir.methods = OrderedDict((k, ctx.functions[v]) for k, v in data["methods"]) ir.glue_methods = OrderedDict( - ((ctx.classes[c], k), ctx.functions[v]) for (c, k), v in data['glue_methods'] + ((ctx.classes[c], k), ctx.functions[v]) for (c, k), v in data["glue_methods"] ) ir.property_types = OrderedDict( - (k, deserialize_type(t, ctx)) for k, t in data['property_types'] + (k, deserialize_type(t, ctx)) for k, t in data["property_types"] ) ir.properties = OrderedDict( - (k, (ir.methods[k], ir.methods.get(PROPSET_PREFIX + k))) for k in data['properties'] + (k, (ir.methods[k], ir.methods.get(PROPSET_PREFIX + k))) for k in data["properties"] ) - ir.vtable = data['vtable'] - ir.vtable_entries = deserialize_vtable(data['vtable_entries'], ctx) + ir.vtable = data["vtable"] + ir.vtable_entries = deserialize_vtable(data["vtable_entries"], ctx) ir.trait_vtables = OrderedDict( - (ctx.classes[k], deserialize_vtable(v, ctx)) for k, v in data['trait_vtables'] + (ctx.classes[k], deserialize_vtable(v, ctx)) for k, v in data["trait_vtables"] ) - base = data['base'] + base = data["base"] ir.base = ctx.classes[base] if base else None - ir.traits = [ctx.classes[s] for s in data['traits']] - ir.mro = [ctx.classes[s] for s in data['mro']] - ir.base_mro = [ctx.classes[s] for s in data['base_mro']] - ir.children = data['children'] and [ctx.classes[s] for s in data['children']] - ir.deletable = data['deletable'] - ir.attrs_with_defaults = set(data['attrs_with_defaults']) - ir._always_initialized_attrs = set(data['_always_initialized_attrs']) - ir._sometimes_initialized_attrs = set(data['_sometimes_initialized_attrs']) - ir.init_self_leak = data['init_self_leak'] + ir.traits = [ctx.classes[s] for s in data["traits"]] + ir.mro = [ctx.classes[s] for s in data["mro"]] + ir.base_mro = [ctx.classes[s] for s in data["base_mro"]] + ir.children = data["children"] and [ctx.classes[s] for s in data["children"]] + ir.deletable = data["deletable"] + ir.attrs_with_defaults = set(data["attrs_with_defaults"]) + ir._always_initialized_attrs = set(data["_always_initialized_attrs"]) + ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"]) + ir.init_self_leak = data["init_self_leak"] return ir @@ -440,11 +441,11 @@ def __init__(self, dict: Value, bases: Value, anns: Value, metaclass: Value) -> def serialize_vtable_entry(entry: VTableMethod) -> JsonDict: return { - '.class': 'VTableMethod', - 'cls': entry.cls.fullname, - 'name': entry.name, - 'method': entry.method.decl.id, - 'shadow_method': entry.shadow_method.decl.id if entry.shadow_method else None, + ".class": "VTableMethod", + "cls": entry.cls.fullname, + "name": entry.name, + "method": entry.method.decl.id, + "shadow_method": entry.shadow_method.decl.id if entry.shadow_method else None, } @@ -452,15 +453,18 @@ def serialize_vtable(vtable: VTableEntries) -> List[JsonDict]: return [serialize_vtable_entry(v) for v in vtable] -def deserialize_vtable_entry(data: JsonDict, ctx: 'DeserMaps') -> VTableMethod: - if data['.class'] == 'VTableMethod': +def deserialize_vtable_entry(data: JsonDict, ctx: "DeserMaps") -> VTableMethod: + if data[".class"] == "VTableMethod": return VTableMethod( - ctx.classes[data['cls']], data['name'], ctx.functions[data['method']], - ctx.functions[data['shadow_method']] if data['shadow_method'] else None) - assert False, "Bogus vtable .class: %s" % data['.class'] + ctx.classes[data["cls"]], + data["name"], + ctx.functions[data["method"]], + ctx.functions[data["shadow_method"]] if data["shadow_method"] else None, + ) + assert False, "Bogus vtable .class: %s" % data[".class"] -def deserialize_vtable(data: List[JsonDict], ctx: 'DeserMaps') -> VTableEntries: +def deserialize_vtable(data: List[JsonDict], ctx: "DeserMaps") -> VTableEntries: return [deserialize_vtable_entry(x, ctx) for x in data] diff --git a/mypyc/ir/func_ir.py b/mypyc/ir/func_ir.py index 6a5a720e309bc..7bc0d879814de 100644 --- a/mypyc/ir/func_ir.py +++ b/mypyc/ir/func_ir.py @@ -1,13 +1,20 @@ """Intermediate representation of functions.""" from typing import List, Optional, Sequence -from typing_extensions import Final -from mypy.nodes import FuncDef, Block, ArgKind, ARG_POS +from typing_extensions import Final +from mypy.nodes import ARG_POS, ArgKind, Block, FuncDef from mypyc.common import JsonDict, get_id_from_name, short_id_from_name from mypyc.ir.ops import ( - DeserMaps, BasicBlock, Value, Register, Assign, AssignMulti, ControlOp, LoadAddress + Assign, + AssignMulti, + BasicBlock, + ControlOp, + DeserMaps, + LoadAddress, + Register, + Value, ) from mypyc.ir.rtypes import RType, deserialize_type from mypyc.namegen import NameGenerator @@ -20,7 +27,8 @@ class RuntimeArg: """ def __init__( - self, name: str, typ: RType, kind: ArgKind = ARG_POS, pos_only: bool = False) -> None: + self, name: str, typ: RType, kind: ArgKind = ARG_POS, pos_only: bool = False + ) -> None: self.name = name self.type = typ self.kind = kind @@ -31,20 +39,25 @@ def optional(self) -> bool: return self.kind.is_optional() def __repr__(self) -> str: - return 'RuntimeArg(name={}, type={}, optional={!r}, pos_only={!r})'.format( - self.name, self.type, self.optional, self.pos_only) + return "RuntimeArg(name={}, type={}, optional={!r}, pos_only={!r})".format( + self.name, self.type, self.optional, self.pos_only + ) def serialize(self) -> JsonDict: - return {'name': self.name, 'type': self.type.serialize(), 'kind': int(self.kind.value), - 'pos_only': self.pos_only} + return { + "name": self.name, + "type": self.type.serialize(), + "kind": int(self.kind.value), + "pos_only": self.pos_only, + } @classmethod - def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'RuntimeArg': + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> "RuntimeArg": return RuntimeArg( - data['name'], - deserialize_type(data['type'], ctx), - ArgKind(data['kind']), - data['pos_only'], + data["name"], + deserialize_type(data["type"], ctx), + ArgKind(data["kind"]), + data["pos_only"], ) @@ -58,16 +71,16 @@ def __init__(self, args: Sequence[RuntimeArg], ret_type: RType) -> None: self.ret_type = ret_type def __repr__(self) -> str: - return f'FuncSignature(args={self.args!r}, ret={self.ret_type!r})' + return f"FuncSignature(args={self.args!r}, ret={self.ret_type!r})" def serialize(self) -> JsonDict: - return {'args': [t.serialize() for t in self.args], 'ret_type': self.ret_type.serialize()} + return {"args": [t.serialize() for t in self.args], "ret_type": self.ret_type.serialize()} @classmethod - def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'FuncSignature': + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> "FuncSignature": return FuncSignature( - [RuntimeArg.deserialize(arg, ctx) for arg in data['args']], - deserialize_type(data['ret_type'], ctx), + [RuntimeArg.deserialize(arg, ctx) for arg in data["args"]], + deserialize_type(data["ret_type"], ctx), ) @@ -83,14 +96,16 @@ class FuncDecl: static method, a class method, or a property getter/setter. """ - def __init__(self, - name: str, - class_name: Optional[str], - module_name: str, - sig: FuncSignature, - kind: int = FUNC_NORMAL, - is_prop_setter: bool = False, - is_prop_getter: bool = False) -> None: + def __init__( + self, + name: str, + class_name: Optional[str], + module_name: str, + sig: FuncSignature, + kind: int = FUNC_NORMAL, + is_prop_setter: bool = False, + is_prop_getter: bool = False, + ) -> None: self.name = name self.class_name = class_name self.module_name = module_name @@ -126,7 +141,7 @@ def id(self) -> str: @staticmethod def compute_shortname(class_name: Optional[str], name: str) -> str: - return class_name + '.' + name if class_name else name + return class_name + "." + name if class_name else name @property def shortname(self) -> str: @@ -134,7 +149,7 @@ def shortname(self) -> str: @property def fullname(self) -> str: - return self.module_name + '.' + self.shortname + return self.module_name + "." + self.shortname def cname(self, names: NameGenerator) -> str: partial_name = short_id_from_name(self.name, self.shortname, self._line) @@ -142,34 +157,34 @@ def cname(self, names: NameGenerator) -> str: def serialize(self) -> JsonDict: return { - 'name': self.name, - 'class_name': self.class_name, - 'module_name': self.module_name, - 'sig': self.sig.serialize(), - 'kind': self.kind, - 'is_prop_setter': self.is_prop_setter, - 'is_prop_getter': self.is_prop_getter, + "name": self.name, + "class_name": self.class_name, + "module_name": self.module_name, + "sig": self.sig.serialize(), + "kind": self.kind, + "is_prop_setter": self.is_prop_setter, + "is_prop_getter": self.is_prop_getter, } # TODO: move this to FuncIR? @staticmethod def get_id_from_json(func_ir: JsonDict) -> str: """Get the id from the serialized FuncIR associated with this FuncDecl""" - decl = func_ir['decl'] - shortname = FuncDecl.compute_shortname(decl['class_name'], decl['name']) - fullname = decl['module_name'] + '.' + shortname - return get_id_from_name(decl['name'], fullname, func_ir['line']) + decl = func_ir["decl"] + shortname = FuncDecl.compute_shortname(decl["class_name"], decl["name"]) + fullname = decl["module_name"] + "." + shortname + return get_id_from_name(decl["name"], fullname, func_ir["line"]) @classmethod - def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'FuncDecl': + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> "FuncDecl": return FuncDecl( - data['name'], - data['class_name'], - data['module_name'], - FuncSignature.deserialize(data['sig'], ctx), - data['kind'], - data['is_prop_setter'], - data['is_prop_getter'], + data["name"], + data["class_name"], + data["module_name"], + FuncSignature.deserialize(data["sig"], ctx), + data["kind"], + data["is_prop_setter"], + data["is_prop_getter"], ) @@ -179,12 +194,14 @@ class FuncIR: Unlike FuncDecl, this includes the IR of the body (basic blocks). """ - def __init__(self, - decl: FuncDecl, - arg_regs: List[Register], - blocks: List[BasicBlock], - line: int = -1, - traceback_name: Optional[str] = None) -> None: + def __init__( + self, + decl: FuncDecl, + arg_regs: List[Register], + blocks: List[BasicBlock], + line: int = -1, + traceback_name: Optional[str] = None, + ) -> None: # Declaration of the function, including the signature self.decl = decl # Registers for all the arguments to the function @@ -234,26 +251,22 @@ def cname(self, names: NameGenerator) -> str: def __repr__(self) -> str: if self.class_name: - return f'' + return f"" else: - return f'' + return f"" def serialize(self) -> JsonDict: # We don't include blocks in the serialized version return { - 'decl': self.decl.serialize(), - 'line': self.line, - 'traceback_name': self.traceback_name, + "decl": self.decl.serialize(), + "line": self.line, + "traceback_name": self.traceback_name, } @classmethod - def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'FuncIR': + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> "FuncIR": return FuncIR( - FuncDecl.deserialize(data['decl'], ctx), - [], - [], - data['line'], - data['traceback_name'], + FuncDecl.deserialize(data["decl"], ctx), [], [], data["line"], data["traceback_name"] ) @@ -279,9 +292,11 @@ def all_values(args: List[Register], blocks: List[BasicBlock]) -> List[Value]: continue else: # If we take the address of a register, it might get initialized. - if (isinstance(op, LoadAddress) - and isinstance(op.src, Register) - and op.src not in seen_registers): + if ( + isinstance(op, LoadAddress) + and isinstance(op.src, Register) + and op.src not in seen_registers + ): values.append(op.src) seen_registers.add(op.src) values.append(op) diff --git a/mypyc/ir/module_ir.py b/mypyc/ir/module_ir.py index 8fa5e522ddf03..bd0ae8226e80c 100644 --- a/mypyc/ir/module_ir.py +++ b/mypyc/ir/module_ir.py @@ -1,24 +1,25 @@ """Intermediate representation of modules.""" -from typing import List, Tuple, Dict +from typing import Dict, List, Tuple from mypyc.common import JsonDict +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR from mypyc.ir.ops import DeserMaps from mypyc.ir.rtypes import RType, deserialize_type -from mypyc.ir.func_ir import FuncIR, FuncDecl -from mypyc.ir.class_ir import ClassIR class ModuleIR: """Intermediate representation of a module.""" def __init__( - self, - fullname: str, - imports: List[str], - functions: List[FuncIR], - classes: List[ClassIR], - final_names: List[Tuple[str, RType]]) -> None: + self, + fullname: str, + imports: List[str], + functions: List[FuncIR], + classes: List[ClassIR], + final_names: List[Tuple[str, RType]], + ) -> None: self.fullname = fullname self.imports = imports[:] self.functions = functions @@ -27,21 +28,21 @@ def __init__( def serialize(self) -> JsonDict: return { - 'fullname': self.fullname, - 'imports': self.imports, - 'functions': [f.serialize() for f in self.functions], - 'classes': [c.serialize() for c in self.classes], - 'final_names': [(k, t.serialize()) for k, t in self.final_names], + "fullname": self.fullname, + "imports": self.imports, + "functions": [f.serialize() for f in self.functions], + "classes": [c.serialize() for c in self.classes], + "final_names": [(k, t.serialize()) for k, t in self.final_names], } @classmethod - def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> 'ModuleIR': + def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> "ModuleIR": return ModuleIR( - data['fullname'], - data['imports'], - [ctx.functions[FuncDecl.get_id_from_json(f)] for f in data['functions']], - [ClassIR.deserialize(c, ctx) for c in data['classes']], - [(k, deserialize_type(t, ctx)) for k, t in data['final_names']], + data["fullname"], + data["imports"], + [ctx.functions[FuncDecl.get_id_from_json(f)] for f in data["functions"]], + [ClassIR.deserialize(c, ctx) for c in data["classes"]], + [(k, deserialize_type(t, ctx)) for k, t in data["final_names"]], ) @@ -62,18 +63,19 @@ def deserialize_modules(data: Dict[str, JsonDict], ctx: DeserMaps) -> Dict[str, """ for mod in data.values(): # First create ClassIRs for every class so that we can construct types and whatnot - for cls in mod['classes']: - ir = ClassIR(cls['name'], cls['module_name']) + for cls in mod["classes"]: + ir = ClassIR(cls["name"], cls["module_name"]) assert ir.fullname not in ctx.classes, "Class %s already in map" % ir.fullname ctx.classes[ir.fullname] = ir for mod in data.values(): # Then deserialize all of the functions so that methods are available # to the class deserialization. - for method in mod['functions']: + for method in mod["functions"]: func = FuncIR.deserialize(method, ctx) assert func.decl.id not in ctx.functions, ( - "Method %s already in map" % func.decl.fullname) + "Method %s already in map" % func.decl.fullname + ) ctx.functions[func.decl.id] = func return {k: ModuleIR.deserialize(v, ctx) for k, v in data.items()} diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 8474b5ab58e2a..dec4018a14cbd 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -10,25 +10,38 @@ """ from abc import abstractmethod -from typing import ( - List, Sequence, Dict, Generic, TypeVar, Optional, NamedTuple, Tuple, Union -) +from typing import Dict, Generic, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union -from typing_extensions import Final, TYPE_CHECKING from mypy_extensions import trait +from typing_extensions import TYPE_CHECKING, Final from mypyc.ir.rtypes import ( - RType, RInstance, RTuple, RArray, RVoid, is_bool_rprimitive, is_int_rprimitive, - is_short_int_rprimitive, is_none_rprimitive, object_rprimitive, bool_rprimitive, - short_int_rprimitive, int_rprimitive, void_rtype, pointer_rprimitive, is_pointer_rprimitive, - bit_rprimitive, is_bit_rprimitive, is_fixed_width_rtype + RArray, + RInstance, + RTuple, + RType, + RVoid, + bit_rprimitive, + bool_rprimitive, + int_rprimitive, + is_bit_rprimitive, + is_bool_rprimitive, + is_fixed_width_rtype, + is_int_rprimitive, + is_none_rprimitive, + is_pointer_rprimitive, + is_short_int_rprimitive, + object_rprimitive, + pointer_rprimitive, + short_int_rprimitive, + void_rtype, ) if TYPE_CHECKING: from mypyc.ir.class_ir import ClassIR # noqa - from mypyc.ir.func_ir import FuncIR, FuncDecl # noqa + from mypyc.ir.func_ir import FuncDecl, FuncIR # noqa -T = TypeVar('T') +T = TypeVar("T") class BasicBlock: @@ -76,7 +89,7 @@ def terminated(self) -> bool: return bool(self.ops) and isinstance(self.ops[-1], ControlOp) @property - def terminator(self) -> 'ControlOp': + def terminator(self) -> "ControlOp": """The terminator operation of the block.""" assert bool(self.ops) and isinstance(self.ops[-1], ControlOp) return self.ops[-1] @@ -136,7 +149,7 @@ class Register(Value): to refer to arbitrary Values (for example, in RegisterOp). """ - def __init__(self, type: RType, name: str = '', is_arg: bool = False, line: int = -1) -> None: + def __init__(self, type: RType, name: str = "", is_arg: bool = False, line: int = -1) -> None: self.type = type self.name = name self.is_arg = is_arg @@ -148,7 +161,7 @@ def is_void(self) -> bool: return False def __repr__(self) -> str: - return f'' + return f"" class Integer(Value): @@ -213,12 +226,13 @@ def unique_sources(self) -> List[Value]: return result @abstractmethod - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: pass class BaseAssign(Op): """Base class for ops that assign to a register.""" + def __init__(self, dest: Register, line: int = -1) -> None: super().__init__(line) self.dest = dest @@ -239,7 +253,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_assign(self) @@ -269,7 +283,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_assign_multi(self) @@ -302,12 +316,12 @@ def set_target(self, i: int, new: BasicBlock) -> None: self.label = new def __repr__(self) -> str: - return '' % self.label.label + return "" % self.label.label def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_goto(self) @@ -327,14 +341,16 @@ class Branch(ControlOp): BOOL: Final = 100 IS_ERROR: Final = 101 - def __init__(self, - value: Value, - true_label: BasicBlock, - false_label: BasicBlock, - op: int, - line: int = -1, - *, - rare: bool = False) -> None: + def __init__( + self, + value: Value, + true_label: BasicBlock, + false_label: BasicBlock, + op: int, + line: int = -1, + *, + rare: bool = False, + ) -> None: super().__init__(line) # Target value being checked self.value = value @@ -368,7 +384,7 @@ def sources(self) -> List[Value]: def invert(self) -> None: self.negated = not self.negated - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_branch(self) @@ -387,7 +403,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [self.value] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_return(self) @@ -415,7 +431,7 @@ def __init__(self, line: int = -1) -> None: def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_unreachable(self) @@ -438,7 +454,7 @@ class RegisterOp(Op): def __init__(self, line: int) -> None: super().__init__(line) - assert self.error_kind != -1, 'error_kind not defined' + assert self.error_kind != -1, "error_kind not defined" def can_raise(self) -> bool: return self.error_kind != ERR_NEVER @@ -457,7 +473,7 @@ def __init__(self, src: Value, line: int = -1) -> None: def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_inc_ref(self) @@ -477,12 +493,12 @@ def __init__(self, src: Value, is_xdec: bool = False, line: int = -1) -> None: self.is_xdec = is_xdec def __repr__(self) -> str: - return '<{}DecRef {!r}>'.format('X' if self.is_xdec else '', self.src) + return "<{}DecRef {!r}>".format("X" if self.is_xdec else "", self.src) def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_dec_ref(self) @@ -492,7 +508,7 @@ class Call(RegisterOp): The call target can be a module-level function or a class. """ - def __init__(self, fn: 'FuncDecl', args: Sequence[Value], line: int) -> None: + def __init__(self, fn: "FuncDecl", args: Sequence[Value], line: int) -> None: self.fn = fn self.args = list(args) assert len(self.args) == len(fn.sig.args) @@ -507,18 +523,14 @@ def __init__(self, fn: 'FuncDecl', args: Sequence[Value], line: int) -> None: def sources(self) -> List[Value]: return list(self.args[:]) - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_call(self) class MethodCall(RegisterOp): """Native method call obj.method(arg, ...)""" - def __init__(self, - obj: Value, - method: str, - args: List[Value], - line: int = -1) -> None: + def __init__(self, obj: Value, method: str, args: List[Value], line: int = -1) -> None: self.obj = obj self.method = method self.args = args @@ -526,7 +538,8 @@ def __init__(self, self.receiver_type = obj.type method_ir = self.receiver_type.class_ir.method_sig(method) assert method_ir is not None, "{} doesn't have method {}".format( - self.receiver_type.name, method) + self.receiver_type.name, method + ) ret_type = method_ir.ret_type self.type = ret_type if not ret_type.error_overlap: @@ -538,7 +551,7 @@ def __init__(self, def sources(self) -> List[Value]: return self.args[:] + [self.obj] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_method_call(self) @@ -551,9 +564,9 @@ class LoadErrorValue(RegisterOp): error_kind = ERR_NEVER - def __init__(self, rtype: RType, line: int = -1, - is_borrowed: bool = False, - undefines: bool = False) -> None: + def __init__( + self, rtype: RType, line: int = -1, is_borrowed: bool = False, undefines: bool = False + ) -> None: super().__init__(line) self.type = rtype self.is_borrowed = is_borrowed @@ -565,7 +578,7 @@ def __init__(self, rtype: RType, line: int = -1, def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_error_value(self) @@ -590,16 +603,18 @@ class LoadLiteral(RegisterOp): error_kind = ERR_NEVER is_borrowed = True - def __init__(self, - value: Union[None, str, bytes, bool, int, float, complex, Tuple[object, ...]], - rtype: RType) -> None: + def __init__( + self, + value: Union[None, str, bytes, bool, int, float, complex, Tuple[object, ...]], + rtype: RType, + ) -> None: self.value = value self.type = rtype def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_literal(self) @@ -612,7 +627,7 @@ def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> super().__init__(line) self.obj = obj self.attr = attr - assert isinstance(obj.type, RInstance), 'Attribute access not supported: %s' % obj.type + assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type self.class_type = obj.type attr_type = obj.type.attr_type(attr) self.type = attr_type @@ -623,7 +638,7 @@ def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> def sources(self) -> List[Value]: return [self.obj] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_get_attr(self) @@ -640,7 +655,7 @@ def __init__(self, obj: Value, attr: str, src: Value, line: int) -> None: self.obj = obj self.attr = attr self.src = src - assert isinstance(obj.type, RInstance), 'Attribute access not supported: %s' % obj.type + assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type self.class_type = obj.type self.type = bool_rprimitive # If True, we can safely assume that the attribute is previously undefined @@ -658,7 +673,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_set_attr(self) @@ -686,13 +701,15 @@ class LoadStatic(RegisterOp): error_kind = ERR_NEVER is_borrowed = True - def __init__(self, - type: RType, - identifier: str, - module_name: Optional[str] = None, - namespace: str = NAMESPACE_STATIC, - line: int = -1, - ann: object = None) -> None: + def __init__( + self, + type: RType, + identifier: str, + module_name: Optional[str] = None, + namespace: str = NAMESPACE_STATIC, + line: int = -1, + ann: object = None, + ) -> None: super().__init__(line) self.identifier = identifier self.module_name = module_name @@ -703,7 +720,7 @@ def __init__(self, def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_static(self) @@ -715,12 +732,14 @@ class InitStatic(RegisterOp): error_kind = ERR_NEVER - def __init__(self, - value: Value, - identifier: str, - module_name: Optional[str] = None, - namespace: str = NAMESPACE_STATIC, - line: int = -1) -> None: + def __init__( + self, + value: Value, + identifier: str, + module_name: Optional[str] = None, + namespace: str = NAMESPACE_STATIC, + line: int = -1, + ) -> None: super().__init__(line) self.identifier = identifier self.module_name = module_name @@ -730,7 +749,7 @@ def __init__(self, def sources(self) -> List[Value]: return [self.value] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_init_static(self) @@ -746,14 +765,17 @@ def __init__(self, items: List[Value], line: int) -> None: # is put into a tuple, since we don't properly implement # runtime subtyping for tuples. self.tuple_type = RTuple( - [arg.type if not is_short_int_rprimitive(arg.type) else int_rprimitive - for arg in items]) + [ + arg.type if not is_short_int_rprimitive(arg.type) else int_rprimitive + for arg in items + ] + ) self.type = self.tuple_type def sources(self) -> List[Value]: return self.items[:] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_tuple_set(self) @@ -773,7 +795,7 @@ def __init__(self, src: Value, index: int, line: int) -> None: def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_tuple_get(self) @@ -801,7 +823,7 @@ def stolen(self) -> List[Value]: return [] return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_cast(self) @@ -819,9 +841,11 @@ def __init__(self, src: Value, line: int = -1) -> None: self.src = src self.type = object_rprimitive # When we box None and bool values, we produce a borrowed result - if (is_none_rprimitive(self.src.type) - or is_bool_rprimitive(self.src.type) - or is_bit_rprimitive(self.src.type)): + if ( + is_none_rprimitive(self.src.type) + or is_bool_rprimitive(self.src.type) + or is_bit_rprimitive(self.src.type) + ): self.is_borrowed = True def sources(self) -> List[Value]: @@ -830,7 +854,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_box(self) @@ -853,7 +877,7 @@ def __init__(self, src: Value, typ: RType, line: int) -> None: def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_unbox(self) @@ -884,7 +908,7 @@ def __init__(self, class_name: str, value: Optional[Union[str, Value]], line: in def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_raise_standard_error(self) @@ -900,15 +924,17 @@ class CallC(RegisterOp): functions. """ - def __init__(self, - function_name: str, - args: List[Value], - ret_type: RType, - steals: StealsDescription, - is_borrowed: bool, - error_kind: int, - line: int, - var_arg_idx: int = -1) -> None: + def __init__( + self, + function_name: str, + args: List[Value], + ret_type: RType, + steals: StealsDescription, + is_borrowed: bool, + error_kind: int, + line: int, + var_arg_idx: int = -1, + ) -> None: self.error_kind = error_kind super().__init__(line) self.function_name = function_name @@ -929,7 +955,7 @@ def stolen(self) -> List[Value]: else: return [] if not self.steals else self.sources() - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_call_c(self) @@ -944,10 +970,7 @@ class Truncate(RegisterOp): error_kind = ERR_NEVER - def __init__(self, - src: Value, - dst_type: RType, - line: int = -1) -> None: + def __init__(self, src: Value, dst_type: RType, line: int = -1) -> None: super().__init__(line) self.src = src self.type = dst_type @@ -959,7 +982,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_truncate(self) @@ -977,11 +1000,7 @@ class Extend(RegisterOp): error_kind = ERR_NEVER - def __init__(self, - src: Value, - dst_type: RType, - signed: bool, - line: int = -1) -> None: + def __init__(self, src: Value, dst_type: RType, signed: bool, line: int = -1) -> None: super().__init__(line) self.src = src self.type = dst_type @@ -994,7 +1013,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_extend(self) @@ -1009,11 +1028,7 @@ class LoadGlobal(RegisterOp): error_kind = ERR_NEVER is_borrowed = True - def __init__(self, - type: RType, - identifier: str, - line: int = -1, - ann: object = None) -> None: + def __init__(self, type: RType, identifier: str, line: int = -1, ann: object = None) -> None: super().__init__(line) self.identifier = identifier self.type = type @@ -1022,7 +1037,7 @@ def __init__(self, def sources(self) -> List[Value]: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_global(self) @@ -1056,16 +1071,16 @@ class IntOp(RegisterOp): RIGHT_SHIFT: Final = 204 op_str: Final = { - ADD: '+', - SUB: '-', - MUL: '*', - DIV: '/', - MOD: '%', - AND: '&', - OR: '|', - XOR: '^', - LEFT_SHIFT: '<<', - RIGHT_SHIFT: '>>', + ADD: "+", + SUB: "-", + MUL: "*", + DIV: "/", + MOD: "%", + AND: "&", + OR: "|", + XOR: "^", + LEFT_SHIFT: "<<", + RIGHT_SHIFT: ">>", } def __init__(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: @@ -1078,7 +1093,7 @@ def __init__(self, type: RType, lhs: Value, rhs: Value, op: int, line: int = -1) def sources(self) -> List[Value]: return [self.lhs, self.rhs] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_int_op(self) @@ -1116,26 +1131,19 @@ class ComparisonOp(RegisterOp): UGE: Final = 109 op_str: Final = { - EQ: '==', - NEQ: '!=', - SLT: '<', - SGT: '>', - SLE: '<=', - SGE: '>=', - ULT: '<', - UGT: '>', - ULE: '<=', - UGE: '>=', + EQ: "==", + NEQ: "!=", + SLT: "<", + SGT: ">", + SLE: "<=", + SGE: ">=", + ULT: "<", + UGT: ">", + ULE: "<=", + UGE: ">=", } - signed_ops: Final = { - '==': EQ, - '!=': NEQ, - '<': SLT, - '>': SGT, - '<=': SLE, - '>=': SGE, - } + signed_ops: Final = {"==": EQ, "!=": NEQ, "<": SLT, ">": SGT, "<=": SLE, ">=": SGE} def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: super().__init__(line) @@ -1147,7 +1155,7 @@ def __init__(self, lhs: Value, rhs: Value, op: int, line: int = -1) -> None: def sources(self) -> List[Value]: return [self.lhs, self.rhs] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_comparison_op(self) @@ -1173,7 +1181,7 @@ def __init__(self, type: RType, src: Value, line: int = -1) -> None: def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_mem(self) @@ -1188,11 +1196,7 @@ class SetMem(Op): error_kind = ERR_NEVER - def __init__(self, - type: RType, - dest: Value, - src: Value, - line: int = -1) -> None: + def __init__(self, type: RType, dest: Value, src: Value, line: int = -1) -> None: super().__init__(line) self.type = void_rtype self.dest_type = type @@ -1205,7 +1209,7 @@ def sources(self) -> List[Value]: def stolen(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_set_mem(self) @@ -1228,7 +1232,7 @@ def __init__(self, src: Value, src_type: RType, field: str, line: int = -1) -> N def sources(self) -> List[Value]: return [self.src] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_get_element_ptr(self) @@ -1255,7 +1259,7 @@ def sources(self) -> List[Value]: else: return [] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_load_address(self) @@ -1286,7 +1290,7 @@ def __init__(self, src: List[Value]) -> None: def sources(self) -> List[Value]: return self.src[:] - def accept(self, visitor: 'OpVisitor[T]') -> T: + def accept(self, visitor: "OpVisitor[T]") -> T: return visitor.visit_keep_alive(self) @@ -1449,5 +1453,6 @@ def visit_keep_alive(self, op: KeepAlive) -> T: # # (Serialization and deserialization *will* be used for incremental # compilation but so far it is not hooked up to anything.) -DeserMaps = NamedTuple('DeserMaps', - [('classes', Dict[str, 'ClassIR']), ('functions', Dict[str, 'FuncIR'])]) +DeserMaps = NamedTuple( + "DeserMaps", [("classes", Dict[str, "ClassIR"]), ("functions", Dict[str, "FuncIR"])] +) diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index e6cd721e4c27f..252499bb7fc77 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -1,21 +1,57 @@ """Utilities for pretty-printing IR in a human-readable form.""" from collections import defaultdict -from typing import Any, Dict, List, Union, Sequence, Tuple +from typing import Any, Dict, List, Sequence, Tuple, Union from typing_extensions import Final from mypyc.common import short_name -from mypyc.ir.ops import ( - Goto, Branch, Return, Unreachable, Assign, Integer, LoadErrorValue, GetAttr, SetAttr, - LoadStatic, InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast, Box, Unbox, - RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, SetMem, - GetElementPtr, LoadAddress, Register, Value, OpVisitor, BasicBlock, ControlOp, LoadLiteral, - AssignMulti, KeepAlive, Op, Extend, ERR_NEVER -) from mypyc.ir.func_ir import FuncIR, all_values_full from mypyc.ir.module_ir import ModuleIRs -from mypyc.ir.rtypes import is_bool_rprimitive, is_int_rprimitive, RType +from mypyc.ir.ops import ( + ERR_NEVER, + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + ControlOp, + DecRef, + Extend, + GetAttr, + GetElementPtr, + Goto, + IncRef, + InitStatic, + Integer, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadGlobal, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + OpVisitor, + RaiseStandardError, + Register, + Return, + SetAttr, + SetMem, + Truncate, + TupleGet, + TupleSet, + Unbox, + Unreachable, + Value, +) +from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive ErrorSource = Union[BasicBlock, Op] @@ -30,161 +66,156 @@ def __init__(self, names: Dict[Value, str]) -> None: self.names = names def visit_goto(self, op: Goto) -> str: - return self.format('goto %l', op.label) + return self.format("goto %l", op.label) - branch_op_names: Final = { - Branch.BOOL: ('%r', 'bool'), - Branch.IS_ERROR: ('is_error(%r)', ''), - } + branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")} def visit_branch(self, op: Branch) -> str: fmt, typ = self.branch_op_names[op.op] if op.negated: - fmt = f'not {fmt}' + fmt = f"not {fmt}" cond = self.format(fmt, op.value) - tb = '' + tb = "" if op.traceback_entry: - tb = ' (error at %s:%d)' % op.traceback_entry - fmt = f'if {cond} goto %l{tb} else goto %l' + tb = " (error at %s:%d)" % op.traceback_entry + fmt = f"if {cond} goto %l{tb} else goto %l" if typ: - fmt += f' :: {typ}' + fmt += f" :: {typ}" return self.format(fmt, op.true, op.false) def visit_return(self, op: Return) -> str: - return self.format('return %r', op.value) + return self.format("return %r", op.value) def visit_unreachable(self, op: Unreachable) -> str: return "unreachable" def visit_assign(self, op: Assign) -> str: - return self.format('%r = %r', op.dest, op.src) + return self.format("%r = %r", op.dest, op.src) def visit_assign_multi(self, op: AssignMulti) -> str: - return self.format('%r = [%s]', - op.dest, - ', '.join(self.format('%r', v) for v in op.src)) + return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src)) def visit_load_error_value(self, op: LoadErrorValue) -> str: - return self.format('%r = :: %s', op, op.type) + return self.format("%r = :: %s", op, op.type) def visit_load_literal(self, op: LoadLiteral) -> str: - prefix = '' + prefix = "" # For values that have a potential unboxed representation, make # it explicit that this is a Python object. if isinstance(op.value, int): - prefix = 'object ' - return self.format('%r = %s%s', op, prefix, repr(op.value)) + prefix = "object " + return self.format("%r = %s%s", op, prefix, repr(op.value)) def visit_get_attr(self, op: GetAttr) -> str: - return self.format('%r = %s%r.%s', op, self.borrow_prefix(op), op.obj, op.attr) + return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr) def borrow_prefix(self, op: Op) -> str: if op.is_borrowed: - return 'borrow ' - return '' + return "borrow " + return "" def visit_set_attr(self, op: SetAttr) -> str: if op.is_init: assert op.error_kind == ERR_NEVER # Initialization and direct struct access can never fail - return self.format('%r.%s = %r', op.obj, op.attr, op.src) + return self.format("%r.%s = %r", op.obj, op.attr, op.src) else: - return self.format('%r.%s = %r; %r = is_error', op.obj, op.attr, op.src, op) + return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op) def visit_load_static(self, op: LoadStatic) -> str: - ann = f' ({repr(op.ann)})' if op.ann else '' + ann = f" ({repr(op.ann)})" if op.ann else "" name = op.identifier if op.module_name is not None: - name = f'{op.module_name}.{name}' - return self.format('%r = %s :: %s%s', op, name, op.namespace, ann) + name = f"{op.module_name}.{name}" + return self.format("%r = %s :: %s%s", op, name, op.namespace, ann) def visit_init_static(self, op: InitStatic) -> str: name = op.identifier if op.module_name is not None: - name = f'{op.module_name}.{name}' - return self.format('%s = %r :: %s', name, op.value, op.namespace) + name = f"{op.module_name}.{name}" + return self.format("%s = %r :: %s", name, op.value, op.namespace) def visit_tuple_get(self, op: TupleGet) -> str: - return self.format('%r = %r[%d]', op, op.src, op.index) + return self.format("%r = %r[%d]", op, op.src, op.index) def visit_tuple_set(self, op: TupleSet) -> str: - item_str = ', '.join(self.format('%r', item) for item in op.items) - return self.format('%r = (%s)', op, item_str) + item_str = ", ".join(self.format("%r", item) for item in op.items) + return self.format("%r = (%s)", op, item_str) def visit_inc_ref(self, op: IncRef) -> str: - s = self.format('inc_ref %r', op.src) + s = self.format("inc_ref %r", op.src) # TODO: Remove bool check (it's unboxed) if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type): - s += f' :: {short_name(op.src.type.name)}' + s += f" :: {short_name(op.src.type.name)}" return s def visit_dec_ref(self, op: DecRef) -> str: - s = self.format('%sdec_ref %r', 'x' if op.is_xdec else '', op.src) + s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src) # TODO: Remove bool check (it's unboxed) if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type): - s += f' :: {short_name(op.src.type.name)}' + s += f" :: {short_name(op.src.type.name)}" return s def visit_call(self, op: Call) -> str: - args = ', '.join(self.format('%r', arg) for arg in op.args) + args = ", ".join(self.format("%r", arg) for arg in op.args) # TODO: Display long name? short_name = op.fn.shortname - s = f'{short_name}({args})' + s = f"{short_name}({args})" if not op.is_void: - s = self.format('%r = ', op) + s + s = self.format("%r = ", op) + s return s def visit_method_call(self, op: MethodCall) -> str: - args = ', '.join(self.format('%r', arg) for arg in op.args) - s = self.format('%r.%s(%s)', op.obj, op.method, args) + args = ", ".join(self.format("%r", arg) for arg in op.args) + s = self.format("%r.%s(%s)", op.obj, op.method, args) if not op.is_void: - s = self.format('%r = ', op) + s + s = self.format("%r = ", op) + s return s def visit_cast(self, op: Cast) -> str: - return self.format('%r = %scast(%s, %r)', op, self.borrow_prefix(op), op.type, op.src) + return self.format("%r = %scast(%s, %r)", op, self.borrow_prefix(op), op.type, op.src) def visit_box(self, op: Box) -> str: - return self.format('%r = box(%s, %r)', op, op.src.type, op.src) + return self.format("%r = box(%s, %r)", op, op.src.type, op.src) def visit_unbox(self, op: Unbox) -> str: - return self.format('%r = unbox(%s, %r)', op, op.type, op.src) + return self.format("%r = unbox(%s, %r)", op, op.type, op.src) def visit_raise_standard_error(self, op: RaiseStandardError) -> str: if op.value is not None: if isinstance(op.value, str): - return self.format('%r = raise %s(%s)', op, op.class_name, repr(op.value)) + return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value)) elif isinstance(op.value, Value): - return self.format('%r = raise %s(%r)', op, op.class_name, op.value) + return self.format("%r = raise %s(%r)", op, op.class_name, op.value) else: - assert False, 'value type must be either str or Value' + assert False, "value type must be either str or Value" else: - return self.format('%r = raise %s', op, op.class_name) + return self.format("%r = raise %s", op, op.class_name) def visit_call_c(self, op: CallC) -> str: - args_str = ', '.join(self.format('%r', arg) for arg in op.args) + args_str = ", ".join(self.format("%r", arg) for arg in op.args) if op.is_void: - return self.format('%s(%s)', op.function_name, args_str) + return self.format("%s(%s)", op.function_name, args_str) else: - return self.format('%r = %s(%s)', op, op.function_name, args_str) + return self.format("%r = %s(%s)", op, op.function_name, args_str) def visit_truncate(self, op: Truncate) -> str: return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type) def visit_extend(self, op: Extend) -> str: if op.signed: - extra = ' signed' + extra = " signed" else: - extra = '' + extra = "" return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type) def visit_load_global(self, op: LoadGlobal) -> str: - ann = f' ({repr(op.ann)})' if op.ann else '' - return self.format('%r = load_global %s :: static%s', op, op.identifier, ann) + ann = f" ({repr(op.ann)})" if op.ann else "" + return self.format("%r = load_global %s :: static%s", op, op.identifier, ann) def visit_int_op(self, op: IntOp) -> str: - return self.format('%r = %r %s %r', op, op.lhs, IntOp.op_str[op.op], op.rhs) + return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs) def visit_comparison_op(self, op: ComparisonOp) -> str: if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE): @@ -193,8 +224,9 @@ def visit_comparison_op(self, op: ComparisonOp) -> str: sign_format = " :: unsigned" else: sign_format = "" - return self.format('%r = %r %s %r%s', op, op.lhs, ComparisonOp.op_str[op.op], - op.rhs, sign_format) + return self.format( + "%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format + ) def visit_load_mem(self, op: LoadMem) -> str: return self.format("%r = load_mem %r :: %t*", op, op.src, op.type) @@ -212,8 +244,7 @@ def visit_load_address(self, op: LoadAddress) -> str: return self.format("%r = load_address %s", op, op.src) def visit_keep_alive(self, op: KeepAlive) -> str: - return self.format('keep_alive %s' % ', '.join(self.format('%r', v) - for v in op.src)) + return self.format("keep_alive %s" % ", ".join(self.format("%r", v) for v in op.src)) # Helpers @@ -233,47 +264,46 @@ def format(self, fmt: str, *args: Any) -> str: i = 0 arglist = list(args) while i < len(fmt): - n = fmt.find('%', i) + n = fmt.find("%", i) if n < 0: n = len(fmt) result.append(fmt[i:n]) if n < len(fmt): typespec = fmt[n + 1] arg = arglist.pop(0) - if typespec == 'r': + if typespec == "r": # Register/value assert isinstance(arg, Value) if isinstance(arg, Integer): result.append(str(arg.value)) else: result.append(self.names[arg]) - elif typespec == 'd': + elif typespec == "d": # Integer - result.append('%d' % arg) - elif typespec == 'f': + result.append("%d" % arg) + elif typespec == "f": # Float - result.append('%f' % arg) - elif typespec == 'l': + result.append("%f" % arg) + elif typespec == "l": # Basic block (label) assert isinstance(arg, BasicBlock) - result.append('L%s' % arg.label) - elif typespec == 't': + result.append("L%s" % arg.label) + elif typespec == "t": # RType assert isinstance(arg, RType) result.append(arg.name) - elif typespec == 's': + elif typespec == "s": # String result.append(str(arg)) else: - raise ValueError(f'Invalid format sequence %{typespec}') + raise ValueError(f"Invalid format sequence %{typespec}") i = n + 2 else: i = n - return ''.join(result) + return "".join(result) -def format_registers(func_ir: FuncIR, - names: Dict[Value, str]) -> List[str]: +def format_registers(func_ir: FuncIR, names: Dict[Value, str]) -> List[str]: result = [] i = 0 regs = all_values_full(func_ir.arg_regs, func_ir.blocks) @@ -284,13 +314,15 @@ def format_registers(func_ir: FuncIR, i += 1 group.append(names[regs[i]]) i += 1 - result.append('{} :: {}'.format(', '.join(group), regs[i0].type)) + result.append("{} :: {}".format(", ".join(group), regs[i0].type)) return result -def format_blocks(blocks: List[BasicBlock], - names: Dict[Value, str], - source_to_error: Dict[ErrorSource, List[str]]) -> List[str]: +def format_blocks( + blocks: List[BasicBlock], + names: Dict[Value, str], + source_to_error: Dict[ErrorSource, List[str]], +) -> List[str]: """Format a list of IR basic blocks into a human-readable form.""" # First label all of the blocks for i, block in enumerate(blocks): @@ -305,24 +337,27 @@ def format_blocks(blocks: List[BasicBlock], lines = [] for i, block in enumerate(blocks): - handler_msg = '' + handler_msg = "" if block in handler_map: - labels = sorted('L%d' % b.label for b in handler_map[block]) - handler_msg = ' (handler for {})'.format(', '.join(labels)) + labels = sorted("L%d" % b.label for b in handler_map[block]) + handler_msg = " (handler for {})".format(", ".join(labels)) - lines.append('L%d:%s' % (block.label, handler_msg)) + lines.append("L%d:%s" % (block.label, handler_msg)) if block in source_to_error: for error in source_to_error[block]: lines.append(f" ERR: {error}") ops = block.ops - if (isinstance(ops[-1], Goto) and i + 1 < len(blocks) - and ops[-1].label == blocks[i + 1] - and not source_to_error.get(ops[-1], [])): + if ( + isinstance(ops[-1], Goto) + and i + 1 < len(blocks) + and ops[-1].label == blocks[i + 1] + and not source_to_error.get(ops[-1], []) + ): # Hide the last goto if it just goes to the next basic block, # and there are no assocatiated errors with the op. ops = ops[:-1] for op in ops: - line = ' ' + op.accept(visitor) + line = " " + op.accept(visitor) lines.append(line) if op in source_to_error: for error in source_to_error[op]: @@ -330,18 +365,19 @@ def format_blocks(blocks: List[BasicBlock], if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)): # Each basic block needs to exit somewhere. - lines.append(' [MISSING BLOCK EXIT OPCODE]') + lines.append(" [MISSING BLOCK EXIT OPCODE]") return lines def format_func(fn: FuncIR, errors: Sequence[Tuple[ErrorSource, str]] = ()) -> List[str]: lines = [] - cls_prefix = fn.class_name + '.' if fn.class_name else '' - lines.append('def {}{}({}):'.format(cls_prefix, fn.name, - ', '.join(arg.name for arg in fn.args))) + cls_prefix = fn.class_name + "." if fn.class_name else "" + lines.append( + "def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args)) + ) names = generate_names_for_ir(fn.arg_regs, fn.blocks) for line in format_registers(fn, names): - lines.append(' ' + line) + lines.append(" " + line) source_to_error = defaultdict(list) for source, error in errors: @@ -357,7 +393,7 @@ def format_modules(modules: ModuleIRs) -> List[str]: for module in modules.values(): for fn in module.functions: ops.extend(format_func(fn)) - ops.append('') + ops.append("") return ops @@ -399,14 +435,14 @@ def generate_names_for_ir(args: List[Register], blocks: List[BasicBlock]) -> Dic elif isinstance(value, Integer): continue else: - name = 'r%d' % temp_index + name = "r%d" % temp_index temp_index += 1 # Append _2, _3, ... if needed to make the name unique. if name in used_names: n = 2 while True: - candidate = '%s_%d' % (name, n) + candidate = "%s_%d" % (name, n) if candidate not in used_names: name = candidate break diff --git a/mypyc/ir/rtypes.py b/mypyc/ir/rtypes.py index 010e25976f1c1..3247c2fb95f2c 100644 --- a/mypyc/ir/rtypes.py +++ b/mypyc/ir/rtypes.py @@ -21,18 +21,18 @@ """ from abc import abstractmethod -from typing import Optional, Union, List, Dict, Generic, TypeVar, Tuple +from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union -from typing_extensions import Final, ClassVar, TYPE_CHECKING +from typing_extensions import TYPE_CHECKING, ClassVar, Final -from mypyc.common import JsonDict, short_name, IS_32_BIT_PLATFORM, PLATFORM_SIZE +from mypyc.common import IS_32_BIT_PLATFORM, PLATFORM_SIZE, JsonDict, short_name from mypyc.namegen import NameGenerator if TYPE_CHECKING: - from mypyc.ir.ops import DeserMaps from mypyc.ir.class_ir import ClassIR + from mypyc.ir.ops import DeserMaps -T = TypeVar('T') +T = TypeVar("T") class RType: @@ -62,7 +62,7 @@ class RType: error_overlap = False @abstractmethod - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: raise NotImplementedError def short_name(self) -> str: @@ -72,13 +72,13 @@ def __str__(self) -> str: return short_name(self.name) def __repr__(self) -> str: - return '<%s>' % self.__class__.__name__ + return "<%s>" % self.__class__.__name__ def serialize(self) -> Union[JsonDict, str]: - raise NotImplementedError(f'Cannot serialize {self.__class__.__name__} instance') + raise NotImplementedError(f"Cannot serialize {self.__class__.__name__} instance") -def deserialize_type(data: Union[JsonDict, str], ctx: 'DeserMaps') -> 'RType': +def deserialize_type(data: Union[JsonDict, str], ctx: "DeserMaps") -> "RType": """Deserialize a JSON-serialized RType. Arguments: @@ -97,42 +97,42 @@ def deserialize_type(data: Union[JsonDict, str], ctx: 'DeserMaps') -> 'RType': return RVoid() else: assert False, f"Can't find class {data}" - elif data['.class'] == 'RTuple': + elif data[".class"] == "RTuple": return RTuple.deserialize(data, ctx) - elif data['.class'] == 'RUnion': + elif data[".class"] == "RUnion": return RUnion.deserialize(data, ctx) - raise NotImplementedError('unexpected .class {}'.format(data['.class'])) + raise NotImplementedError("unexpected .class {}".format(data[".class"])) class RTypeVisitor(Generic[T]): """Generic visitor over RTypes (uses the visitor design pattern).""" @abstractmethod - def visit_rprimitive(self, typ: 'RPrimitive') -> T: + def visit_rprimitive(self, typ: "RPrimitive") -> T: raise NotImplementedError @abstractmethod - def visit_rinstance(self, typ: 'RInstance') -> T: + def visit_rinstance(self, typ: "RInstance") -> T: raise NotImplementedError @abstractmethod - def visit_runion(self, typ: 'RUnion') -> T: + def visit_runion(self, typ: "RUnion") -> T: raise NotImplementedError @abstractmethod - def visit_rtuple(self, typ: 'RTuple') -> T: + def visit_rtuple(self, typ: "RTuple") -> T: raise NotImplementedError @abstractmethod - def visit_rstruct(self, typ: 'RStruct') -> T: + def visit_rstruct(self, typ: "RStruct") -> T: raise NotImplementedError @abstractmethod - def visit_rarray(self, typ: 'RArray') -> T: + def visit_rarray(self, typ: "RArray") -> T: raise NotImplementedError @abstractmethod - def visit_rvoid(self, typ: 'RVoid') -> T: + def visit_rvoid(self, typ: "RVoid") -> T: raise NotImplementedError @@ -144,14 +144,14 @@ class RVoid(RType): """ is_unboxed = False - name = 'void' - ctype = 'void' + name = "void" + ctype = "void" - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rvoid(self) def serialize(self) -> str: - return 'void' + return "void" def __eq__(self, other: object) -> bool: return isinstance(other, RVoid) @@ -181,16 +181,18 @@ class RPrimitive(RType): # Map from primitive names to primitive types and is used by deserialization primitive_map: ClassVar[Dict[str, "RPrimitive"]] = {} - def __init__(self, - name: str, - *, - is_unboxed: bool, - is_refcounted: bool, - is_native_int: bool = False, - is_signed: bool = False, - ctype: str = 'PyObject *', - size: int = PLATFORM_SIZE, - error_overlap: bool = False) -> None: + def __init__( + self, + name: str, + *, + is_unboxed: bool, + is_refcounted: bool, + is_native_int: bool = False, + is_signed: bool = False, + ctype: str = "PyObject *", + size: int = PLATFORM_SIZE, + error_overlap: bool = False, + ) -> None: RPrimitive.primitive_map[name] = self self.name = name @@ -201,34 +203,34 @@ def __init__(self, self._ctype = ctype self.size = size self.error_overlap = error_overlap - if ctype == 'CPyTagged': - self.c_undefined = 'CPY_INT_TAG' - elif ctype in ('int32_t', 'int64_t'): + if ctype == "CPyTagged": + self.c_undefined = "CPY_INT_TAG" + elif ctype in ("int32_t", "int64_t"): # This is basically an arbitrary value that is pretty # unlikely to overlap with a real value. - self.c_undefined = '-113' - elif ctype in ('CPyPtr', 'uint32_t', 'uint64_t'): + self.c_undefined = "-113" + elif ctype in ("CPyPtr", "uint32_t", "uint64_t"): # TODO: For low-level integers, we need to invent an overlapping # error value, similar to int64_t above. - self.c_undefined = '0' - elif ctype == 'PyObject *': + self.c_undefined = "0" + elif ctype == "PyObject *": # Boxed types use the null pointer as the error value. - self.c_undefined = 'NULL' - elif ctype == 'char': - self.c_undefined = '2' - elif ctype in ('PyObject **', 'void *'): - self.c_undefined = 'NULL' + self.c_undefined = "NULL" + elif ctype == "char": + self.c_undefined = "2" + elif ctype in ("PyObject **", "void *"): + self.c_undefined = "NULL" else: - assert False, 'Unrecognized ctype: %r' % ctype + assert False, "Unrecognized ctype: %r" % ctype - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rprimitive(self) def serialize(self) -> str: return self.name def __repr__(self) -> str: - return '' % self.name + return "" % self.name def __eq__(self, other: object) -> bool: return isinstance(other, RPrimitive) and other.name == self.name @@ -330,23 +332,23 @@ def __hash__(self) -> int: if IS_32_BIT_PLATFORM: c_size_t_rprimitive = uint32_rprimitive c_pyssize_t_rprimitive = RPrimitive( - 'native_int', + "native_int", is_unboxed=True, is_refcounted=False, is_native_int=True, is_signed=True, - ctype='int32_t', + ctype="int32_t", size=4, ) else: c_size_t_rprimitive = uint64_rprimitive c_pyssize_t_rprimitive = RPrimitive( - 'native_int', + "native_int", is_unboxed=True, is_refcounted=False, is_native_int=True, is_signed=True, - ctype='int64_t', + ctype="int64_t", size=8, ) @@ -354,8 +356,9 @@ def __hash__(self) -> int: pointer_rprimitive: Final = RPrimitive("ptr", is_unboxed=True, is_refcounted=False, ctype="CPyPtr") # Untyped pointer, represented as void * in the C backend -c_pointer_rprimitive: Final = RPrimitive("c_ptr", is_unboxed=False, is_refcounted=False, - ctype="void *") +c_pointer_rprimitive: Final = RPrimitive( + "c_ptr", is_unboxed=False, is_refcounted=False, ctype="void *" +) # Floats are represent as 'float' PyObject * values. (In the future # we'll likely switch to a more efficient, unboxed representation.) @@ -394,7 +397,7 @@ def __hash__(self) -> int: str_rprimitive: Final = RPrimitive("builtins.str", is_unboxed=False, is_refcounted=True) # Python bytes object. -bytes_rprimitive: Final = RPrimitive('builtins.bytes', is_unboxed=False, is_refcounted=True) +bytes_rprimitive: Final = RPrimitive("builtins.bytes", is_unboxed=False, is_refcounted=True) # Tuple of an arbitrary length (corresponds to Tuple[t, ...], with # explicit '...'). @@ -417,13 +420,15 @@ def is_short_int_rprimitive(rtype: RType) -> bool: def is_int32_rprimitive(rtype: RType) -> bool: - return (rtype is int32_rprimitive or - (rtype is c_pyssize_t_rprimitive and rtype._ctype == 'int32_t')) + return rtype is int32_rprimitive or ( + rtype is c_pyssize_t_rprimitive and rtype._ctype == "int32_t" + ) def is_int64_rprimitive(rtype: RType) -> bool: - return (rtype is int64_rprimitive or - (rtype is c_pyssize_t_rprimitive and rtype._ctype == 'int64_t')) + return rtype is int64_rprimitive or ( + rtype is c_pyssize_t_rprimitive and rtype._ctype == "int64_t" + ) def is_fixed_width_rtype(rtype: RType) -> bool: @@ -447,51 +452,51 @@ def is_pointer_rprimitive(rtype: RType) -> bool: def is_float_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.float' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.float" def is_bool_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bool' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.bool" def is_bit_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'bit' + return isinstance(rtype, RPrimitive) and rtype.name == "bit" def is_object_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.object' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object" def is_none_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.None' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.None" def is_list_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.list' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.list" def is_dict_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.dict' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict" def is_set_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.set' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.set" def is_str_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.str' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.str" def is_bytes_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.bytes' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.bytes" def is_tuple_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.tuple' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.tuple" def is_range_rprimitive(rtype: RType) -> bool: - return isinstance(rtype, RPrimitive) and rtype.name == 'builtins.range' + return isinstance(rtype, RPrimitive) and rtype.name == "builtins.range" def is_sequence_rprimitive(rtype: RType) -> bool: @@ -503,35 +508,35 @@ def is_sequence_rprimitive(rtype: RType) -> bool: class TupleNameVisitor(RTypeVisitor[str]): """Produce a tuple name based on the concrete representations of types.""" - def visit_rinstance(self, t: 'RInstance') -> str: + def visit_rinstance(self, t: "RInstance") -> str: return "O" - def visit_runion(self, t: 'RUnion') -> str: + def visit_runion(self, t: "RUnion") -> str: return "O" - def visit_rprimitive(self, t: 'RPrimitive') -> str: - if t._ctype == 'CPyTagged': - return 'I' - elif t._ctype == 'char': - return 'C' - elif t._ctype == 'int64_t': - return '8' # "8 byte integer" - elif t._ctype == 'int32_t': - return '4' # "4 byte integer" + def visit_rprimitive(self, t: "RPrimitive") -> str: + if t._ctype == "CPyTagged": + return "I" + elif t._ctype == "char": + return "C" + elif t._ctype == "int64_t": + return "8" # "8 byte integer" + elif t._ctype == "int32_t": + return "4" # "4 byte integer" assert not t.is_unboxed, f"{t} unexpected unboxed type" - return 'O' + return "O" - def visit_rtuple(self, t: 'RTuple') -> str: + def visit_rtuple(self, t: "RTuple") -> str: parts = [elem.accept(self) for elem in t.types] - return 'T{}{}'.format(len(parts), ''.join(parts)) + return "T{}{}".format(len(parts), "".join(parts)) - def visit_rstruct(self, t: 'RStruct') -> str: - assert False, 'RStruct not supported in tuple' + def visit_rstruct(self, t: "RStruct") -> str: + assert False, "RStruct not supported in tuple" - def visit_rarray(self, t: 'RArray') -> str: - assert False, 'RArray not supported in tuple' + def visit_rarray(self, t: "RArray") -> str: + assert False, "RArray not supported in tuple" - def visit_rvoid(self, t: 'RVoid') -> str: + def visit_rvoid(self, t: "RVoid") -> str: assert False, "rvoid in tuple?" @@ -553,7 +558,7 @@ class RTuple(RType): is_unboxed = True def __init__(self, types: List[RType]) -> None: - self.name = 'tuple' + self.name = "tuple" self.types = tuple(types) self.is_refcounted = any(t.is_refcounted for t in self.types) # Generate a unique id which is used in naming corresponding C identifiers. @@ -561,17 +566,17 @@ def __init__(self, types: List[RType]) -> None: # in the same way python can just assign a Tuple[int, bool] to a Tuple[int, bool]. self.unique_id = self.accept(TupleNameVisitor()) # Nominally the max c length is 31 chars, but I'm not honestly worried about this. - self.struct_name = f'tuple_{self.unique_id}' - self._ctype = f'{self.struct_name}' + self.struct_name = f"tuple_{self.unique_id}" + self._ctype = f"{self.struct_name}" - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rtuple(self) def __str__(self) -> str: - return 'tuple[%s]' % ', '.join(str(typ) for typ in self.types) + return "tuple[%s]" % ", ".join(str(typ) for typ in self.types) def __repr__(self) -> str: - return '' % ', '.join(repr(typ) for typ in self.types) + return "" % ", ".join(repr(typ) for typ in self.types) def __eq__(self, other: object) -> bool: return isinstance(other, RTuple) and self.types == other.types @@ -581,11 +586,11 @@ def __hash__(self) -> int: def serialize(self) -> JsonDict: types = [x.serialize() for x in self.types] - return {'.class': 'RTuple', 'types': types} + return {".class": "RTuple", "types": types} @classmethod - def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RTuple': - types = [deserialize_type(t, ctx) for t in data['types']] + def deserialize(cls, data: JsonDict, ctx: "DeserMaps") -> "RTuple": + types = [deserialize_type(t, ctx) for t in data["types"]] return RTuple(types) @@ -598,9 +603,7 @@ def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RTuple': [bool_rprimitive, short_int_rprimitive, object_rprimitive, object_rprimitive] ) # Same as above but just for key or value. -dict_next_rtuple_single = RTuple( - [bool_rprimitive, short_int_rprimitive, object_rprimitive] -) +dict_next_rtuple_single = RTuple([bool_rprimitive, short_int_rprimitive, object_rprimitive]) def compute_rtype_alignment(typ: RType) -> int: @@ -676,37 +679,40 @@ def compute_aligned_offsets_and_size(types: List[RType]) -> Tuple[List[int], int class RStruct(RType): """C struct type""" - def __init__(self, - name: str, - names: List[str], - types: List[RType]) -> None: + def __init__(self, name: str, names: List[str], types: List[RType]) -> None: self.name = name self.names = names self.types = types # generate dummy names if len(self.names) < len(self.types): for i in range(len(self.types) - len(self.names)): - self.names.append('_item' + str(i)) + self.names.append("_item" + str(i)) self.offsets, self.size = compute_aligned_offsets_and_size(types) self._ctype = name - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rstruct(self) def __str__(self) -> str: # if not tuple(unnamed structs) - return '{}{{{}}}'.format(self.name, ', '.join(name + ":" + str(typ) - for name, typ in zip(self.names, self.types))) + return "{}{{{}}}".format( + self.name, + ", ".join(name + ":" + str(typ) for name, typ in zip(self.names, self.types)), + ) def __repr__(self) -> str: - return ''.format( - self.name, ', '.join(name + ":" + repr(typ) - for name, typ in zip(self.names, self.types)) + return "".format( + self.name, + ", ".join(name + ":" + repr(typ) for name, typ in zip(self.names, self.types)), ) def __eq__(self, other: object) -> bool: - return (isinstance(other, RStruct) and self.name == other.name - and self.names == other.names and self.types == other.types) + return ( + isinstance(other, RStruct) + and self.name == other.name + and self.names == other.names + and self.types == other.types + ) def __hash__(self) -> int: return hash((self.name, tuple(self.names), tuple(self.types))) @@ -715,7 +721,7 @@ def serialize(self) -> JsonDict: assert False @classmethod - def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RStruct': + def deserialize(cls, data: JsonDict, ctx: "DeserMaps") -> "RStruct": assert False @@ -737,14 +743,14 @@ class RInstance(RType): is_unboxed = False - def __init__(self, class_ir: 'ClassIR') -> None: + def __init__(self, class_ir: "ClassIR") -> None: # name is used for formatting the name in messages and debug output # so we want the fullname for precision. self.name = class_ir.fullname self.class_ir = class_ir - self._ctype = 'PyObject *' + self._ctype = "PyObject *" - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rinstance(self) def struct_name(self, names: NameGenerator) -> str: @@ -763,7 +769,7 @@ def attr_type(self, name: str) -> RType: return self.class_ir.attr_type(name) def __repr__(self) -> str: - return '' % self.name + return "" % self.name def __eq__(self, other: object) -> bool: return isinstance(other, RInstance) and other.name == self.name @@ -781,34 +787,34 @@ class RUnion(RType): is_unboxed = False def __init__(self, items: List[RType]) -> None: - self.name = 'union' + self.name = "union" self.items = items self.items_set = frozenset(items) - self._ctype = 'PyObject *' + self._ctype = "PyObject *" - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_runion(self) def __repr__(self) -> str: - return '' % ', '.join(str(item) for item in self.items) + return "" % ", ".join(str(item) for item in self.items) def __str__(self) -> str: - return 'union[%s]' % ', '.join(str(item) for item in self.items) + return "union[%s]" % ", ".join(str(item) for item in self.items) # We compare based on the set because order in a union doesn't matter def __eq__(self, other: object) -> bool: return isinstance(other, RUnion) and self.items_set == other.items_set def __hash__(self) -> int: - return hash(('union', self.items_set)) + return hash(("union", self.items_set)) def serialize(self) -> JsonDict: types = [x.serialize() for x in self.items] - return {'.class': 'RUnion', 'types': types} + return {".class": "RUnion", "types": types} @classmethod - def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RUnion': - types = [deserialize_type(t, ctx) for t in data['types']] + def deserialize(cls, data: JsonDict, ctx: "DeserMaps") -> "RUnion": + types = [deserialize_type(t, ctx) for t in data["types"]] return RUnion(types) @@ -837,26 +843,27 @@ class RArray(RType): be only used for local variables that are initialized in one location. """ - def __init__(self, - item_type: RType, - length: int) -> None: + def __init__(self, item_type: RType, length: int) -> None: self.item_type = item_type # Number of items self.length = length self.is_refcounted = False - def accept(self, visitor: 'RTypeVisitor[T]') -> T: + def accept(self, visitor: "RTypeVisitor[T]") -> T: return visitor.visit_rarray(self) def __str__(self) -> str: - return f'{self.item_type}[{self.length}]' + return f"{self.item_type}[{self.length}]" def __repr__(self) -> str: - return f'' + return f"" def __eq__(self, other: object) -> bool: - return (isinstance(other, RArray) and self.item_type == other.item_type - and self.length == other.length) + return ( + isinstance(other, RArray) + and self.item_type == other.item_type + and self.length == other.length + ) def __hash__(self) -> int: return hash((self.item_type, self.length)) @@ -865,40 +872,54 @@ def serialize(self) -> JsonDict: assert False @classmethod - def deserialize(cls, data: JsonDict, ctx: 'DeserMaps') -> 'RArray': + def deserialize(cls, data: JsonDict, ctx: "DeserMaps") -> "RArray": assert False PyObject = RStruct( - name='PyObject', - names=['ob_refcnt', 'ob_type'], - types=[c_pyssize_t_rprimitive, pointer_rprimitive]) + name="PyObject", + names=["ob_refcnt", "ob_type"], + types=[c_pyssize_t_rprimitive, pointer_rprimitive], +) PyVarObject = RStruct( - name='PyVarObject', - names=['ob_base', 'ob_size'], - types=[PyObject, c_pyssize_t_rprimitive]) + name="PyVarObject", names=["ob_base", "ob_size"], types=[PyObject, c_pyssize_t_rprimitive] +) setentry = RStruct( - name='setentry', - names=['key', 'hash'], - types=[pointer_rprimitive, c_pyssize_t_rprimitive]) + name="setentry", names=["key", "hash"], types=[pointer_rprimitive, c_pyssize_t_rprimitive] +) -smalltable = RStruct( - name='smalltable', - names=[], - types=[setentry] * 8) +smalltable = RStruct(name="smalltable", names=[], types=[setentry] * 8) PySetObject = RStruct( - name='PySetObject', - names=['ob_base', 'fill', 'used', 'mask', 'table', 'hash', 'finger', - 'smalltable', 'weakreflist'], - types=[PyObject, c_pyssize_t_rprimitive, c_pyssize_t_rprimitive, c_pyssize_t_rprimitive, - pointer_rprimitive, c_pyssize_t_rprimitive, c_pyssize_t_rprimitive, smalltable, - pointer_rprimitive]) + name="PySetObject", + names=[ + "ob_base", + "fill", + "used", + "mask", + "table", + "hash", + "finger", + "smalltable", + "weakreflist", + ], + types=[ + PyObject, + c_pyssize_t_rprimitive, + c_pyssize_t_rprimitive, + c_pyssize_t_rprimitive, + pointer_rprimitive, + c_pyssize_t_rprimitive, + c_pyssize_t_rprimitive, + smalltable, + pointer_rprimitive, + ], +) PyListObject = RStruct( - name='PyListObject', - names=['ob_base', 'ob_item', 'allocated'], - types=[PyVarObject, pointer_rprimitive, c_pyssize_t_rprimitive] + name="PyListObject", + names=["ob_base", "ob_item", "allocated"], + types=[PyVarObject, pointer_rprimitive, c_pyssize_t_rprimitive], ) diff --git a/mypyc/irbuild/ast_helpers.py b/mypyc/irbuild/ast_helpers.py index 8c9ca186e46ae..0c2506d90f255 100644 --- a/mypyc/irbuild/ast_helpers.py +++ b/mypyc/irbuild/ast_helpers.py @@ -5,8 +5,18 @@ """ from mypy.nodes import ( - Expression, MemberExpr, Var, IntExpr, FloatExpr, StrExpr, BytesExpr, NameExpr, OpExpr, - UnaryExpr, ComparisonExpr, LDEF + LDEF, + BytesExpr, + ComparisonExpr, + Expression, + FloatExpr, + IntExpr, + MemberExpr, + NameExpr, + OpExpr, + StrExpr, + UnaryExpr, + Var, ) from mypyc.ir.ops import BasicBlock from mypyc.ir.rtypes import is_tagged @@ -14,10 +24,11 @@ from mypyc.irbuild.constant_fold import constant_fold_expr -def process_conditional(self: IRBuilder, e: Expression, true: BasicBlock, - false: BasicBlock) -> None: - if isinstance(e, OpExpr) and e.op in ['and', 'or']: - if e.op == 'and': +def process_conditional( + self: IRBuilder, e: Expression, true: BasicBlock, false: BasicBlock +) -> None: + if isinstance(e, OpExpr) and e.op in ["and", "or"]: + if e.op == "and": # Short circuit 'and' in a conditional context. new = BasicBlock() process_conditional(self, e.left, new, false) @@ -29,7 +40,7 @@ def process_conditional(self: IRBuilder, e: Expression, true: BasicBlock, process_conditional(self, e.left, true, new) self.activate_block(new) process_conditional(self, e.right, true, false) - elif isinstance(e, UnaryExpr) and e.op == 'not': + elif isinstance(e, UnaryExpr) and e.op == "not": process_conditional(self, e.expr, false, true) else: res = maybe_process_conditional_comparison(self, e, true, false) @@ -40,10 +51,9 @@ def process_conditional(self: IRBuilder, e: Expression, true: BasicBlock, self.add_bool_branch(reg, true, false) -def maybe_process_conditional_comparison(self: IRBuilder, - e: Expression, - true: BasicBlock, - false: BasicBlock) -> bool: +def maybe_process_conditional_comparison( + self: IRBuilder, e: Expression, true: BasicBlock, false: BasicBlock +) -> bool: """Transform simple tagged integer comparisons in a conditional context. Return True if the operation is supported (and was transformed). Otherwise, @@ -61,7 +71,7 @@ def maybe_process_conditional_comparison(self: IRBuilder, if not is_tagged(ltype) or not is_tagged(rtype): return False op = e.operators[0] - if op not in ('==', '!=', '<', '<=', '>', '>='): + if op not in ("==", "!=", "<", "<=", ">", ">="): return False left_expr = e.operands[0] right_expr = e.operands[1] @@ -81,8 +91,10 @@ def is_borrow_friendly_expr(self: IRBuilder, expr: Expression) -> bool: if isinstance(expr, (IntExpr, FloatExpr, StrExpr, BytesExpr)): # Literals are immortal and can always be borrowed return True - if (isinstance(expr, (UnaryExpr, OpExpr, NameExpr, MemberExpr)) and - constant_fold_expr(self, expr) is not None): + if ( + isinstance(expr, (UnaryExpr, OpExpr, NameExpr, MemberExpr)) + and constant_fold_expr(self, expr) is not None + ): # Literal expressions are similar to literals return True if isinstance(expr, NameExpr): diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index c1662d2fdac2c..d62c1700c78a3 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -11,65 +11,115 @@ functions are transformed in mypyc.irbuild.function. """ from contextlib import contextmanager +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union -from mypyc.irbuild.prepare import RegisterImplInfo -from typing import Callable, Dict, List, Tuple, Optional, Union, Sequence, Set, Any, Iterator -from typing_extensions import overload, Final -from mypy.backports import OrderedDict +from typing_extensions import Final, overload +from mypy.backports import OrderedDict from mypy.build import Graph +from mypy.maptype import map_instance_to_supertype from mypy.nodes import ( - MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr, - CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr, - TypeInfo, Decorator, OverloadedFuncDef, StarExpr, - GDEF, ArgKind, ARG_POS, ARG_NAMED, FuncDef, -) -from mypy.types import ( - Type, Instance, TupleType, UninhabitedType, get_proper_type + ARG_NAMED, + ARG_POS, + GDEF, + LDEF, + ArgKind, + CallExpr, + Decorator, + Expression, + FuncDef, + IndexExpr, + IntExpr, + Lvalue, + MemberExpr, + MypyFile, + NameExpr, + OpExpr, + OverloadedFuncDef, + RefExpr, + StarExpr, + Statement, + SymbolNode, + TupleExpr, + TypeInfo, + UnaryExpr, + Var, ) -from mypy.maptype import map_instance_to_supertype -from mypy.visitor import ExpressionVisitor, StatementVisitor +from mypy.types import Instance, TupleType, Type, UninhabitedType, get_proper_type from mypy.util import split_target - -from mypyc.common import TEMP_ATTR_NAME, SELF_NAME -from mypyc.irbuild.prebuildvisitor import PreBuildVisitor +from mypy.visitor import ExpressionVisitor, StatementVisitor +from mypyc.common import SELF_NAME, TEMP_ATTR_NAME +from mypyc.crash import catch_errors +from mypyc.errors import Errors +from mypyc.ir.class_ir import ClassIR, NonExtClassInfo +from mypyc.ir.func_ir import INVALID_FUNC_DEF, FuncDecl, FuncIR, FuncSignature, RuntimeArg from mypyc.ir.ops import ( - BasicBlock, Integer, Value, Register, Op, Assign, Branch, Unreachable, TupleGet, GetAttr, - SetAttr, LoadStatic, InitStatic, NAMESPACE_MODULE, RaiseStandardError + NAMESPACE_MODULE, + Assign, + BasicBlock, + Branch, + GetAttr, + InitStatic, + Integer, + LoadStatic, + Op, + RaiseStandardError, + Register, + SetAttr, + TupleGet, + Unreachable, + Value, ) from mypyc.ir.rtypes import ( - RType, RTuple, RInstance, c_int_rprimitive, int_rprimitive, dict_rprimitive, - none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive, - str_rprimitive, is_list_rprimitive, is_tuple_rprimitive, c_pyssize_t_rprimitive -) -from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF, RuntimeArg, FuncSignature, FuncDecl -from mypyc.ir.class_ir import ClassIR, NonExtClassInfo -from mypyc.primitives.registry import CFunctionDescription, function_ops -from mypyc.primitives.list_ops import to_list, list_pop_last, list_get_item_unsafe_op -from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op -from mypyc.primitives.generic_ops import py_setattr_op, iter_op, next_op -from mypyc.primitives.misc_ops import ( - import_op, check_unpack_count_op, get_module_dict_op, import_extra_args_op + RInstance, + RTuple, + RType, + c_int_rprimitive, + c_pyssize_t_rprimitive, + dict_rprimitive, + int_rprimitive, + is_list_rprimitive, + is_none_rprimitive, + is_object_rprimitive, + is_tuple_rprimitive, + none_rprimitive, + object_rprimitive, + str_rprimitive, ) -from mypyc.crash import catch_errors -from mypyc.options import CompilerOptions -from mypyc.errors import Errors +from mypyc.irbuild.context import FuncInfo, ImplicitClass +from mypyc.irbuild.ll_builder import LowLevelIRBuilder +from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.nonlocalcontrol import ( - NonlocalControl, BaseNonlocalControl, LoopNonlocalControl, GeneratorNonlocalControl + BaseNonlocalControl, + GeneratorNonlocalControl, + LoopNonlocalControl, + NonlocalControl, ) +from mypyc.irbuild.prebuildvisitor import PreBuildVisitor +from mypyc.irbuild.prepare import RegisterImplInfo from mypyc.irbuild.targets import ( - AssignmentTarget, AssignmentTargetRegister, AssignmentTargetIndex, AssignmentTargetAttr, - AssignmentTargetTuple + AssignmentTarget, + AssignmentTargetAttr, + AssignmentTargetIndex, + AssignmentTargetRegister, + AssignmentTargetTuple, ) -from mypyc.irbuild.context import FuncInfo, ImplicitClass -from mypyc.irbuild.mapper import Mapper -from mypyc.irbuild.ll_builder import LowLevelIRBuilder from mypyc.irbuild.util import is_constant - +from mypyc.options import CompilerOptions +from mypyc.primitives.dict_ops import dict_get_item_op, dict_set_item_op +from mypyc.primitives.generic_ops import iter_op, next_op, py_setattr_op +from mypyc.primitives.list_ops import list_get_item_unsafe_op, list_pop_last, to_list +from mypyc.primitives.misc_ops import ( + check_unpack_count_op, + get_module_dict_op, + import_extra_args_op, + import_op, +) +from mypyc.primitives.registry import CFunctionDescription, function_ops # These int binary operations can borrow their operands safely, since the # primitives take this into consideration. -int_borrow_friendly_op: Final = {'+', '-', '==', '!=', '<', '<=', '>', '>='} +int_borrow_friendly_op: Final = {"+", "-", "==", "!=", "<", "<=", ">", ">="} class IRVisitor(ExpressionVisitor[Value], StatementVisitor[None]): @@ -84,16 +134,18 @@ class UnsupportedException(Exception): class IRBuilder: - def __init__(self, - current_module: str, - types: Dict[Expression, Type], - graph: Graph, - errors: Errors, - mapper: Mapper, - pbv: PreBuildVisitor, - visitor: IRVisitor, - options: CompilerOptions, - singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]]) -> None: + def __init__( + self, + current_module: str, + types: Dict[Expression, Type], + graph: Graph, + errors: Errors, + mapper: Mapper, + pbv: PreBuildVisitor, + visitor: IRVisitor, + options: CompilerOptions, + singledispatch_impls: Dict[FuncDef, List[RegisterImplInfo]], + ) -> None: self.builder = LowLevelIRBuilder(current_module, mapper, options) self.builders = [self.builder] self.symtables: List[OrderedDict[SymbolNode, SymbolTarget]] = [OrderedDict()] @@ -133,7 +185,7 @@ def __init__(self, # and information about that function (e.g. whether it is nested, its environment class to # be generated) is stored in that FuncInfo instance. When the function is done being # generated, its corresponding FuncInfo is popped off the stack. - self.fn_info = FuncInfo(INVALID_FUNC_DEF, '', '') + self.fn_info = FuncInfo(INVALID_FUNC_DEF, "", "") self.fn_infos: List[FuncInfo] = [self.fn_info] # This list operates as a stack of constructs that modify the @@ -159,13 +211,16 @@ def set_module(self, module_name: str, module_path: str) -> None: self.module_path = module_path @overload - def accept(self, node: Expression, *, can_borrow: bool = False) -> Value: ... + def accept(self, node: Expression, *, can_borrow: bool = False) -> Value: + ... @overload - def accept(self, node: Statement) -> None: ... + def accept(self, node: Statement) -> None: + ... - def accept(self, node: Union[Statement, Expression], *, - can_borrow: bool = False) -> Optional[Value]: + def accept( + self, node: Union[Statement, Expression], *, can_borrow: bool = False + ) -> Optional[Value]: """Transform an expression or a statement. If can_borrow is true, prefer to generate a borrowed reference. @@ -229,7 +284,7 @@ def load_bytes_from_str_literal(self, value: str) -> Value: are stored in BytesExpr.value, whose type is 'str' not 'bytes'. Thus we perform a special conversion here. """ - bytes_value = bytes(value, 'utf8').decode('unicode-escape').encode('raw-unicode-escape') + bytes_value = bytes(value, "utf8").decode("unicode-escape").encode("raw-unicode-escape") return self.builder.load_bytes(bytes_value) def load_int(self, value: int) -> Value: @@ -262,19 +317,17 @@ def new_list_op(self, values: List[Value], line: int) -> Value: def new_set_op(self, values: List[Value], line: int) -> Value: return self.builder.new_set_op(values, line) - def translate_is_op(self, - lreg: Value, - rreg: Value, - expr_op: str, - line: int) -> Value: + def translate_is_op(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value: return self.builder.translate_is_op(lreg, rreg, expr_op, line) - def py_call(self, - function: Value, - arg_values: List[Value], - line: int, - arg_kinds: Optional[List[ArgKind]] = None, - arg_names: Optional[Sequence[Optional[str]]] = None) -> Value: + def py_call( + self, + function: Value, + arg_values: List[Value], + line: int, + arg_kinds: Optional[List[ArgKind]] = None, + arg_names: Optional[Sequence[Optional[str]]] = None, + ) -> Value: return self.builder.py_call(function, arg_values, line, arg_kinds, arg_names) def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None: @@ -283,14 +336,16 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> def load_native_type_object(self, fullname: str) -> Value: return self.builder.load_native_type_object(fullname) - def gen_method_call(self, - base: Value, - name: str, - arg_values: List[Value], - result_type: Optional[RType], - line: int, - arg_kinds: Optional[List[ArgKind]] = None, - arg_names: Optional[List[Optional[str]]] = None) -> Value: + def gen_method_call( + self, + base: Value, + name: str, + arg_values: List[Value], + result_type: Optional[RType], + line: int, + arg_kinds: Optional[List[ArgKind]] = None, + arg_names: Optional[List[Optional[str]]] = None, + ) -> Value: return self.builder.gen_method_call( base, name, arg_values, result_type, line, arg_kinds, arg_names, self.can_borrow ) @@ -318,14 +373,16 @@ def new_tuple(self, items: List[Value], line: int) -> Value: # Helpers for IR building - def add_to_non_ext_dict(self, non_ext: NonExtClassInfo, - key: str, val: Value, line: int) -> None: + def add_to_non_ext_dict( + self, non_ext: NonExtClassInfo, key: str, val: Value, line: int + ) -> None: # Add an attribute entry into the class dict of a non-extension class. key_unicode = self.load_str(key) self.call_c(dict_set_item_op, [non_ext.dict, key_unicode, val], line) - def gen_import_from(self, id: str, globals_dict: Value, - imported: List[str], line: int) -> Value: + def gen_import_from( + self, id: str, globals_dict: Value, imported: List[str], line: int + ) -> Value: self.imports[id] = None null_dict = Integer(0, dict_rprimitive, line) @@ -350,8 +407,9 @@ def gen_import(self, id: str, line: int) -> None: self.add(InitStatic(value, id, namespace=NAMESPACE_MODULE)) self.goto_and_activate(out) - def check_if_module_loaded(self, id: str, line: int, - needs_import: BasicBlock, out: BasicBlock) -> None: + def check_if_module_loaded( + self, id: str, line: int, needs_import: BasicBlock, out: BasicBlock + ) -> None: """Generate code that checks if the module `id` has been loaded yet. Arguments: @@ -360,15 +418,14 @@ def check_if_module_loaded(self, id: str, line: int, needs_import: the BasicBlock that is run if the module has not been loaded yet out: the BasicBlock that is run if the module has already been loaded""" first_load = self.load_module(id) - comparison = self.translate_is_op(first_load, self.none_object(), 'is not', line) + comparison = self.translate_is_op(first_load, self.none_object(), "is not", line) self.add_bool_branch(comparison, out, needs_import) def get_module(self, module: str, line: int) -> Value: # Python 3.7 has a nice 'PyImport_GetModule' function that we can't use :( mod_dict = self.call_c(get_module_dict_op, [], line) # Get module object from modules dict. - return self.call_c(dict_get_item_op, - [mod_dict, self.load_str(module)], line) + return self.call_c(dict_get_item_op, [mod_dict, self.load_str(module)], line) def get_module_attr(self, module: str, attr: str, line: int) -> Value: """Look up an attribute of a module without storing it in the local namespace. @@ -382,8 +439,7 @@ def get_module_attr(self, module: str, attr: str, line: int) -> Value: module_obj = self.get_module(module, line) return self.py_get_attr(module_obj, attr, line) - def assign_if_null(self, target: Register, - get_val: Callable[[], Value], line: int) -> None: + def assign_if_null(self, target: Register, get_val: Callable[[], Value], line: int) -> None: """If target is NULL, assign value produced by get_val to it.""" error_block, body_block = BasicBlock(), BasicBlock() self.add(Branch(target, error_block, body_block, Branch.IS_ERROR)) @@ -415,48 +471,56 @@ def disallow_class_assignments(self, lvalues: List[Lvalue], line: int) -> None: # miscompile the interaction between instance and class # variables. for lvalue in lvalues: - if (isinstance(lvalue, MemberExpr) - and isinstance(lvalue.expr, RefExpr) - and isinstance(lvalue.expr.node, TypeInfo)): + if ( + isinstance(lvalue, MemberExpr) + and isinstance(lvalue.expr, RefExpr) + and isinstance(lvalue.expr.node, TypeInfo) + ): var = lvalue.expr.node[lvalue.name].node if isinstance(var, Var) and not var.is_classvar: - self.error( - "Only class variables defined as ClassVar can be assigned to", - line) + self.error("Only class variables defined as ClassVar can be assigned to", line) def non_function_scope(self) -> bool: # Currently the stack always has at least two items: dummy and top-level. return len(self.fn_infos) <= 2 - def init_final_static(self, - lvalue: Lvalue, - rvalue_reg: Value, - class_name: Optional[str] = None, - *, - type_override: Optional[RType] = None) -> None: + def init_final_static( + self, + lvalue: Lvalue, + rvalue_reg: Value, + class_name: Optional[str] = None, + *, + type_override: Optional[RType] = None, + ) -> None: assert isinstance(lvalue, NameExpr) assert isinstance(lvalue.node, Var) if lvalue.node.final_value is None: if class_name is None: name = lvalue.name else: - name = f'{class_name}.{lvalue.name}' + name = f"{class_name}.{lvalue.name}" assert name is not None, "Full name not set for variable" coerced = self.coerce(rvalue_reg, type_override or self.node_type(lvalue), lvalue.line) self.final_names.append((name, coerced.type)) self.add(InitStatic(coerced, name, self.module_name)) - def load_final_static(self, fullname: str, typ: RType, line: int, - error_name: Optional[str] = None) -> Value: + def load_final_static( + self, fullname: str, typ: RType, line: int, error_name: Optional[str] = None + ) -> Value: split_name = split_target(self.graph, fullname) assert split_name is not None module, name = split_name return self.builder.load_static_checked( - typ, name, module, line=line, - error_msg=f'value for final name "{error_name}" was not set') + typ, + name, + module, + line=line, + error_msg=f'value for final name "{error_name}" was not set', + ) - def load_final_literal_value(self, val: Union[int, str, bytes, float, bool], - line: int) -> Value: + def load_final_literal_value( + self, val: Union[int, str, bytes, float, bool], line: int + ) -> Value: """Load value of a final name or class-level attribute.""" if isinstance(val, bool): if val: @@ -476,8 +540,7 @@ def load_final_literal_value(self, val: Union[int, str, bytes, float, bool], else: assert False, "Unsupported final literal value" - def get_assignment_target(self, lvalue: Lvalue, - line: int = -1) -> AssignmentTarget: + def get_assignment_target(self, lvalue: Lvalue, line: int = -1) -> AssignmentTarget: if isinstance(lvalue, NameExpr): # If we are visiting a decorator, then the SymbolNode we really want to be looking at # is the function that is decorated, not the entire Decorator node itself. @@ -496,9 +559,12 @@ def get_assignment_target(self, lvalue: Lvalue, # target to the table containing class environment variables, as well as the # current environment. if self.fn_info.is_generator: - return self.add_var_to_env_class(symbol, self.node_type(lvalue), - self.fn_info.generator_class, - reassign=False) + return self.add_var_to_env_class( + symbol, + self.node_type(lvalue), + self.fn_info.generator_class, + reassign=False, + ) # Otherwise define a new local variable. return self.add_local_reg(symbol, self.node_type(lvalue)) @@ -538,19 +604,19 @@ def get_assignment_target(self, lvalue: Lvalue, elif isinstance(lvalue, StarExpr): return self.get_assignment_target(lvalue.expr) - assert False, 'Unsupported lvalue: %r' % lvalue + assert False, "Unsupported lvalue: %r" % lvalue - def read(self, - target: Union[Value, AssignmentTarget], - line: int = -1, - can_borrow: bool = False) -> Value: + def read( + self, target: Union[Value, AssignmentTarget], line: int = -1, can_borrow: bool = False + ) -> Value: if isinstance(target, Value): return target if isinstance(target, AssignmentTargetRegister): return target.register if isinstance(target, AssignmentTargetIndex): reg = self.gen_method_call( - target.base, '__getitem__', [target.index], target.type, line) + target.base, "__getitem__", [target.index], target.type, line + ) if reg is not None: return reg assert False, target.base.type @@ -561,12 +627,11 @@ def read(self, else: return self.py_get_attr(target.obj, target.attr, line) - assert False, 'Unsupported lvalue: %r' % target + assert False, "Unsupported lvalue: %r" % target - def assign(self, - target: Union[Register, AssignmentTarget], - rvalue_reg: Value, - line: int) -> None: + def assign( + self, target: Union[Register, AssignmentTarget], rvalue_reg: Value, line: int + ) -> None: if isinstance(target, Register): self.add(Assign(target, self.coerce(rvalue_reg, target.type, line))) elif isinstance(target, AssignmentTargetRegister): @@ -582,7 +647,8 @@ def assign(self, self.call_c(py_setattr_op, [target.obj, key, boxed_reg], line) elif isinstance(target, AssignmentTargetIndex): target_reg2 = self.gen_method_call( - target.base, '__setitem__', [target.index, rvalue_reg], None, line) + target.base, "__setitem__", [target.index, rvalue_reg], None, line + ) assert target_reg2 is not None, target.base.type elif isinstance(target, AssignmentTargetTuple): if isinstance(rvalue_reg.type, RTuple) and target.star_idx is None: @@ -591,18 +657,18 @@ def assign(self, for i in range(len(rtypes)): item_value = self.add(TupleGet(rvalue_reg, i, line)) self.assign(target.items[i], item_value, line) - elif ((is_list_rprimitive(rvalue_reg.type) or is_tuple_rprimitive(rvalue_reg.type)) - and target.star_idx is None): + elif ( + is_list_rprimitive(rvalue_reg.type) or is_tuple_rprimitive(rvalue_reg.type) + ) and target.star_idx is None: self.process_sequence_assignment(target, rvalue_reg, line) else: self.process_iterator_tuple_assignment(target, rvalue_reg, line) else: - assert False, 'Unsupported assignment target' + assert False, "Unsupported assignment target" - def process_sequence_assignment(self, - target: AssignmentTargetTuple, - rvalue: Value, - line: int) -> None: + def process_sequence_assignment( + self, target: AssignmentTargetTuple, rvalue: Value, line: int + ) -> None: """Process assignment like 'x, y = s', where s is a variable-length list or tuple.""" # Check the length of sequence. expected_len = Integer(len(target.items), c_pyssize_t_rprimitive) @@ -617,31 +683,32 @@ def process_sequence_assignment(self, item_value = self.call_c(list_get_item_unsafe_op, [rvalue, index], line) else: item_value = self.builder.gen_method_call( - rvalue, '__getitem__', [index], item.type, line) + rvalue, "__getitem__", [index], item.type, line + ) values.append(item_value) # Assign sequence items to the target lvalues. for lvalue, value in zip(target.items, values): self.assign(lvalue, value, line) - def process_iterator_tuple_assignment_helper(self, - litem: AssignmentTarget, - ritem: Value, line: int) -> None: + def process_iterator_tuple_assignment_helper( + self, litem: AssignmentTarget, ritem: Value, line: int + ) -> None: error_block, ok_block = BasicBlock(), BasicBlock() self.add(Branch(ritem, error_block, ok_block, Branch.IS_ERROR)) self.activate_block(error_block) - self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, - 'not enough values to unpack', line)) + self.add( + RaiseStandardError(RaiseStandardError.VALUE_ERROR, "not enough values to unpack", line) + ) self.add(Unreachable()) self.activate_block(ok_block) self.assign(litem, ritem, line) - def process_iterator_tuple_assignment(self, - target: AssignmentTargetTuple, - rvalue_reg: Value, - line: int) -> None: + def process_iterator_tuple_assignment( + self, target: AssignmentTargetTuple, rvalue_reg: Value, line: int + ) -> None: iterator = self.call_c(iter_op, [rvalue_reg], line) @@ -655,8 +722,11 @@ def process_iterator_tuple_assignment(self, self.add(Branch(ritem, error_block, ok_block, Branch.IS_ERROR)) self.activate_block(error_block) - self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, - 'not enough values to unpack', line)) + self.add( + RaiseStandardError( + RaiseStandardError.VALUE_ERROR, "not enough values to unpack", line + ) + ) self.add(Unreachable()) self.activate_block(ok_block) @@ -665,18 +735,21 @@ def process_iterator_tuple_assignment(self, # Assign the starred value and all values after it if target.star_idx is not None: - post_star_vals = target.items[split_idx + 1:] + post_star_vals = target.items[split_idx + 1 :] iter_list = self.call_c(to_list, [iterator], line) iter_list_len = self.builtin_len(iter_list, line) post_star_len = Integer(len(post_star_vals)) - condition = self.binary_op(post_star_len, iter_list_len, '<=', line) + condition = self.binary_op(post_star_len, iter_list_len, "<=", line) error_block, ok_block = BasicBlock(), BasicBlock() self.add(Branch(condition, ok_block, error_block, Branch.BOOL)) self.activate_block(error_block) - self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, - 'not enough values to unpack', line)) + self.add( + RaiseStandardError( + RaiseStandardError.VALUE_ERROR, "not enough values to unpack", line + ) + ) self.add(Unreachable()) self.activate_block(ok_block) @@ -696,22 +769,26 @@ def process_iterator_tuple_assignment(self, self.add(Branch(extra, ok_block, error_block, Branch.IS_ERROR)) self.activate_block(error_block) - self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, - 'too many values to unpack', line)) + self.add( + RaiseStandardError( + RaiseStandardError.VALUE_ERROR, "too many values to unpack", line + ) + ) self.add(Unreachable()) self.activate_block(ok_block) def push_loop_stack(self, continue_block: BasicBlock, break_block: BasicBlock) -> None: self.nonlocal_control.append( - LoopNonlocalControl(self.nonlocal_control[-1], continue_block, break_block)) + LoopNonlocalControl(self.nonlocal_control[-1], continue_block, break_block) + ) def pop_loop_stack(self) -> None: self.nonlocal_control.pop() def spill(self, value: Value) -> AssignmentTarget: """Moves a given Value instance into the generator class' environment class.""" - name = f'{TEMP_ATTR_NAME}{self.temp_counter}' + name = f"{TEMP_ATTR_NAME}{self.temp_counter}" self.temp_counter += 1 target = self.add_var_to_env_class(Var(name), value.type, self.fn_info.generator_class) # Shouldn't be able to fail, so -1 for line @@ -752,7 +829,7 @@ def maybe_spill_assignable(self, value: Value) -> Union[Register, AssignmentTarg def extract_int(self, e: Expression) -> Optional[int]: if isinstance(e, IntExpr): return e.value - elif isinstance(e, UnaryExpr) and e.op == '-' and isinstance(e.expr, IntExpr): + elif isinstance(e, UnaryExpr) and e.op == "-" and isinstance(e.expr, IntExpr): return -e.expr.value else: return None @@ -760,7 +837,7 @@ def extract_int(self, e: Expression) -> Optional[int]: def get_sequence_type(self, expr: Expression) -> RType: target_type = get_proper_type(self.types[expr]) assert isinstance(target_type, Instance) - if target_type.type.fullname == 'builtins.str': + if target_type.type.fullname == "builtins.str": return str_rprimitive else: return self.type_to_rtype(target_type.args[0]) @@ -772,8 +849,7 @@ def get_dict_base_type(self, expr: Expression) -> Instance: """ target_type = get_proper_type(self.types[expr]) assert isinstance(target_type, Instance) - dict_base = next(base for base in target_type.type.mro - if base.fullname == 'builtins.dict') + dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict") return map_instance_to_supertype(target_type, dict_base) def get_dict_key_type(self, expr: Expression) -> RType: @@ -794,9 +870,10 @@ def _analyze_iterable_item_type(self, expr: Expression) -> Type: # This logic is copied from mypy's TypeChecker.analyze_iterable_item_type. iterable = get_proper_type(self.types[expr]) echk = self.graph[self.module_name].type_checker().expr_checker - iterator = echk.check_method_call_by_name('__iter__', iterable, [], [], expr)[0] + iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], expr)[0] from mypy.join import join_types + if isinstance(iterable, TupleType): joined: Type = UninhabitedType() for item in iterable.items: @@ -804,7 +881,7 @@ def _analyze_iterable_item_type(self, expr: Expression) -> Type: return joined else: # Non-tuple iterable. - return echk.check_method_call_by_name('__next__', iterator, [], [], expr)[0] + return echk.check_method_call_by_name("__next__", iterator, [], [], expr)[0] def is_native_module(self, module: str) -> bool: """Is the given module one compiled by mypyc?""" @@ -813,8 +890,8 @@ def is_native_module(self, module: str) -> bool: def is_native_ref_expr(self, expr: RefExpr) -> bool: if expr.node is None: return False - if '.' in expr.node.fullname: - return self.is_native_module(expr.node.fullname.rpartition('.')[0]) + if "." in expr.node.fullname: + return self.is_native_module(expr.node.fullname.rpartition(".")[0]) return True def is_native_module_ref_expr(self, expr: RefExpr) -> bool: @@ -840,10 +917,10 @@ def get_final_ref(self, expr: MemberExpr) -> Optional[Tuple[str, Var, bool]]: if sym and isinstance(sym.node, Var): # Enum attribute are treated as final since they are added to the global cache expr_fullname = expr.expr.node.bases[0].type.fullname - is_final = sym.node.is_final or expr_fullname == 'enum.Enum' + is_final = sym.node.is_final or expr_fullname == "enum.Enum" if is_final: final_var = sym.node - fullname = f'{sym.node.info.fullname}.{final_var.name}' + fullname = f"{sym.node.info.fullname}.{final_var.name}" native = self.is_native_module(expr.expr.node.module_name) elif self.is_module_member_expr(expr): # a module attribute @@ -855,8 +932,9 @@ def get_final_ref(self, expr: MemberExpr) -> Optional[Tuple[str, Var, bool]]: return fullname, final_var, native return None - def emit_load_final(self, final_var: Var, fullname: str, - name: str, native: bool, typ: Type, line: int) -> Optional[Value]: + def emit_load_final( + self, final_var: Var, fullname: str, name: str, native: bool, typ: Type, line: int + ) -> Optional[Value]: """Emit code for loading value of a final name (if possible). Args: @@ -870,8 +948,7 @@ def emit_load_final(self, final_var: Var, fullname: str, if final_var.final_value is not None: # this is safe even for non-native names return self.load_final_literal_value(final_var.final_value, line) elif native: - return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), - line, name) + return self.load_final_static(fullname, self.mapper.type_to_rtype(typ), line, name) else: return None @@ -879,13 +956,15 @@ def is_module_member_expr(self, expr: MemberExpr) -> bool: return isinstance(expr.expr, RefExpr) and isinstance(expr.expr.node, MypyFile) def call_refexpr_with_args( - self, expr: CallExpr, callee: RefExpr, arg_values: List[Value]) -> Value: + self, expr: CallExpr, callee: RefExpr, arg_values: List[Value] + ) -> Value: # Handle data-driven special-cased primitive call ops. if callee.fullname is not None and expr.arg_kinds == [ARG_POS] * len(arg_values): call_c_ops_candidates = function_ops.get(callee.fullname, []) - target = self.builder.matching_call_c(call_c_ops_candidates, arg_values, - expr.line, self.node_type(expr)) + target = self.builder.matching_call_c( + call_c_ops_candidates, arg_values, expr.line, self.node_type(expr) + ) if target: return target @@ -904,24 +983,28 @@ def call_refexpr_with_args( and callee_node.func in self.singledispatch_impls ): callee_node = callee_node.func - if (callee_node is not None - and callee.fullname is not None - and callee_node in self.mapper.func_to_decl - and all(kind in (ARG_POS, ARG_NAMED) for kind in expr.arg_kinds)): + if ( + callee_node is not None + and callee.fullname is not None + and callee_node in self.mapper.func_to_decl + and all(kind in (ARG_POS, ARG_NAMED) for kind in expr.arg_kinds) + ): decl = self.mapper.func_to_decl[callee_node] return self.builder.call(decl, arg_values, expr.arg_kinds, expr.arg_names, expr.line) # Fall back to a Python call function = self.accept(callee) - return self.py_call(function, arg_values, expr.line, - arg_kinds=expr.arg_kinds, arg_names=expr.arg_names) + return self.py_call( + function, arg_values, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names + ) def shortcircuit_expr(self, expr: OpExpr) -> Value: return self.builder.shortcircuit_helper( - expr.op, self.node_type(expr), + expr.op, + self.node_type(expr), lambda: self.accept(expr.left), lambda: self.accept(expr.right), - expr.line + expr.line, ) # Basic helpers @@ -949,7 +1032,7 @@ def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[Class return None return res - def enter(self, fn_info: Union[FuncInfo, str] = '') -> None: + def enter(self, fn_info: Union[FuncInfo, str] = "") -> None: if isinstance(fn_info, str): fn_info = FuncInfo(name=fn_info) self.builder = LowLevelIRBuilder(self.current_module, self.mapper, self.options) @@ -977,12 +1060,14 @@ def leave(self) -> Tuple[List[Register], List[RuntimeArg], List[BasicBlock], RTy return builder.args, runtime_args, builder.blocks, ret_type, fn_info @contextmanager - def enter_method(self, - class_ir: ClassIR, - name: str, - ret_type: RType, - fn_info: Union[FuncInfo, str] = '', - self_type: Optional[RType] = None) -> Iterator[None]: + def enter_method( + self, + class_ir: ClassIR, + name: str, + ret_type: RType, + fn_info: Union[FuncInfo, str] = "", + self_type: Optional[RType] = None, + ) -> Iterator[None]: """Generate IR for a method. If the method takes arguments, you should immediately afterwards call @@ -1030,7 +1115,7 @@ def add_argument(self, var: Union[str, Var], typ: RType, kind: ArgKind = ARG_POS def lookup(self, symbol: SymbolNode) -> SymbolTarget: return self.symtables[-1][symbol] - def add_local(self, symbol: SymbolNode, typ: RType, is_arg: bool = False) -> 'Register': + def add_local(self, symbol: SymbolNode, typ: RType, is_arg: bool = False) -> "Register": """Add register that represents a symbol to the symbol table. Args: @@ -1038,20 +1123,16 @@ def add_local(self, symbol: SymbolNode, typ: RType, is_arg: bool = False) -> 'Re """ assert isinstance(symbol, SymbolNode) reg = Register( - typ, - remangle_redefinition_name(symbol.name), - is_arg=is_arg, - line=symbol.line, + typ, remangle_redefinition_name(symbol.name), is_arg=is_arg, line=symbol.line ) self.symtables[-1][symbol] = AssignmentTargetRegister(reg) if is_arg: self.builder.args.append(reg) return reg - def add_local_reg(self, - symbol: SymbolNode, - typ: RType, - is_arg: bool = False) -> AssignmentTargetRegister: + def add_local_reg( + self, symbol: SymbolNode, typ: RType, is_arg: bool = False + ) -> AssignmentTargetRegister: """Like add_local, but return an assignment target instead of value.""" self.add_local(symbol, typ, is_arg) target = self.symtables[-1][symbol] @@ -1081,11 +1162,13 @@ def node_type(self, node: Expression) -> RType: mypy_type = self.types[node] return self.type_to_rtype(mypy_type) - def add_var_to_env_class(self, - var: SymbolNode, - rtype: RType, - base: Union[FuncInfo, ImplicitClass], - reassign: bool = False) -> AssignmentTarget: + def add_var_to_env_class( + self, + var: SymbolNode, + rtype: RType, + base: Union[FuncInfo, ImplicitClass], + reassign: bool = False, + ) -> AssignmentTarget: # First, define the variable name as an attribute of the environment class, and then # construct a target for that attribute. self.fn_info.env_class.attributes[var.name] = rtype @@ -1103,7 +1186,7 @@ def add_var_to_env_class(self, def is_builtin_ref_expr(self, expr: RefExpr) -> bool: assert expr.node, "RefExpr not resolved" - return '.' in expr.node.fullname and expr.node.fullname.split('.')[0] == 'builtins' + return "." in expr.node.fullname and expr.node.fullname.split(".")[0] == "builtins" def load_global(self, expr: NameExpr) -> Value: """Loads a Python-level global. @@ -1115,8 +1198,11 @@ def load_global(self, expr: NameExpr) -> Value: if self.is_builtin_ref_expr(expr): assert expr.node, "RefExpr not resolved" return self.load_module_attr_by_fullname(expr.node.fullname, expr.line) - if (self.is_native_module_ref_expr(expr) and isinstance(expr.node, TypeInfo) - and not self.is_synthetic_type(expr.node)): + if ( + self.is_native_module_ref_expr(expr) + and isinstance(expr.node, TypeInfo) + and not self.is_synthetic_type(expr.node) + ): assert expr.fullname is not None return self.load_native_type_object(expr.fullname) return self.load_global_str(expr.name, expr.line) @@ -1127,20 +1213,22 @@ def load_global_str(self, name: str, line: int) -> Value: return self.call_c(dict_get_item_op, [_globals, reg], line) def load_globals_dict(self) -> Value: - return self.add(LoadStatic(dict_rprimitive, 'globals', self.module_name)) + return self.add(LoadStatic(dict_rprimitive, "globals", self.module_name)) def load_module_attr_by_fullname(self, fullname: str, line: int) -> Value: - module, _, name = fullname.rpartition('.') + module, _, name = fullname.rpartition(".") left = self.load_module(module) return self.py_get_attr(left, name, line) def is_native_attr_ref(self, expr: MemberExpr) -> bool: """Is expr a direct reference to a native (struct) attribute of an instance?""" obj_rtype = self.node_type(expr.expr) - return (isinstance(obj_rtype, RInstance) - and obj_rtype.class_ir.is_ext_class - and obj_rtype.class_ir.has_attr(expr.name) - and not obj_rtype.class_ir.get_method(expr.name)) + return ( + isinstance(obj_rtype, RInstance) + and obj_rtype.class_ir.is_ext_class + and obj_rtype.class_ir.has_attr(expr.name) + and not obj_rtype.class_ir.get_method(expr.name) + ) # Lacks a good type because there wasn't a reasonable type in 3.5 :( def catch_errors(self, line: int) -> Any: @@ -1177,14 +1265,16 @@ def get_default() -> Value: # Because gen_arg_defaults runs before calculate_arg_defaults, we # add the static/attribute to final_names/the class here. elif not builder.fn_info.is_nested: - name = fitem.fullname + '.' + arg.variable.name + name = fitem.fullname + "." + arg.variable.name builder.final_names.append((name, target.type)) return builder.add(LoadStatic(target.type, name, builder.module_name)) else: name = arg.variable.name builder.fn_info.callable_class.ir.attributes[name] = target.type return builder.add( - GetAttr(builder.fn_info.callable_class.self_reg, name, arg.line)) + GetAttr(builder.fn_info.callable_class.self_reg, name, arg.line) + ) + assert isinstance(target, AssignmentTargetRegister) builder.assign_if_null(target.register, get_default, arg.initializer.line) diff --git a/mypyc/irbuild/callable_class.py b/mypyc/irbuild/callable_class.py index fe561cfc531d5..d2ac7fcd584e6 100644 --- a/mypyc/irbuild/callable_class.py +++ b/mypyc/irbuild/callable_class.py @@ -6,11 +6,11 @@ from typing import List -from mypyc.common import SELF_NAME, ENV_ATTR_NAME -from mypyc.ir.ops import BasicBlock, Return, Call, SetAttr, Value, Register -from mypyc.ir.rtypes import RInstance, object_rprimitive -from mypyc.ir.func_ir import FuncIR, FuncSignature, RuntimeArg, FuncDecl +from mypyc.common import ENV_ATTR_NAME, SELF_NAME from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg +from mypyc.ir.ops import BasicBlock, Call, Register, Return, SetAttr, Value +from mypyc.ir.rtypes import RInstance, object_rprimitive from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.context import FuncInfo, ImplicitClass from mypyc.primitives.misc_ops import method_new_op @@ -45,10 +45,10 @@ class for the nested function. # else: # def foo(): ----> foo_obj_0() # return False - name = base_name = f'{builder.fn_info.namespaced_name()}_obj' + name = base_name = f"{builder.fn_info.namespaced_name()}_obj" count = 0 while name in builder.callable_class_names: - name = base_name + '_' + str(count) + name = base_name + "_" + str(count) count += 1 builder.callable_class_names.add(name) @@ -67,9 +67,7 @@ class for the nested function. # If the enclosing class doesn't contain nested (which will happen if # this is a toplevel lambda), don't set up an environment. if builder.fn_infos[-2].contains_nested: - callable_class_ir.attributes[ENV_ATTR_NAME] = RInstance( - builder.fn_infos[-2].env_class - ) + callable_class_ir.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class) callable_class_ir.mro = [callable_class_ir] builder.fn_info.callable_class = ImplicitClass(callable_class_ir) builder.classes.append(callable_class_ir) @@ -80,11 +78,13 @@ class for the nested function. builder.fn_info.callable_class.self_reg = builder.read(self_target, builder.fn_info.fitem.line) -def add_call_to_callable_class(builder: IRBuilder, - args: List[Register], - blocks: List[BasicBlock], - sig: FuncSignature, - fn_info: FuncInfo) -> FuncIR: +def add_call_to_callable_class( + builder: IRBuilder, + args: List[Register], + blocks: List[BasicBlock], + sig: FuncSignature, + fn_info: FuncInfo, +) -> FuncIR: """Generate a '__call__' method for a callable class representing a nested function. This takes the blocks and signature associated with a function @@ -93,11 +93,12 @@ def add_call_to_callable_class(builder: IRBuilder, """ # Since we create a method, we also add a 'self' parameter. sig = FuncSignature((RuntimeArg(SELF_NAME, object_rprimitive),) + sig.args, sig.ret_type) - call_fn_decl = FuncDecl('__call__', fn_info.callable_class.ir.name, builder.module_name, sig) - call_fn_ir = FuncIR(call_fn_decl, args, blocks, - fn_info.fitem.line, traceback_name=fn_info.fitem.name) - fn_info.callable_class.ir.methods['__call__'] = call_fn_ir - fn_info.callable_class.ir.method_decls['__call__'] = call_fn_decl + call_fn_decl = FuncDecl("__call__", fn_info.callable_class.ir.name, builder.module_name, sig) + call_fn_ir = FuncIR( + call_fn_decl, args, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name + ) + fn_info.callable_class.ir.methods["__call__"] = call_fn_ir + fn_info.callable_class.ir.method_decls["__call__"] = call_fn_decl return call_fn_ir @@ -105,17 +106,21 @@ def add_get_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None: """Generate the '__get__' method for a callable class.""" line = fn_info.fitem.line with builder.enter_method( - fn_info.callable_class.ir, '__get__', object_rprimitive, fn_info, - self_type=object_rprimitive): - instance = builder.add_argument('instance', object_rprimitive) - builder.add_argument('owner', object_rprimitive) + fn_info.callable_class.ir, + "__get__", + object_rprimitive, + fn_info, + self_type=object_rprimitive, + ): + instance = builder.add_argument("instance", object_rprimitive) + builder.add_argument("owner", object_rprimitive) # If accessed through the class, just return the callable # object. If accessed through an object, create a new bound # instance method object. instance_block, class_block = BasicBlock(), BasicBlock() comparison = builder.translate_is_op( - builder.read(instance), builder.none_object(), 'is', line + builder.read(instance), builder.none_object(), "is", line ) builder.add_bool_branch(comparison, class_block, instance_block) @@ -123,8 +128,9 @@ def add_get_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None: builder.add(Return(builder.self())) builder.activate_block(instance_block) - builder.add(Return(builder.call_c(method_new_op, - [builder.self(), builder.read(instance)], line))) + builder.add( + Return(builder.call_c(method_new_op, [builder.self(), builder.read(instance)], line)) + ) def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value: diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index 7cc08b73494fb..113741a021d3e 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -2,34 +2,67 @@ from abc import abstractmethod from typing import Callable, List, Optional, Set, Tuple + from typing_extensions import Final from mypy.nodes import ( - ClassDef, FuncDef, OverloadedFuncDef, PassStmt, AssignmentStmt, CallExpr, NameExpr, StrExpr, - ExpressionStmt, TempNode, Decorator, Lvalue, MemberExpr, RefExpr, TypeInfo, is_class_var + AssignmentStmt, + CallExpr, + ClassDef, + Decorator, + ExpressionStmt, + FuncDef, + Lvalue, + MemberExpr, + NameExpr, + OverloadedFuncDef, + PassStmt, + RefExpr, + StrExpr, + TempNode, + TypeInfo, + is_class_var, ) -from mypy.types import Instance, get_proper_type, ENUM_REMOVED_PROPS +from mypy.types import ENUM_REMOVED_PROPS, Instance, get_proper_type +from mypyc.ir.class_ir import ClassIR, NonExtClassInfo +from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.ops import ( - Value, Register, Call, LoadErrorValue, LoadStatic, InitStatic, TupleSet, SetAttr, Return, - BasicBlock, Branch, MethodCall, NAMESPACE_TYPE, LoadAddress + NAMESPACE_TYPE, + BasicBlock, + Branch, + Call, + InitStatic, + LoadAddress, + LoadErrorValue, + LoadStatic, + MethodCall, + Register, + Return, + SetAttr, + TupleSet, + Value, ) from mypyc.ir.rtypes import ( - RType, object_rprimitive, bool_rprimitive, dict_rprimitive, is_optional_type, - is_object_rprimitive, is_none_rprimitive -) -from mypyc.ir.func_ir import FuncDecl, FuncSignature -from mypyc.ir.class_ir import ClassIR, NonExtClassInfo -from mypyc.primitives.generic_ops import py_setattr_op, py_hasattr_op -from mypyc.primitives.misc_ops import ( - dataclass_sleight_of_hand, pytype_from_template_op, py_calc_meta_op, type_object_op, - not_implemented_op -) -from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op -from mypyc.irbuild.util import ( - is_dataclass_decorator, get_func_def, is_constant, dataclass_type + RType, + bool_rprimitive, + dict_rprimitive, + is_none_rprimitive, + is_object_rprimitive, + is_optional_type, + object_rprimitive, ) from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.function import handle_ext_method, handle_non_ext_method, load_type +from mypyc.irbuild.util import dataclass_type, get_func_def, is_constant, is_dataclass_decorator +from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op +from mypyc.primitives.generic_ops import py_hasattr_op, py_setattr_op +from mypyc.primitives.misc_ops import ( + dataclass_sleight_of_hand, + not_implemented_op, + py_calc_meta_op, + pytype_from_template_op, + type_object_op, +) def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: @@ -49,7 +82,7 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: # We do this check here because the base field of parent # classes aren't necessarily populated yet at # prepare_class_def time. - if any(ir.base_mro[i].base != ir. base_mro[i + 1] for i in range(len(ir.base_mro) - 1)): + if any(ir.base_mro[i].base != ir.base_mro[i + 1] for i in range(len(ir.base_mro) - 1)): builder.error("Non-trait MRO must be linear", cdef.line) if ir.allow_interpreted_subclasses: @@ -57,7 +90,10 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: if not parent.allow_interpreted_subclasses: builder.error( 'Base class "{}" does not allow interpreted subclasses'.format( - parent.fullname), cdef.line) + parent.fullname + ), + cdef.line, + ) # Currently, we only create non-extension classes for classes that are # decorated or inherit from Enum. Classes decorated with @trait do not @@ -66,9 +102,9 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: cls_type = dataclass_type(cdef) if cls_type is None: cls_builder: ClassBuilder = ExtClassBuilder(builder, cdef) - elif cls_type in ['dataclasses', 'attr-auto']: + elif cls_type in ["dataclasses", "attr-auto"]: cls_builder = DataClassBuilder(builder, cdef) - elif cls_type == 'attr': + elif cls_type == "attr": cls_builder = AttrsClassBuilder(builder, cdef) else: raise ValueError(cls_type) @@ -80,8 +116,7 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: if isinstance(cls_builder, NonExtClassBuilder): # properties with both getters and setters in non_extension # classes not supported - builder.error("Property setters not supported in non-extension classes", - stmt.line) + builder.error("Property setters not supported in non-extension classes", stmt.line) for item in stmt.items: with builder.catch_errors(stmt.line): cls_builder.add_method(get_func_def(item)) @@ -101,8 +136,9 @@ def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None: continue lvalue = stmt.lvalues[0] if not isinstance(lvalue, NameExpr): - builder.error("Only assignment to variables is supported in class bodies", - stmt.line) + builder.error( + "Only assignment to variables is supported in class bodies", stmt.line + ) continue # We want to collect class variables in a dictionary for both real # non-extension classes and fake dataclass ones. @@ -149,8 +185,9 @@ def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None: def create_non_ext_info(self) -> NonExtClassInfo: non_ext_bases = populate_non_ext_bases(self.builder, self.cdef) non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases) - non_ext_dict = setup_non_ext_dict(self.builder, self.cdef, non_ext_metaclass, - non_ext_bases) + non_ext_dict = setup_non_ext_dict( + self.builder, self.cdef, non_ext_metaclass, non_ext_bases + ) # We populate __annotations__ for non-extension classes # because dataclasses uses it to determine which attributes to compute on. # TODO: Maybe generate more precise types for annotations @@ -162,8 +199,9 @@ def add_method(self, fdef: FuncDef) -> None: def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: add_non_ext_class_attr_ann(self.builder, self.non_ext, lvalue, stmt) - add_non_ext_class_attr(self.builder, self.non_ext, lvalue, stmt, self.cdef, - self.attrs_to_cache) + add_non_ext_class_attr( + self.builder, self.non_ext, lvalue, stmt, self.cdef, self.attrs_to_cache + ) def finalize(self, ir: ClassIR) -> None: # Dynamically create the class via the type constructor @@ -171,16 +209,20 @@ def finalize(self, ir: ClassIR) -> None: non_ext_class = load_decorated_class(self.builder, self.cdef, non_ext_class) # Save the decorated class - self.builder.add(InitStatic(non_ext_class, self.cdef.name, self.builder.module_name, - NAMESPACE_TYPE)) + self.builder.add( + InitStatic(non_ext_class, self.cdef.name, self.builder.module_name, NAMESPACE_TYPE) + ) # Add the non-extension class to the dict - self.builder.call_c(dict_set_item_op, - [ - self.builder.load_globals_dict(), - self.builder.load_str(self.cdef.name), - non_ext_class - ], self.cdef.line) + self.builder.call_c( + dict_set_item_op, + [ + self.builder.load_globals_dict(), + self.builder.load_str(self.cdef.name), + non_ext_class, + ], + self.cdef.line, + ) # Cache any cacheable class attributes cache_class_attrs(self.builder, self.attrs_to_cache, self.cdef) @@ -209,13 +251,15 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: typ = self.builder.load_native_type_object(self.cdef.fullname) value = self.builder.accept(stmt.rvalue) self.builder.call_c( - py_setattr_op, [typ, self.builder.load_str(lvalue.name), value], stmt.line) + py_setattr_op, [typ, self.builder.load_str(lvalue.name), value], stmt.line + ) if self.builder.non_function_scope() and stmt.is_final_def: self.builder.init_final_static(lvalue, value, self.cdef.name) def finalize(self, ir: ClassIR) -> None: attrs_with_defaults, default_assignments = find_attr_initializers( - self.builder, self.cdef, self.skip_attr_default) + self.builder, self.cdef, self.skip_attr_default + ) ir.attrs_with_defaults.update(attrs_with_defaults) generate_attr_defaults_init(self.builder, self.cdef, default_assignments) create_ne_from_eq(self.builder, self.cdef) @@ -242,7 +286,7 @@ def create_non_ext_info(self) -> NonExtClassInfo: self.builder.call_c(dict_new_op, [], self.cdef.line), self.builder.add(TupleSet([], self.cdef.line)), self.builder.call_c(dict_new_op, [], self.cdef.line), - self.builder.add(LoadAddress(type_object_op.type, type_object_op.src, self.cdef.line)) + self.builder.add(LoadAddress(type_object_op.type, type_object_op.src, self.cdef.line)), ) def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool: @@ -257,10 +301,12 @@ def get_type_annotation(self, stmt: AssignmentStmt) -> Optional[TypeInfo]: return None def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: - add_non_ext_class_attr_ann(self.builder, self.non_ext, lvalue, stmt, - self.get_type_annotation) - add_non_ext_class_attr(self.builder, self.non_ext, lvalue, stmt, self.cdef, - self.attrs_to_cache) + add_non_ext_class_attr_ann( + self.builder, self.non_ext, lvalue, stmt, self.get_type_annotation + ) + add_non_ext_class_attr( + self.builder, self.non_ext, lvalue, stmt, self.cdef, self.attrs_to_cache + ) super().add_attr(lvalue, stmt) def finalize(self, ir: ClassIR) -> None: @@ -283,13 +329,17 @@ def finalize(self, ir: ClassIR) -> None: """ super().finalize(ir) assert self.type_obj - add_dunders_to_non_ext_dict(self.builder, self.non_ext, self.cdef.line, - self.add_annotations_to_dict) + add_dunders_to_non_ext_dict( + self.builder, self.non_ext, self.cdef.line, self.add_annotations_to_dict + ) dec = self.builder.accept( - next(d for d in self.cdef.decorators if is_dataclass_decorator(d))) + next(d for d in self.cdef.decorators if is_dataclass_decorator(d)) + ) self.builder.call_c( - dataclass_sleight_of_hand, [dec, self.type_obj, self.non_ext.dict, self.non_ext.anns], - self.cdef.line) + dataclass_sleight_of_hand, + [dec, self.type_obj, self.non_ext.dict, self.non_ext.anns], + self.cdef.line, + ) class AttrsClassBuilder(DataClassBuilder): @@ -310,10 +360,12 @@ def get_type_annotation(self, stmt: AssignmentStmt) -> Optional[TypeInfo]: if isinstance(stmt.rvalue, CallExpr): # find the type arg in `attr.ib(type=str)` callee = stmt.rvalue.callee - if (isinstance(callee, MemberExpr) and - callee.fullname in ['attr.ib', 'attr.attr'] and - 'type' in stmt.rvalue.arg_names): - index = stmt.rvalue.arg_names.index('type') + if ( + isinstance(callee, MemberExpr) + and callee.fullname in ["attr.ib", "attr.attr"] + and "type" in stmt.rvalue.arg_names + ): + index = stmt.rvalue.arg_names.index("type") type_name = stmt.rvalue.args[index] if isinstance(type_name, NameExpr) and isinstance(type_name.node, TypeInfo): lvalue = stmt.lvalues[0] @@ -331,33 +383,44 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value: else: tp_bases = builder.add(LoadErrorValue(object_rprimitive, is_borrowed=True)) modname = builder.load_str(builder.module_name) - template = builder.add(LoadStatic(object_rprimitive, cdef.name + "_template", - builder.module_name, NAMESPACE_TYPE)) + template = builder.add( + LoadStatic(object_rprimitive, cdef.name + "_template", builder.module_name, NAMESPACE_TYPE) + ) # Create the class - tp = builder.call_c(pytype_from_template_op, - [template, tp_bases, modname], cdef.line) + tp = builder.call_c(pytype_from_template_op, [template, tp_bases, modname], cdef.line) # Immediately fix up the trait vtables, before doing anything with the class. ir = builder.mapper.type_to_ir[cdef.info] if not ir.is_trait and not ir.builtin_base: - builder.add(Call( - FuncDecl(cdef.name + '_trait_vtable_setup', - None, builder.module_name, - FuncSignature([], bool_rprimitive)), [], -1)) + builder.add( + Call( + FuncDecl( + cdef.name + "_trait_vtable_setup", + None, + builder.module_name, + FuncSignature([], bool_rprimitive), + ), + [], + -1, + ) + ) # Populate a '__mypyc_attrs__' field containing the list of attrs - builder.call_c(py_setattr_op, [ - tp, builder.load_str('__mypyc_attrs__'), - create_mypyc_attrs_tuple(builder, builder.mapper.type_to_ir[cdef.info], cdef.line)], - cdef.line) + builder.call_c( + py_setattr_op, + [ + tp, + builder.load_str("__mypyc_attrs__"), + create_mypyc_attrs_tuple(builder, builder.mapper.type_to_ir[cdef.info], cdef.line), + ], + cdef.line, + ) # Save the class builder.add(InitStatic(tp, cdef.name, builder.module_name, NAMESPACE_TYPE)) # Add it to the dict - builder.call_c(dict_set_item_op, - [builder.load_globals_dict(), - builder.load_str(cdef.name), - tp], - cdef.line) + builder.call_c( + dict_set_item_op, [builder.load_globals_dict(), builder.load_str(cdef.name), tp], cdef.line + ) return tp @@ -365,8 +428,8 @@ def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value: # Mypy uses these internally as base classes of TypedDict classes. These are # lies and don't have any runtime equivalent. MAGIC_TYPED_DICT_CLASSES: Final[Tuple[str, ...]] = ( - 'typing._TypedDict', - 'typing_extensions._TypedDict', + "typing._TypedDict", + "typing_extensions._TypedDict", ) @@ -379,13 +442,15 @@ def populate_non_ext_bases(builder: IRBuilder, cdef: ClassDef) -> Value: ir = builder.mapper.type_to_ir[cdef.info] bases = [] for cls in cdef.info.mro[1:]: - if cls.fullname == 'builtins.object': + if cls.fullname == "builtins.object": continue - if is_named_tuple and cls.fullname in ('typing.Sequence', - 'typing.Iterable', - 'typing.Collection', - 'typing.Reversible', - 'typing.Container'): + if is_named_tuple and cls.fullname in ( + "typing.Sequence", + "typing.Iterable", + "typing.Collection", + "typing.Reversible", + "typing.Container", + ): # HAX: Synthesized base classes added by mypy don't exist at runtime, so skip them. # This could break if they were added explicitly, though... continue @@ -398,23 +463,23 @@ def populate_non_ext_bases(builder: IRBuilder, cdef: ClassDef) -> Value: if cls.fullname in MAGIC_TYPED_DICT_CLASSES: # HAX: Mypy internally represents TypedDict classes differently from what # should happen at runtime. Replace with something that works. - module = 'typing' + module = "typing" if builder.options.capi_version < (3, 9): - name = 'TypedDict' + name = "TypedDict" if builder.options.capi_version < (3, 8): # TypedDict was added to typing in Python 3.8. - module = 'typing_extensions' + module = "typing_extensions" else: # In Python 3.9 TypedDict is not a real type. - name = '_TypedDict' + name = "_TypedDict" base = builder.get_module_attr(module, name, cdef.line) - elif is_named_tuple and cls.fullname == 'builtins.tuple': + elif is_named_tuple and cls.fullname == "builtins.tuple": if builder.options.capi_version < (3, 9): - name = 'NamedTuple' + name = "NamedTuple" else: # This was changed in Python 3.9. - name = '_NamedTuple' - base = builder.get_module_attr('typing', name, cdef.line) + name = "_NamedTuple" + base = builder.get_module_attr("typing", name, cdef.line) else: base = builder.load_global_str(cls.name, cdef.line) bases.append(base) @@ -425,46 +490,46 @@ def populate_non_ext_bases(builder: IRBuilder, cdef: ClassDef) -> Value: def find_non_ext_metaclass(builder: IRBuilder, cdef: ClassDef, bases: Value) -> Value: - """Find the metaclass of a class from its defs and bases. """ + """Find the metaclass of a class from its defs and bases.""" if cdef.metaclass: declared_metaclass = builder.accept(cdef.metaclass) else: if cdef.info.typeddict_type is not None and builder.options.capi_version >= (3, 9): # In Python 3.9, the metaclass for class-based TypedDict is typing._TypedDictMeta. # We can't easily calculate it generically, so special case it. - return builder.get_module_attr('typing', '_TypedDictMeta', cdef.line) + return builder.get_module_attr("typing", "_TypedDictMeta", cdef.line) elif cdef.info.is_named_tuple and builder.options.capi_version >= (3, 9): # In Python 3.9, the metaclass for class-based NamedTuple is typing.NamedTupleMeta. # We can't easily calculate it generically, so special case it. - return builder.get_module_attr('typing', 'NamedTupleMeta', cdef.line) + return builder.get_module_attr("typing", "NamedTupleMeta", cdef.line) - declared_metaclass = builder.add(LoadAddress(type_object_op.type, - type_object_op.src, cdef.line)) + declared_metaclass = builder.add( + LoadAddress(type_object_op.type, type_object_op.src, cdef.line) + ) return builder.call_c(py_calc_meta_op, [declared_metaclass, bases], cdef.line) -def setup_non_ext_dict(builder: IRBuilder, - cdef: ClassDef, - metaclass: Value, - bases: Value) -> Value: +def setup_non_ext_dict( + builder: IRBuilder, cdef: ClassDef, metaclass: Value, bases: Value +) -> Value: """Initialize the class dictionary for a non-extension class. This class dictionary is passed to the metaclass constructor. """ # Check if the metaclass defines a __prepare__ method, and if so, call it. - has_prepare = builder.call_c(py_hasattr_op, - [metaclass, - builder.load_str('__prepare__')], cdef.line) + has_prepare = builder.call_c( + py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line + ) non_ext_dict = Register(dict_rprimitive) - true_block, false_block, exit_block, = BasicBlock(), BasicBlock(), BasicBlock() + true_block, false_block, exit_block = (BasicBlock(), BasicBlock(), BasicBlock()) builder.add_bool_branch(has_prepare, true_block, false_block) builder.activate_block(true_block) cls_name = builder.load_str(cdef.name) - prepare_meth = builder.py_get_attr(metaclass, '__prepare__', cdef.line) + prepare_meth = builder.py_get_attr(metaclass, "__prepare__", cdef.line) prepare_dict = builder.py_call(prepare_meth, [cls_name, bases], cdef.line) builder.assign(non_ext_dict, prepare_dict, cdef.line) builder.goto(exit_block) @@ -477,13 +542,13 @@ def setup_non_ext_dict(builder: IRBuilder, return non_ext_dict -def add_non_ext_class_attr_ann(builder: IRBuilder, - non_ext: NonExtClassInfo, - lvalue: NameExpr, - stmt: AssignmentStmt, - get_type_info: Optional[Callable[[AssignmentStmt], - Optional[TypeInfo]]] = None - ) -> None: +def add_non_ext_class_attr_ann( + builder: IRBuilder, + non_ext: NonExtClassInfo, + lvalue: NameExpr, + stmt: AssignmentStmt, + get_type_info: Optional[Callable[[AssignmentStmt], Optional[TypeInfo]]] = None, +) -> None: """Add a class attribute to __annotations__ of a non-extension class.""" typ: Optional[Value] = None if get_type_info is not None: @@ -503,12 +568,14 @@ def add_non_ext_class_attr_ann(builder: IRBuilder, builder.call_c(dict_set_item_op, [non_ext.anns, key, typ], stmt.line) -def add_non_ext_class_attr(builder: IRBuilder, - non_ext: NonExtClassInfo, - lvalue: NameExpr, - stmt: AssignmentStmt, - cdef: ClassDef, - attr_to_cache: List[Tuple[Lvalue, RType]]) -> None: +def add_non_ext_class_attr( + builder: IRBuilder, + non_ext: NonExtClassInfo, + lvalue: NameExpr, + stmt: AssignmentStmt, + cdef: ClassDef, + attr_to_cache: List[Tuple[Lvalue, RType]], +) -> None: """Add a class attribute to __dict__ of a non-extension class.""" # Only add the attribute to the __dict__ if the assignment is of the form: # x: type = value (don't add attributes of the form 'x: type' to the __dict__). @@ -519,7 +586,7 @@ def add_non_ext_class_attr(builder: IRBuilder, # are final. if ( cdef.info.bases - and cdef.info.bases[0].type.fullname == 'enum.Enum' + and cdef.info.bases[0].type.fullname == "enum.Enum" # Skip these since Enum will remove it and lvalue.name not in ENUM_REMOVED_PROPS ): @@ -527,10 +594,11 @@ def add_non_ext_class_attr(builder: IRBuilder, attr_to_cache.append((lvalue, object_rprimitive)) -def find_attr_initializers(builder: IRBuilder, - cdef: ClassDef, - skip: Optional[Callable[[str, AssignmentStmt], bool]] = None, - ) -> Tuple[Set[str], List[AssignmentStmt]]: +def find_attr_initializers( + builder: IRBuilder, + cdef: ClassDef, + skip: Optional[Callable[[str, AssignmentStmt], bool]] = None, +) -> Tuple[Set[str], List[AssignmentStmt]]: """Find initializers of attributes in a class body. If provided, the skip arg should be a callable which will return whether @@ -550,15 +618,17 @@ def find_attr_initializers(builder: IRBuilder, if info not in builder.mapper.type_to_ir: continue for stmt in info.defn.defs.body: - if (isinstance(stmt, AssignmentStmt) - and isinstance(stmt.lvalues[0], NameExpr) - and not is_class_var(stmt.lvalues[0]) - and not isinstance(stmt.rvalue, TempNode)): + if ( + isinstance(stmt, AssignmentStmt) + and isinstance(stmt.lvalues[0], NameExpr) + and not is_class_var(stmt.lvalues[0]) + and not isinstance(stmt.rvalue, TempNode) + ): name = stmt.lvalues[0].name - if name == '__slots__': + if name == "__slots__": continue - if name == '__deletable__': + if name == "__deletable__": check_deletable_declaration(builder, cls, stmt.line) continue @@ -569,9 +639,12 @@ def find_attr_initializers(builder: IRBuilder, # If the attribute is initialized to None and type isn't optional, # doesn't initialize it to anything (special case for "# type:" comments). - if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == 'builtins.None': - if (not is_optional_type(attr_type) and not is_object_rprimitive(attr_type) - and not is_none_rprimitive(attr_type)): + if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == "builtins.None": + if ( + not is_optional_type(attr_type) + and not is_object_rprimitive(attr_type) + and not is_none_rprimitive(attr_type) + ): continue attrs_with_defaults.add(name) @@ -580,9 +653,9 @@ def find_attr_initializers(builder: IRBuilder, return attrs_with_defaults, default_assignments -def generate_attr_defaults_init(builder: IRBuilder, - cdef: ClassDef, - default_assignments: List[AssignmentStmt]) -> None: +def generate_attr_defaults_init( + builder: IRBuilder, cdef: ClassDef, default_assignments: List[AssignmentStmt] +) -> None: """Generate an initialization method for default attr values (from class vars).""" if not default_assignments: return @@ -590,13 +663,13 @@ def generate_attr_defaults_init(builder: IRBuilder, if cls.builtin_base: return - with builder.enter_method(cls, '__mypyc_defaults_setup', bool_rprimitive): + with builder.enter_method(cls, "__mypyc_defaults_setup", bool_rprimitive): self_var = builder.self() for stmt in default_assignments: lvalue = stmt.lvalues[0] assert isinstance(lvalue, NameExpr) if not stmt.is_final_def and not is_constant(stmt.rvalue): - builder.warning('Unsupported default attribute value', stmt.rvalue.line) + builder.warning("Unsupported default attribute value", stmt.rvalue.line) attr_type = cls.attr_type(lvalue.name) val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line) @@ -619,55 +692,58 @@ def check_deletable_declaration(builder: IRBuilder, cl: ClassIR, line: int) -> N break else: _, base = cl.attr_details(attr) - builder.error(('Attribute "{}" not defined in "{}" ' + - '(defined in "{}")').format(attr, cl.name, base.name), line) + builder.error( + ('Attribute "{}" not defined in "{}" ' + '(defined in "{}")').format( + attr, cl.name, base.name + ), + line, + ) def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None: """Create a "__ne__" method from a "__eq__" method (if only latter exists).""" cls = builder.mapper.type_to_ir[cdef.info] - if cls.has_method('__eq__') and not cls.has_method('__ne__'): + if cls.has_method("__eq__") and not cls.has_method("__ne__"): gen_glue_ne_method(builder, cls, cdef.line) def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None: - """Generate a "__ne__" method from a "__eq__" method. """ - with builder.enter_method(cls, '__ne__', object_rprimitive): - rhs_arg = builder.add_argument('rhs', object_rprimitive) + """Generate a "__ne__" method from a "__eq__" method.""" + with builder.enter_method(cls, "__ne__", object_rprimitive): + rhs_arg = builder.add_argument("rhs", object_rprimitive) # If __eq__ returns NotImplemented, then __ne__ should also not_implemented_block, regular_block = BasicBlock(), BasicBlock() - eqval = builder.add(MethodCall(builder.self(), '__eq__', [rhs_arg], line)) - not_implemented = builder.add(LoadAddress(not_implemented_op.type, - not_implemented_op.src, line)) - builder.add(Branch( - builder.translate_is_op(eqval, not_implemented, 'is', line), - not_implemented_block, - regular_block, - Branch.BOOL)) + eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line)) + not_implemented = builder.add( + LoadAddress(not_implemented_op.type, not_implemented_op.src, line) + ) + builder.add( + Branch( + builder.translate_is_op(eqval, not_implemented, "is", line), + not_implemented_block, + regular_block, + Branch.BOOL, + ) + ) builder.activate_block(regular_block) - retval = builder.coerce( - builder.unary_op(eqval, 'not', line), object_rprimitive, line - ) + retval = builder.coerce(builder.unary_op(eqval, "not", line), object_rprimitive, line) builder.add(Return(retval)) builder.activate_block(not_implemented_block) builder.add(Return(not_implemented)) -def load_non_ext_class(builder: IRBuilder, - ir: ClassIR, - non_ext: NonExtClassInfo, - line: int) -> Value: +def load_non_ext_class( + builder: IRBuilder, ir: ClassIR, non_ext: NonExtClassInfo, line: int +) -> Value: cls_name = builder.load_str(ir.name) add_dunders_to_non_ext_dict(builder, non_ext, line) class_type_obj = builder.py_call( - non_ext.metaclass, - [cls_name, non_ext.bases, non_ext.dict], - line + non_ext.metaclass, [cls_name, non_ext.bases, non_ext.dict], line ) return class_type_obj @@ -690,9 +766,9 @@ def load_decorated_class(builder: IRBuilder, cdef: ClassDef, type_obj: Value) -> return dec_class -def cache_class_attrs(builder: IRBuilder, - attrs_to_cache: List[Tuple[Lvalue, RType]], - cdef: ClassDef) -> None: +def cache_class_attrs( + builder: IRBuilder, attrs_to_cache: List[Tuple[Lvalue, RType]], cdef: ClassDef +) -> None: """Add class attributes to be cached to the global cache.""" typ = builder.load_native_type_object(cdef.info.fullname) for lval, rtype in attrs_to_cache: @@ -704,22 +780,21 @@ def cache_class_attrs(builder: IRBuilder, def create_mypyc_attrs_tuple(builder: IRBuilder, ir: ClassIR, line: int) -> Value: attrs = [name for ancestor in ir.mro for name in ancestor.attributes] if ir.inherits_python: - attrs.append('__dict__') + attrs.append("__dict__") items = [builder.load_str(attr) for attr in attrs] return builder.new_tuple(items, line) -def add_dunders_to_non_ext_dict(builder: IRBuilder, non_ext: NonExtClassInfo, - line: int, add_annotations: bool = True) -> None: +def add_dunders_to_non_ext_dict( + builder: IRBuilder, non_ext: NonExtClassInfo, line: int, add_annotations: bool = True +) -> None: if add_annotations: # Add __annotations__ to the class dict. - builder.add_to_non_ext_dict(non_ext, '__annotations__', non_ext.anns, line) + builder.add_to_non_ext_dict(non_ext, "__annotations__", non_ext.anns, line) # We add a __doc__ attribute so if the non-extension class is decorated with the # dataclass decorator, dataclass will not try to look for __text_signature__. # https://github.com/python/cpython/blob/3.7/Lib/dataclasses.py#L957 - filler_doc_str = 'mypyc filler docstring' - builder.add_to_non_ext_dict( - non_ext, '__doc__', builder.load_str(filler_doc_str), line) - builder.add_to_non_ext_dict( - non_ext, '__module__', builder.load_str(builder.module_name), line) + filler_doc_str = "mypyc filler docstring" + builder.add_to_non_ext_dict(non_ext, "__doc__", builder.load_str(filler_doc_str), line) + builder.add_to_non_ext_dict(non_ext, "__module__", builder.load_str(builder.module_name), line) diff --git a/mypyc/irbuild/constant_fold.py b/mypyc/irbuild/constant_fold.py index 21e9ea939a3ef..9ded13f405866 100644 --- a/mypyc/irbuild/constant_fold.py +++ b/mypyc/irbuild/constant_fold.py @@ -4,12 +4,12 @@ """ from typing import Optional, Union + from typing_extensions import Final -from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var +from mypy.nodes import Expression, IntExpr, MemberExpr, NameExpr, OpExpr, StrExpr, UnaryExpr, Var from mypyc.irbuild.builder import IRBuilder - # All possible result types of constant folding ConstantValue = Union[int, str] CONST_TYPES: Final = (int, str) @@ -53,47 +53,47 @@ def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[Constan def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]: - if op == '+': + if op == "+": return left + right - if op == '-': + if op == "-": return left - right - elif op == '*': + elif op == "*": return left * right - elif op == '//': + elif op == "//": if right != 0: return left // right - elif op == '%': + elif op == "%": if right != 0: return left % right - elif op == '&': + elif op == "&": return left & right - elif op == '|': + elif op == "|": return left | right - elif op == '^': + elif op == "^": return left ^ right - elif op == '<<': + elif op == "<<": if right >= 0: return left << right - elif op == '>>': + elif op == ">>": if right >= 0: return left >> right - elif op == '**': + elif op == "**": if right >= 0: - return left ** right + return left**right return None def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]: - if op == '-': + if op == "-": return -value - elif op == '~': + elif op == "~": return ~value - elif op == '+': + elif op == "+": return value return None def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]: - if op == '+': + if op == "+": return left + right return None diff --git a/mypyc/irbuild/context.py b/mypyc/irbuild/context.py index 307ce84ab584f..cfeb96110bacf 100644 --- a/mypyc/irbuild/context.py +++ b/mypyc/irbuild/context.py @@ -3,25 +3,26 @@ from typing import List, Optional, Tuple from mypy.nodes import FuncItem - -from mypyc.ir.ops import Value, BasicBlock -from mypyc.ir.func_ir import INVALID_FUNC_DEF from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import INVALID_FUNC_DEF +from mypyc.ir.ops import BasicBlock, Value from mypyc.irbuild.targets import AssignmentTarget class FuncInfo: """Contains information about functions as they are generated.""" - def __init__(self, - fitem: FuncItem = INVALID_FUNC_DEF, - name: str = '', - class_name: Optional[str] = None, - namespace: str = '', - is_nested: bool = False, - contains_nested: bool = False, - is_decorated: bool = False, - in_non_ext: bool = False) -> None: + def __init__( + self, + fitem: FuncItem = INVALID_FUNC_DEF, + name: str = "", + class_name: Optional[str] = None, + namespace: str = "", + is_nested: bool = False, + contains_nested: bool = False, + is_decorated: bool = False, + in_non_ext: bool = False, + ) -> None: self.fitem = fitem self.name = name self.class_name = class_name @@ -50,7 +51,7 @@ def __init__(self, # TODO: add field for ret_type: RType = none_rprimitive def namespaced_name(self) -> str: - return '_'.join(x for x in [self.name, self.class_name, self.ns] if x) + return "_".join(x for x in [self.name, self.class_name, self.ns] if x) @property def is_generator(self) -> bool: @@ -61,12 +62,12 @@ def is_coroutine(self) -> bool: return self.fitem.is_coroutine @property - def callable_class(self) -> 'ImplicitClass': + def callable_class(self) -> "ImplicitClass": assert self._callable_class is not None return self._callable_class @callable_class.setter - def callable_class(self, cls: 'ImplicitClass') -> None: + def callable_class(self, cls: "ImplicitClass") -> None: self._callable_class = cls @property @@ -79,12 +80,12 @@ def env_class(self, ir: ClassIR) -> None: self._env_class = ir @property - def generator_class(self) -> 'GeneratorClass': + def generator_class(self) -> "GeneratorClass": assert self._generator_class is not None return self._generator_class @generator_class.setter - def generator_class(self, cls: 'GeneratorClass') -> None: + def generator_class(self, cls: "GeneratorClass") -> None: self._generator_class = cls @property diff --git a/mypyc/irbuild/env_class.py b/mypyc/irbuild/env_class.py index 9ed764c8bcca4..c31df44eeba09 100644 --- a/mypyc/irbuild/env_class.py +++ b/mypyc/irbuild/env_class.py @@ -18,14 +18,13 @@ def g() -> int: from typing import Dict, Optional, Union from mypy.nodes import FuncDef, SymbolNode - -from mypyc.common import SELF_NAME, ENV_ATTR_NAME +from mypyc.common import ENV_ATTR_NAME, SELF_NAME +from mypyc.ir.class_ir import ClassIR from mypyc.ir.ops import Call, GetAttr, SetAttr, Value from mypyc.ir.rtypes import RInstance, object_rprimitive -from mypyc.ir.class_ir import ClassIR from mypyc.irbuild.builder import IRBuilder, SymbolTarget +from mypyc.irbuild.context import FuncInfo, GeneratorClass, ImplicitClass from mypyc.irbuild.targets import AssignmentTargetAttr -from mypyc.irbuild.context import FuncInfo, ImplicitClass, GeneratorClass def setup_env_class(builder: IRBuilder) -> ClassIR: @@ -43,8 +42,9 @@ class is generated, the function environment has not yet been Return a ClassIR representing an environment for a function containing a nested function. """ - env_class = ClassIR(f'{builder.fn_info.namespaced_name()}_env', - builder.module_name, is_generated=True) + env_class = ClassIR( + f"{builder.fn_info.namespaced_name()}_env", builder.module_name, is_generated=True + ) env_class.attributes[SELF_NAME] = RInstance(env_class) if builder.fn_info.is_nested: # If the function is nested, its environment class must contain an environment @@ -77,10 +77,14 @@ def instantiate_env_class(builder: IRBuilder) -> Value: if builder.fn_info.is_nested: builder.fn_info.callable_class._curr_env_reg = curr_env_reg - builder.add(SetAttr(curr_env_reg, - ENV_ATTR_NAME, - builder.fn_info.callable_class.prev_env_reg, - builder.fn_info.fitem.line)) + builder.add( + SetAttr( + curr_env_reg, + ENV_ATTR_NAME, + builder.fn_info.callable_class.prev_env_reg, + builder.fn_info.fitem.line, + ) + ) else: builder.fn_info._curr_env_reg = curr_env_reg @@ -107,9 +111,9 @@ def load_env_registers(builder: IRBuilder) -> None: setup_func_for_recursive_call(builder, fitem, fn_info.callable_class) -def load_outer_env(builder: IRBuilder, - base: Value, - outer_env: Dict[SymbolNode, SymbolTarget]) -> Value: +def load_outer_env( + builder: IRBuilder, base: Value, outer_env: Dict[SymbolNode, SymbolTarget] +) -> Value: """Load the environment class for a given base into a register. Additionally, iterates through all of the SymbolNode and @@ -122,7 +126,7 @@ def load_outer_env(builder: IRBuilder, Returns the register where the environment class was loaded. """ env = builder.add(GetAttr(base, ENV_ATTR_NAME, builder.fn_info.fitem.line)) - assert isinstance(env.type, RInstance), f'{env} must be of type RInstance' + assert isinstance(env.type, RInstance), f"{env} must be of type RInstance" for symbol, target in outer_env.items(): env.type.class_ir.attributes[symbol.name] = target.type @@ -155,10 +159,12 @@ def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None: index -= 1 -def add_args_to_env(builder: IRBuilder, - local: bool = True, - base: Optional[Union[FuncInfo, ImplicitClass]] = None, - reassign: bool = True) -> None: +def add_args_to_env( + builder: IRBuilder, + local: bool = True, + base: Optional[Union[FuncInfo, ImplicitClass]] = None, + reassign: bool = True, +) -> None: fn_info = builder.fn_info if local: for arg in fn_info.fitem.arguments: @@ -168,7 +174,7 @@ def add_args_to_env(builder: IRBuilder, for arg in fn_info.fitem.arguments: if is_free_variable(builder, arg.variable) or fn_info.is_generator: rtype = builder.type_to_rtype(arg.variable.type) - assert base is not None, 'base cannot be None for adding nonlocal args' + assert base is not None, "base cannot be None for adding nonlocal args" builder.add_var_to_env_class(arg.variable, rtype, base, reassign=reassign) @@ -201,7 +207,4 @@ def setup_func_for_recursive_call(builder: IRBuilder, fdef: FuncDef, base: Impli def is_free_variable(builder: IRBuilder, symbol: SymbolNode) -> bool: fitem = builder.fn_info.fitem - return ( - fitem in builder.free_variables - and symbol in builder.free_variables[fitem] - ) + return fitem in builder.free_variables and symbol in builder.free_variables[fitem] diff --git a/mypyc/irbuild/expression.py b/mypyc/irbuild/expression.py index 49a5dd38089a5..00516959c4dc5 100644 --- a/mypyc/irbuild/expression.py +++ b/mypyc/irbuild/expression.py @@ -4,70 +4,117 @@ and mypyc.irbuild.builder. """ -from typing import List, Optional, Union, Callable, cast +from typing import Callable, List, Optional, Union, cast from mypy.nodes import ( - Expression, NameExpr, MemberExpr, SuperExpr, CallExpr, UnaryExpr, OpExpr, IndexExpr, - ConditionalExpr, ComparisonExpr, IntExpr, FloatExpr, ComplexExpr, StrExpr, - BytesExpr, EllipsisExpr, ListExpr, TupleExpr, DictExpr, SetExpr, ListComprehension, - SetComprehension, DictionaryComprehension, SliceExpr, GeneratorExpr, CastExpr, StarExpr, - AssignmentExpr, AssertTypeExpr, - Var, RefExpr, MypyFile, TypeInfo, TypeApplication, LDEF, ARG_POS + ARG_POS, + LDEF, + AssertTypeExpr, + AssignmentExpr, + BytesExpr, + CallExpr, + CastExpr, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + Expression, + FloatExpr, + GeneratorExpr, + IndexExpr, + IntExpr, + ListComprehension, + ListExpr, + MemberExpr, + MypyFile, + NameExpr, + OpExpr, + RefExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + StrExpr, + SuperExpr, + TupleExpr, + TypeApplication, + TypeInfo, + UnaryExpr, + Var, ) -from mypy.types import TupleType, Instance, TypeType, ProperType, get_proper_type - +from mypy.types import Instance, ProperType, TupleType, TypeType, get_proper_type from mypyc.common import MAX_SHORT_INT +from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD from mypyc.ir.ops import ( - Value, Register, TupleGet, TupleSet, BasicBlock, Assign, LoadAddress, RaiseStandardError + Assign, + BasicBlock, + LoadAddress, + RaiseStandardError, + Register, + TupleGet, + TupleSet, + Value, ) from mypyc.ir.rtypes import ( - RTuple, object_rprimitive, is_none_rprimitive, int_rprimitive, is_int_rprimitive, - is_list_rprimitive + RTuple, + int_rprimitive, + is_int_rprimitive, + is_list_rprimitive, + is_none_rprimitive, + object_rprimitive, +) +from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional +from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op +from mypyc.irbuild.constant_fold import constant_fold_expr +from mypyc.irbuild.for_helpers import ( + comprehension_helper, + translate_list_comprehension, + translate_set_comprehension, ) -from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD from mypyc.irbuild.format_str_tokenizer import ( - tokenizer_printf_style, join_formatted_strings, convert_format_expr_to_str, - convert_format_expr_to_bytes, join_formatted_bytes + convert_format_expr_to_bytes, + convert_format_expr_to_str, + join_formatted_bytes, + join_formatted_strings, + tokenizer_printf_style, ) +from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization from mypyc.primitives.bytes_ops import bytes_slice_op -from mypyc.primitives.registry import CFunctionDescription, builtin_names +from mypyc.primitives.dict_ops import dict_get_item_op, dict_new_op, dict_set_item_op from mypyc.primitives.generic_ops import iter_op -from mypyc.primitives.misc_ops import new_slice_op, ellipsis_op, type_op, get_module_dict_op +from mypyc.primitives.int_ops import int_comparison_op_mapping from mypyc.primitives.list_ops import list_append_op, list_extend_op, list_slice_op -from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op -from mypyc.primitives.dict_ops import dict_new_op, dict_set_item_op, dict_get_item_op +from mypyc.primitives.misc_ops import ellipsis_op, get_module_dict_op, new_slice_op, type_op +from mypyc.primitives.registry import CFunctionDescription, builtin_names from mypyc.primitives.set_ops import set_add_op, set_update_op from mypyc.primitives.str_ops import str_slice_op -from mypyc.primitives.int_ops import int_comparison_op_mapping -from mypyc.irbuild.specialize import apply_function_specialization, apply_method_specialization -from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op -from mypyc.irbuild.for_helpers import ( - translate_list_comprehension, translate_set_comprehension, - comprehension_helper -) -from mypyc.irbuild.constant_fold import constant_fold_expr -from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional - +from mypyc.primitives.tuple_ops import list_tuple_op, tuple_slice_op # Name and attribute references def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: if expr.node is None: - builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, - "mypyc internal error: should be unreachable", - expr.line)) + builder.add( + RaiseStandardError( + RaiseStandardError.RUNTIME_ERROR, + "mypyc internal error: should be unreachable", + expr.line, + ) + ) return builder.none() fullname = expr.node.fullname if fullname in builtin_names: typ, src = builtin_names[fullname] return builder.add(LoadAddress(typ, src, expr.line)) # special cases - if fullname == 'builtins.None': + if fullname == "builtins.None": return builder.none() - if fullname == 'builtins.True': + if fullname == "builtins.True": return builder.true() - if fullname == 'builtins.False': + if fullname == "builtins.False": return builder.false() if isinstance(expr.node, Var) and expr.node.is_final: @@ -89,16 +136,20 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: # assignment target and return it. Otherwise if the expression is a global, load it from # the globals dictionary. # Except for imports, that currently always happens in the global namespace. - if expr.kind == LDEF and not (isinstance(expr.node, Var) - and expr.node.is_suppressed_import): + if expr.kind == LDEF and not (isinstance(expr.node, Var) and expr.node.is_suppressed_import): # Try to detect and error when we hit the irritating mypy bug # where a local variable is cast to None. (#5423) - if (isinstance(expr.node, Var) and is_none_rprimitive(builder.node_type(expr)) - and expr.node.is_inferred): + if ( + isinstance(expr.node, Var) + and is_none_rprimitive(builder.node_type(expr)) + and expr.node.is_inferred + ): builder.error( 'Local variable "{}" has inferred type None; add an annotation'.format( - expr.node.name), - expr.node.line) + expr.node.name + ), + expr.node.line, + ) # TODO: Behavior currently only defined for Var, FuncDef and MypyFile node types. if isinstance(expr.node, MypyFile): @@ -109,9 +160,9 @@ def transform_name_expr(builder: IRBuilder, expr: NameExpr) -> Value: # AST doesn't include a Var node for the module. We # instead load the module separately on each access. mod_dict = builder.call_c(get_module_dict_op, [], expr.line) - obj = builder.call_c(dict_get_item_op, - [mod_dict, builder.load_str(expr.node.fullname)], - expr.line) + obj = builder.call_c( + dict_get_item_op, [mod_dict, builder.load_str(expr.node.fullname)], expr.line + ) return obj else: return builder.read(builder.get_assignment_target(expr), expr.line) @@ -124,8 +175,9 @@ def transform_member_expr(builder: IRBuilder, expr: MemberExpr) -> Value: final = builder.get_final_ref(expr) if final is not None: fullname, final_var, native = final - value = builder.emit_load_final(final_var, fullname, final_var.name, native, - builder.types[expr], expr.line) + value = builder.emit_load_final( + final_var, fullname, final_var.name, native, builder.types[expr], expr.line + ) if value is not None: return value @@ -139,10 +191,10 @@ def transform_member_expr(builder: IRBuilder, expr: MemberExpr) -> Value: # Special case: for named tuples transform attribute access to faster index access. typ = get_proper_type(builder.types.get(expr.expr)) if isinstance(typ, TupleType) and typ.partial_fallback.type.is_named_tuple: - fields = typ.partial_fallback.type.metadata['namedtuple']['fields'] + fields = typ.partial_fallback.type.metadata["namedtuple"]["fields"] if expr.name in fields: index = builder.builder.load_int(fields.index(expr.name)) - return builder.gen_method_call(obj, '__getitem__', [index], rtype, expr.line) + return builder.gen_method_call(obj, "__getitem__", [index], rtype, expr.line) check_instance_attribute_access_through_class(builder, expr, typ) @@ -150,9 +202,9 @@ def transform_member_expr(builder: IRBuilder, expr: MemberExpr) -> Value: return builder.builder.get_attr(obj, expr.name, rtype, expr.line, borrow=borrow) -def check_instance_attribute_access_through_class(builder: IRBuilder, - expr: MemberExpr, - typ: Optional[ProperType]) -> None: +def check_instance_attribute_access_through_class( + builder: IRBuilder, expr: MemberExpr, typ: Optional[ProperType] +) -> None: """Report error if accessing an instance attribute through class object.""" if isinstance(expr.expr, RefExpr): node = expr.expr.node @@ -163,25 +215,28 @@ def check_instance_attribute_access_through_class(builder: IRBuilder, class_ir = builder.mapper.type_to_ir.get(node) if class_ir is not None and class_ir.is_ext_class: sym = node.get(expr.name) - if (sym is not None - and isinstance(sym.node, Var) - and not sym.node.is_classvar - and not sym.node.is_final): + if ( + sym is not None + and isinstance(sym.node, Var) + and not sym.node.is_classvar + and not sym.node.is_final + ): builder.error( 'Cannot access instance attribute "{}" through class object'.format( - expr.name), - expr.line + expr.name + ), + expr.line, ) builder.note( '(Hint: Use "x: Final = ..." or "x: ClassVar = ..." to define ' - 'a class attribute)', - expr.line + "a class attribute)", + expr.line, ) def transform_super_expr(builder: IRBuilder, o: SuperExpr) -> Value: # warning(builder, 'can not optimize super() expression', o.line) - sup_val = builder.load_module_attr_by_fullname('builtins.super', o.line) + sup_val = builder.load_module_attr_by_fullname("builtins.super", o.line) if o.call.args: args = [builder.accept(arg) for arg in o.call.args] else: @@ -217,8 +272,9 @@ def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value: callee = callee.analyzed.expr # Unwrap type application if isinstance(callee, MemberExpr): - return apply_method_specialization(builder, expr, callee) or \ - translate_method_call(builder, expr, callee) + return apply_method_specialization(builder, expr, callee) or translate_method_call( + builder, expr, callee + ) elif isinstance(callee, SuperExpr): return translate_super_method_call(builder, expr, callee) else: @@ -228,13 +284,15 @@ def transform_call_expr(builder: IRBuilder, expr: CallExpr) -> Value: def translate_call(builder: IRBuilder, expr: CallExpr, callee: Expression) -> Value: # The common case of calls is refexprs if isinstance(callee, RefExpr): - return apply_function_specialization(builder, expr, callee) or \ - translate_refexpr_call(builder, expr, callee) + return apply_function_specialization(builder, expr, callee) or translate_refexpr_call( + builder, expr, callee + ) function = builder.accept(callee) args = [builder.accept(arg) for arg in expr.args] - return builder.py_call(function, args, expr.line, - arg_kinds=expr.arg_kinds, arg_names=expr.arg_names) + return builder.py_call( + function, args, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names + ) def translate_refexpr_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value: @@ -276,20 +334,23 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr return builder.builder.call(decl, args, arg_kinds, arg_names, expr.line) else: obj = builder.accept(callee.expr) - return builder.gen_method_call(obj, - callee.name, - args, - builder.node_type(expr), - expr.line, - expr.arg_kinds, - expr.arg_names) + return builder.gen_method_call( + obj, + callee.name, + args, + builder.node_type(expr), + expr.line, + expr.arg_kinds, + expr.arg_names, + ) elif builder.is_module_member_expr(callee): # Fall back to a PyCall for non-native module calls function = builder.accept(callee) args = [builder.accept(arg) for arg in expr.args] - return builder.py_call(function, args, expr.line, - arg_kinds=expr.arg_kinds, arg_names=expr.arg_names) + return builder.py_call( + function, args, expr.line, arg_kinds=expr.arg_kinds, arg_names=expr.arg_names + ) else: receiver_typ = builder.node_type(callee.expr) @@ -301,13 +362,15 @@ def translate_method_call(builder: IRBuilder, expr: CallExpr, callee: MemberExpr obj = builder.accept(callee.expr) args = [builder.accept(arg) for arg in expr.args] - return builder.gen_method_call(obj, - callee.name, - args, - builder.node_type(expr), - expr.line, - expr.arg_kinds, - expr.arg_names) + return builder.gen_method_call( + obj, + callee.name, + args, + builder.node_type(expr), + expr.line, + expr.arg_kinds, + expr.arg_names, + ) def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: SuperExpr) -> Value: @@ -340,11 +403,13 @@ def translate_super_method_call(builder: IRBuilder, expr: CallExpr, callee: Supe if callee.name in base.method_decls: break else: - if (ir.is_ext_class - and ir.builtin_base is None - and not ir.inherits_python - and callee.name == '__init__' - and len(expr.args) == 0): + if ( + ir.is_ext_class + and ir.builtin_base is None + and not ir.inherits_python + and callee.name == "__init__" + and len(expr.args) == 0 + ): # Call translates to object.__init__(self), which is a # no-op, so omit the call. return builder.none() @@ -392,11 +457,11 @@ def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value: def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value: - if expr.op in ('and', 'or'): + if expr.op in ("and", "or"): return builder.shortcircuit_expr(expr) # Special case for string formatting - if expr.op == '%' and (isinstance(expr.left, StrExpr) or isinstance(expr.left, BytesExpr)): + if expr.op == "%" and (isinstance(expr.left, StrExpr) or isinstance(expr.left, BytesExpr)): ret = translate_printf_style_formatting(builder, expr.left, expr.right) if ret is not None: return ret @@ -406,9 +471,10 @@ def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value: return folded # Special case some int ops to allow borrowing operands. - if (is_int_rprimitive(builder.node_type(expr.left)) - and is_int_rprimitive(builder.node_type(expr.right))): - if expr.op == '//': + if is_int_rprimitive(builder.node_type(expr.left)) and is_int_rprimitive( + builder.node_type(expr.right) + ): + if expr.op == "//": expr = try_optimize_int_floor_divide(expr) if expr.op in int_borrow_friendly_op: borrow_left = is_borrow_friendly_expr(builder, expr.right) @@ -428,7 +494,7 @@ def try_optimize_int_floor_divide(expr: OpExpr) -> OpExpr: divisor = expr.right.value shift = divisor.bit_length() - 1 if 0 < shift < 28 and divisor == (1 << shift): - return OpExpr('>>', expr.left, IntExpr(shift)) + return OpExpr(">>", expr.left, IntExpr(shift)) return expr @@ -450,7 +516,8 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value: index_reg = builder.accept(expr.index, can_borrow=is_list) return builder.gen_method_call( - base, '__getitem__', [index_reg], builder.node_type(expr), expr.line) + base, "__getitem__", [index_reg], builder.node_type(expr), expr.line + ) def try_constant_fold(builder: IRBuilder, expr: Expression) -> Optional[Value]: @@ -533,25 +600,27 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x in (...)/[...] # x not in (...)/[...] first_op = e.operators[0] - if (first_op in ['in', 'not in'] - and len(e.operators) == 1 - and isinstance(e.operands[1], (TupleExpr, ListExpr))): + if ( + first_op in ["in", "not in"] + and len(e.operators) == 1 + and isinstance(e.operands[1], (TupleExpr, ListExpr)) + ): items = e.operands[1].items n_items = len(items) # x in y -> x == y[0] or ... or x == y[n] # x not in y -> x != y[0] and ... and x != y[n] # 16 is arbitrarily chosen to limit code size if 1 < n_items < 16: - if e.operators[0] == 'in': - bin_op = 'or' - cmp_op = '==' + if e.operators[0] == "in": + bin_op = "or" + cmp_op = "==" else: - bin_op = 'and' - cmp_op = '!=' + bin_op = "and" + cmp_op = "!=" lhs = e.operands[0] - mypy_file = builder.graph['builtins'].tree + mypy_file = builder.graph["builtins"].tree assert mypy_file is not None - bool_type = Instance(cast(TypeInfo, mypy_file.names['bool'].node), []) + bool_type = Instance(cast(TypeInfo, mypy_file.names["bool"].node), []) exprs = [] for item in items: expr = ComparisonExpr([cmp_op], [lhs, item]) @@ -566,27 +635,27 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: # x in [y]/(y) -> x == y # x not in [y]/(y) -> x != y elif n_items == 1: - if e.operators[0] == 'in': - cmp_op = '==' + if e.operators[0] == "in": + cmp_op = "==" else: - cmp_op = '!=' + cmp_op = "!=" e.operators = [cmp_op] e.operands[1] = items[0] # x in []/() -> False # x not in []/() -> True elif n_items == 0: - if e.operators[0] == 'in': + if e.operators[0] == "in": return builder.false() else: return builder.true() if len(e.operators) == 1: # Special some common simple cases - if first_op in ('is', 'is not'): + if first_op in ("is", "is not"): right_expr = e.operands[1] - if isinstance(right_expr, NameExpr) and right_expr.fullname == 'builtins.None': + if isinstance(right_expr, NameExpr) and right_expr.fullname == "builtins.None": # Special case 'is None' / 'is not None'. - return translate_is_none(builder, e.operands[0], negated=first_op != 'is') + return translate_is_none(builder, e.operands[0], negated=first_op != "is") left_expr = e.operands[0] if is_int_rprimitive(builder.node_type(left_expr)): right_expr = e.operands[1] @@ -608,48 +677,51 @@ def transform_comparison_expr(builder: IRBuilder, e: ComparisonExpr) -> Value: def go(i: int, prev: Value) -> Value: if i == len(e.operators) - 1: return transform_basic_comparison( - builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line) + builder, e.operators[i], prev, builder.accept(e.operands[i + 1]), e.line + ) next = builder.accept(e.operands[i + 1]) return builder.builder.shortcircuit_helper( - 'and', expr_type, - lambda: transform_basic_comparison( - builder, e.operators[i], prev, next, e.line), + "and", + expr_type, + lambda: transform_basic_comparison(builder, e.operators[i], prev, next, e.line), lambda: go(i + 1, next), - e.line) + e.line, + ) return go(0, builder.accept(e.operands[0])) def translate_is_none(builder: IRBuilder, expr: Expression, negated: bool) -> Value: v = builder.accept(expr, can_borrow=True) - return builder.binary_op(v, builder.none_object(), 'is not' if negated else 'is', expr.line) + return builder.binary_op(v, builder.none_object(), "is not" if negated else "is", expr.line) -def transform_basic_comparison(builder: IRBuilder, - op: str, - left: Value, - right: Value, - line: int) -> Value: - if (is_int_rprimitive(left.type) and is_int_rprimitive(right.type) - and op in int_comparison_op_mapping.keys()): +def transform_basic_comparison( + builder: IRBuilder, op: str, left: Value, right: Value, line: int +) -> Value: + if ( + is_int_rprimitive(left.type) + and is_int_rprimitive(right.type) + and op in int_comparison_op_mapping.keys() + ): return builder.compare_tagged(left, right, op, line) negate = False - if op == 'is not': - op, negate = 'is', True - elif op == 'not in': - op, negate = 'in', True + if op == "is not": + op, negate = "is", True + elif op == "not in": + op, negate = "in", True target = builder.binary_op(left, right, op, line) if negate: - target = builder.unary_op(target, 'not', line) + target = builder.unary_op(target, "not", line) return target -def translate_printf_style_formatting(builder: IRBuilder, - format_expr: Union[StrExpr, BytesExpr], - rhs: Expression) -> Optional[Value]: +def translate_printf_style_formatting( + builder: IRBuilder, format_expr: Union[StrExpr, BytesExpr], rhs: Expression +) -> Optional[Value]: tokens = tokenizer_printf_style(format_expr.value) if tokens is not None: literals, format_ops = tokens @@ -661,13 +733,15 @@ def translate_printf_style_formatting(builder: IRBuilder, exprs.append(rhs) if isinstance(format_expr, BytesExpr): - substitutions = convert_format_expr_to_bytes(builder, format_ops, - exprs, format_expr.line) + substitutions = convert_format_expr_to_bytes( + builder, format_ops, exprs, format_expr.line + ) if substitutions is not None: return join_formatted_bytes(builder, literals, substitutions, format_expr.line) else: - substitutions = convert_format_expr_to_str(builder, format_ops, - exprs, format_expr.line) + substitutions = convert_format_expr_to_str( + builder, format_ops, exprs, format_expr.line + ) if substitutions is not None: return join_formatted_strings(builder, literals, substitutions, format_expr.line) @@ -710,13 +784,7 @@ def transform_list_expr(builder: IRBuilder, expr: ListExpr) -> Value: def _visit_list_display(builder: IRBuilder, items: List[Expression], line: int) -> Value: return _visit_display( - builder, - items, - builder.new_list_op, - list_append_op, - list_extend_op, - line, - True + builder, items, builder.new_list_op, list_append_op, list_extend_op, line, True ) @@ -729,8 +797,11 @@ def transform_tuple_expr(builder: IRBuilder, expr: TupleExpr) -> Value: tuple_type = builder.node_type(expr) # When handling NamedTuple et. al we might not have proper type info, # so make some up if we need it. - types = (tuple_type.types if isinstance(tuple_type, RTuple) - else [object_rprimitive] * len(expr.items)) + types = ( + tuple_type.types + if isinstance(tuple_type, RTuple) + else [object_rprimitive] * len(expr.items) + ) items = [] for item_expr, item_type in zip(expr.items, types): @@ -758,24 +829,19 @@ def transform_dict_expr(builder: IRBuilder, expr: DictExpr) -> Value: def transform_set_expr(builder: IRBuilder, expr: SetExpr) -> Value: return _visit_display( - builder, - expr.items, - builder.new_set_op, - set_add_op, - set_update_op, - expr.line, - False + builder, expr.items, builder.new_set_op, set_add_op, set_update_op, expr.line, False ) -def _visit_display(builder: IRBuilder, - items: List[Expression], - constructor_op: Callable[[List[Value], int], Value], - append_op: CFunctionDescription, - extend_op: CFunctionDescription, - line: int, - is_list: bool - ) -> Value: +def _visit_display( + builder: IRBuilder, + items: List[Expression], + constructor_op: Callable[[List[Value], int], Value], + append_op: CFunctionDescription, + extend_op: CFunctionDescription, + line: int, + is_list: bool, +) -> Value: accepted_items = [] for item in items: if isinstance(item, StarExpr): @@ -806,19 +872,19 @@ def _visit_display(builder: IRBuilder, def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value: if any(o.generator.is_async): - builder.error('async comprehensions are unimplemented', o.line) + builder.error("async comprehensions are unimplemented", o.line) return translate_list_comprehension(builder, o.generator) def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value: if any(o.generator.is_async): - builder.error('async comprehensions are unimplemented', o.line) + builder.error("async comprehensions are unimplemented", o.line) return translate_set_comprehension(builder, o.generator) def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value: if any(o.is_async): - builder.error('async comprehensions are unimplemented', o.line) + builder.error("async comprehensions are unimplemented", o.line) d = builder.call_c(dict_new_op, [], o.line) loop_params = list(zip(o.indices, o.sequences, o.condlists)) @@ -842,20 +908,16 @@ def get_arg(arg: Optional[Expression]) -> Value: else: return builder.accept(arg) - args = [get_arg(expr.begin_index), - get_arg(expr.end_index), - get_arg(expr.stride)] + args = [get_arg(expr.begin_index), get_arg(expr.end_index), get_arg(expr.stride)] return builder.call_c(new_slice_op, args, expr.line) def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value: if any(o.is_async): - builder.error('async comprehensions are unimplemented', o.line) + builder.error("async comprehensions are unimplemented", o.line) - builder.warning('Treating generator comprehension as list', o.line) - return builder.call_c( - iter_op, [translate_list_comprehension(builder, o)], o.line - ) + builder.warning("Treating generator comprehension as list", o.line) + return builder.call_c(iter_op, [translate_list_comprehension(builder, o)], o.line) def transform_assignment_expr(builder: IRBuilder, o: AssignmentExpr) -> Value: diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index ae592ae91087e..19cc383ace609 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -5,38 +5,62 @@ such special case. """ -from typing import Union, List, Optional, Tuple, Callable -from typing_extensions import Type, ClassVar +from typing import Callable, List, Optional, Tuple, Union + +from typing_extensions import ClassVar, Type from mypy.nodes import ( - Lvalue, Expression, TupleExpr, CallExpr, RefExpr, GeneratorExpr, ARG_POS, MemberExpr, TypeAlias -) -from mypyc.ir.ops import ( - Value, BasicBlock, Integer, Branch, Register, TupleGet, TupleSet, IntOp + ARG_POS, + CallExpr, + Expression, + GeneratorExpr, + Lvalue, + MemberExpr, + RefExpr, + TupleExpr, + TypeAlias, ) +from mypyc.ir.ops import BasicBlock, Branch, Integer, IntOp, Register, TupleGet, TupleSet, Value from mypyc.ir.rtypes import ( - RType, is_short_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive, - is_tuple_rprimitive, is_dict_rprimitive, is_str_rprimitive, - RTuple, short_int_rprimitive, int_rprimitive + RTuple, + RType, + int_rprimitive, + is_dict_rprimitive, + is_list_rprimitive, + is_sequence_rprimitive, + is_short_int_rprimitive, + is_str_rprimitive, + is_tuple_rprimitive, + short_int_rprimitive, ) -from mypyc.primitives.registry import CFunctionDescription +from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple from mypyc.primitives.dict_ops import ( - dict_next_key_op, dict_next_value_op, dict_next_item_op, dict_check_size_op, - dict_key_iter_op, dict_value_iter_op, dict_item_iter_op + dict_check_size_op, + dict_item_iter_op, + dict_key_iter_op, + dict_next_item_op, + dict_next_key_op, + dict_next_value_op, + dict_value_iter_op, ) +from mypyc.primitives.exc_ops import no_err_occurred_op +from mypyc.primitives.generic_ops import iter_op, next_op from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op +from mypyc.primitives.registry import CFunctionDescription from mypyc.primitives.set_ops import set_add_op -from mypyc.primitives.generic_ops import iter_op, next_op -from mypyc.primitives.exc_ops import no_err_occurred_op -from mypyc.irbuild.builder import IRBuilder -from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple GenFunc = Callable[[], None] -def for_loop_helper(builder: IRBuilder, index: Lvalue, expr: Expression, - body_insts: GenFunc, else_insts: Optional[GenFunc], - line: int) -> None: +def for_loop_helper( + builder: IRBuilder, + index: Lvalue, + expr: Expression, + body_insts: GenFunc, + else_insts: Optional[GenFunc], + line: int, +) -> None: """Generate IR for a loop. Args: @@ -88,11 +112,14 @@ def for_loop_helper(builder: IRBuilder, index: Lvalue, expr: Expression, builder.activate_block(exit_block) -def for_loop_helper_with_index(builder: IRBuilder, - index: Lvalue, - expr: Expression, - expr_reg: Value, - body_insts: Callable[[Value], None], line: int) -> None: +def for_loop_helper_with_index( + builder: IRBuilder, + index: Lvalue, + expr: Expression, + expr_reg: Value, + body_insts: Callable[[Value], None], + line: int, +) -> None: """Generate IR for a sequence iteration. This function only works for sequence type. Compared to for_loop_helper, @@ -135,10 +162,11 @@ def for_loop_helper_with_index(builder: IRBuilder, def sequence_from_generator_preallocate_helper( - builder: IRBuilder, - gen: GeneratorExpr, - empty_op_llbuilder: Callable[[Value, int], Value], - set_item_op: CFunctionDescription) -> Optional[Value]: + builder: IRBuilder, + gen: GeneratorExpr, + empty_op_llbuilder: Callable[[Value, int], Value], + set_item_op: CFunctionDescription, +) -> Optional[Value]: """Generate a new tuple or list from a simple generator expression. Currently we only optimize for simplest generator expression, which means that @@ -164,8 +192,7 @@ def sequence_from_generator_preallocate_helper( """ if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0: rtype = builder.node_type(gen.sequences[0]) - if (is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) - or is_str_rprimitive(rtype)): + if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype): sequence = builder.accept(gen.sequences[0]) length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True) target_op = empty_op_llbuilder(length, gen.line) @@ -174,8 +201,9 @@ def set_item(item_index: Value) -> None: e = builder.accept(gen.left_expr) builder.call_c(set_item_op, [target_op, item_index, e], gen.line) - for_loop_helper_with_index(builder, gen.indices[0], gen.sequences[0], sequence, - set_item, gen.line) + for_loop_helper_with_index( + builder, gen.indices[0], gen.sequences[0], sequence, set_item, gen.line + ) return target_op return None @@ -184,9 +212,11 @@ def set_item(item_index: Value) -> None: def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value: # Try simplest list comprehension, otherwise fall back to general one val = sequence_from_generator_preallocate_helper( - builder, gen, + builder, + gen, empty_op_llbuilder=builder.builder.new_list_op_with_length, - set_item_op=new_list_set_item_op) + set_item_op=new_list_set_item_op, + ) if val is not None: return val @@ -213,10 +243,12 @@ def gen_inner_stmts() -> None: return set_ops -def comprehension_helper(builder: IRBuilder, - loop_params: List[Tuple[Lvalue, Expression, List[Expression]]], - gen_inner_stmts: Callable[[], None], - line: int) -> None: +def comprehension_helper( + builder: IRBuilder, + loop_params: List[Tuple[Lvalue, Expression, List[Expression]]], + gen_inner_stmts: Callable[[], None], + line: int, +) -> None: """Helper function for list comprehensions. Args: @@ -227,6 +259,7 @@ def comprehension_helper(builder: IRBuilder, that must all be true for the loop body to be executed gen_inner_stmts: function to generate the IR for the body of the innermost loop """ + def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression]]]) -> None: """Generate IR for a loop. @@ -234,13 +267,13 @@ def handle_loop(loop_params: List[Tuple[Lvalue, Expression, List[Expression]]]) for the nested loops the list defines. """ index, expr, conds = loop_params[0] - for_loop_helper(builder, index, expr, - lambda: loop_contents(conds, loop_params[1:]), - None, line) + for_loop_helper( + builder, index, expr, lambda: loop_contents(conds, loop_params[1:]), None, line + ) def loop_contents( - conds: List[Expression], - remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression]]], + conds: List[Expression], + remaining_loop_params: List[Tuple[Lvalue, Expression, List[Expression]]], ) -> None: """Generate the body of the loop. @@ -272,17 +305,22 @@ def loop_contents( def is_range_ref(expr: RefExpr) -> bool: - return (expr.fullname == 'builtins.range' - or isinstance(expr.node, TypeAlias) and expr.fullname == 'six.moves.xrange') - - -def make_for_loop_generator(builder: IRBuilder, - index: Lvalue, - expr: Expression, - body_block: BasicBlock, - loop_exit: BasicBlock, - line: int, - nested: bool = False) -> 'ForGenerator': + return ( + expr.fullname == "builtins.range" + or isinstance(expr.node, TypeAlias) + and expr.fullname == "six.moves.xrange" + ) + + +def make_for_loop_generator( + builder: IRBuilder, + index: Lvalue, + expr: Expression, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + nested: bool = False, +) -> "ForGenerator": """Return helper object for generating a for loop over an iterable. If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)". @@ -307,13 +345,15 @@ def make_for_loop_generator(builder: IRBuilder, for_dict.init(expr_reg, target_type) return for_dict - if (isinstance(expr, CallExpr) - and isinstance(expr.callee, RefExpr)): - if (is_range_ref(expr.callee) - and (len(expr.args) <= 2 - or (len(expr.args) == 3 - and builder.extract_int(expr.args[2]) is not None)) - and set(expr.arg_kinds) == {ARG_POS}): + if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr): + if ( + is_range_ref(expr.callee) + and ( + len(expr.args) <= 2 + or (len(expr.args) == 3 and builder.extract_int(expr.args[2]) is not None) + ) + and set(expr.arg_kinds) == {ARG_POS} + ): # Special case "for x in range(...)". # We support the 3 arg form but only for int literals, since it doesn't # seem worth the hassle of supporting dynamically determining which @@ -336,33 +376,38 @@ def make_for_loop_generator(builder: IRBuilder, for_range.init(start_reg, end_reg, step) return for_range - elif (expr.callee.fullname == 'builtins.enumerate' - and len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(index, TupleExpr) - and len(index.items) == 2): + elif ( + expr.callee.fullname == "builtins.enumerate" + and len(expr.args) == 1 + and expr.arg_kinds == [ARG_POS] + and isinstance(index, TupleExpr) + and len(index.items) == 2 + ): # Special case "for i, x in enumerate(y)". lvalue1 = index.items[0] lvalue2 = index.items[1] - for_enumerate = ForEnumerate(builder, index, body_block, loop_exit, line, - nested) + for_enumerate = ForEnumerate(builder, index, body_block, loop_exit, line, nested) for_enumerate.init(lvalue1, lvalue2, expr.args[0]) return for_enumerate - elif (expr.callee.fullname == 'builtins.zip' - and len(expr.args) >= 2 - and set(expr.arg_kinds) == {ARG_POS} - and isinstance(index, TupleExpr) - and len(index.items) == len(expr.args)): + elif ( + expr.callee.fullname == "builtins.zip" + and len(expr.args) >= 2 + and set(expr.arg_kinds) == {ARG_POS} + and isinstance(index, TupleExpr) + and len(index.items) == len(expr.args) + ): # Special case "for x, y in zip(a, b)". for_zip = ForZip(builder, index, body_block, loop_exit, line, nested) for_zip.init(index.items, expr.args) return for_zip - if (expr.callee.fullname == 'builtins.reversed' - and len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and is_sequence_rprimitive(builder.node_type(expr.args[0]))): + if ( + expr.callee.fullname == "builtins.reversed" + and len(expr.args) == 1 + and expr.arg_kinds == [ARG_POS] + and is_sequence_rprimitive(builder.node_type(expr.args[0])) + ): # Special case "for x in reversed()". expr_reg = builder.accept(expr.args[0]) target_type = builder.get_sequence_type(expr) @@ -370,19 +415,16 @@ def make_for_loop_generator(builder: IRBuilder, for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) for_list.init(expr_reg, target_type, reverse=True) return for_list - if (isinstance(expr, CallExpr) - and isinstance(expr.callee, MemberExpr) - and not expr.args): + if isinstance(expr, CallExpr) and isinstance(expr.callee, MemberExpr) and not expr.args: # Special cases for dictionary iterator methods, like dict.items(). rtype = builder.node_type(expr.callee.expr) - if (is_dict_rprimitive(rtype) - and expr.callee.name in ('keys', 'values', 'items')): + if is_dict_rprimitive(rtype) and expr.callee.name in ("keys", "values", "items"): expr_reg = builder.accept(expr.callee.expr) for_dict_type: Optional[Type[ForGenerator]] = None - if expr.callee.name == 'keys': + if expr.callee.name == "keys": target_type = builder.get_dict_key_type(expr.callee.expr) for_dict_type = ForDictionaryKeys - elif expr.callee.name == 'values': + elif expr.callee.name == "values": target_type = builder.get_dict_value_type(expr.callee.expr) for_dict_type = ForDictionaryValues else: @@ -404,13 +446,15 @@ def make_for_loop_generator(builder: IRBuilder, class ForGenerator: """Abstract base class for generating for loops.""" - def __init__(self, - builder: IRBuilder, - index: Lvalue, - body_block: BasicBlock, - loop_exit: BasicBlock, - line: int, - nested: bool) -> None: + def __init__( + self, + builder: IRBuilder, + index: Lvalue, + body_block: BasicBlock, + loop_exit: BasicBlock, + line: int, + nested: bool, + ) -> None: self.builder = builder self.index = index self.body_block = body_block @@ -504,9 +548,7 @@ def gen_cleanup(self) -> None: self.builder.call_c(no_err_occurred_op, [], self.line) -def unsafe_index( - builder: IRBuilder, target: Value, index: Value, line: int -) -> Value: +def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value: """Emit a potentially unsafe index into a target.""" # This doesn't really fit nicely into any of our data-driven frameworks # since we want to use __getitem__ if we don't have an unsafe version, @@ -514,7 +556,7 @@ def unsafe_index( if is_list_rprimitive(target.type): return builder.call_c(list_get_item_unsafe_op, [target, index], line) else: - return builder.gen_method_call(target, '__getitem__', [index], None, line) + return builder.gen_method_call(target, "__getitem__", [index], None, line) class ForSequence(ForGenerator): @@ -533,8 +575,9 @@ def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None: if not reverse: index_reg: Value = Integer(0) else: - index_reg = builder.binary_op(self.load_len(self.expr_target), - Integer(1), '-', self.line) + index_reg = builder.binary_op( + self.load_len(self.expr_target), Integer(1), "-", self.line + ) self.index_target = builder.maybe_spill_assignable(index_reg) self.target_type = target_type @@ -547,15 +590,16 @@ def gen_condition(self) -> None: # to check that the index is still positive. Somewhat less # obviously we still need to check against the length, # since it could shrink out from under us. - comparison = builder.binary_op(builder.read(self.index_target, line), - Integer(0), '>=', line) + comparison = builder.binary_op( + builder.read(self.index_target, line), Integer(0), ">=", line + ) second_check = BasicBlock() builder.add_bool_branch(comparison, second_check, self.loop_exit) builder.activate_block(second_check) # For compatibility with python semantics we recalculate the length # at every iteration. len_reg = self.load_len(self.expr_target) - comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, '<', line) + comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, "<", line) builder.add_bool_branch(comparison, self.body_block, self.loop_exit) def begin_body(self) -> None: @@ -566,23 +610,30 @@ def begin_body(self) -> None: builder, builder.read(self.expr_target, line), builder.read(self.index_target, line), - line + line, ) assert value_box # We coerce to the type of list elements here so that # iterating with tuple unpacking generates a tuple based # unpack instead of an iterator based one. - builder.assign(builder.get_assignment_target(self.index), - builder.coerce(value_box, self.target_type, line), line) + builder.assign( + builder.get_assignment_target(self.index), + builder.coerce(value_box, self.target_type, line), + line, + ) def gen_step(self) -> None: # Step to the next item. builder = self.builder line = self.line step = 1 if not self.reverse else -1 - add = builder.int_op(short_int_rprimitive, - builder.read(self.index_target, line), - Integer(step), IntOp.ADD, line) + add = builder.int_op( + short_int_rprimitive, + builder.read(self.index_target, line), + Integer(step), + IntOp.ADD, + line, + ) builder.assign(self.index_target, add, line) @@ -629,17 +680,17 @@ def gen_condition(self) -> None: builder = self.builder line = self.line self.next_tuple = self.builder.call_c( - self.dict_next_op, [builder.read(self.iter_target, line), - builder.read(self.offset_target, line)], line) + self.dict_next_op, + [builder.read(self.iter_target, line), builder.read(self.offset_target, line)], + line, + ) # Do this here instead of in gen_step() to minimize variables in environment. new_offset = builder.add(TupleGet(self.next_tuple, 1, line)) builder.assign(self.offset_target, new_offset, line) should_continue = builder.add(TupleGet(self.next_tuple, 0, line)) - builder.add( - Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL) - ) + builder.add(Branch(should_continue, self.body_block, self.loop_exit, Branch.BOOL)) def gen_step(self) -> None: """Check that dictionary didn't change size during iteration. @@ -649,9 +700,11 @@ def gen_step(self) -> None: builder = self.builder line = self.line # Technically, we don't need a new primitive for this, but it is simpler. - builder.call_c(dict_check_size_op, - [builder.read(self.expr_target, line), - builder.read(self.size, line)], line) + builder.call_c( + dict_check_size_op, + [builder.read(self.expr_target, line), builder.read(self.size, line)], + line, + ) def gen_cleanup(self) -> None: # Same as for generic ForIterable. @@ -660,6 +713,7 @@ def gen_cleanup(self) -> None: class ForDictionaryKeys(ForDictionaryCommon): """Generate optimized IR for a for loop over dictionary keys.""" + dict_next_op = dict_next_key_op dict_iter_op = dict_key_iter_op @@ -669,12 +723,16 @@ def begin_body(self) -> None: # Key is stored at the third place in the tuple. key = builder.add(TupleGet(self.next_tuple, 2, line)) - builder.assign(builder.get_assignment_target(self.index), - builder.coerce(key, self.target_type, line), line) + builder.assign( + builder.get_assignment_target(self.index), + builder.coerce(key, self.target_type, line), + line, + ) class ForDictionaryValues(ForDictionaryCommon): """Generate optimized IR for a for loop over dictionary values.""" + dict_next_op = dict_next_value_op dict_iter_op = dict_value_iter_op @@ -684,12 +742,16 @@ def begin_body(self) -> None: # Value is stored at the third place in the tuple. value = builder.add(TupleGet(self.next_tuple, 2, line)) - builder.assign(builder.get_assignment_target(self.index), - builder.coerce(value, self.target_type, line), line) + builder.assign( + builder.get_assignment_target(self.index), + builder.coerce(value, self.target_type, line), + line, + ) class ForDictionaryItems(ForDictionaryCommon): """Generate optimized IR for a for loop over dictionary items.""" + dict_next_op = dict_next_item_op dict_iter_op = dict_item_iter_op @@ -743,9 +805,10 @@ def gen_condition(self) -> None: builder = self.builder line = self.line # Add loop condition check. - cmp = '<' if self.step > 0 else '>' - comparison = builder.binary_op(builder.read(self.index_reg, line), - builder.read(self.end_target, line), cmp, line) + cmp = "<" if self.step > 0 else ">" + comparison = builder.binary_op( + builder.read(self.index_reg, line), builder.read(self.end_target, line), cmp, line + ) builder.add_bool_branch(comparison, self.body_block, self.loop_exit) def gen_step(self) -> None: @@ -754,15 +817,21 @@ def gen_step(self) -> None: # Increment index register. If the range is known to fit in short ints, use # short ints. - if (is_short_int_rprimitive(self.start_reg.type) - and is_short_int_rprimitive(self.end_reg.type)): - new_val = builder.int_op(short_int_rprimitive, - builder.read(self.index_reg, line), - Integer(self.step), IntOp.ADD, line) + if is_short_int_rprimitive(self.start_reg.type) and is_short_int_rprimitive( + self.end_reg.type + ): + new_val = builder.int_op( + short_int_rprimitive, + builder.read(self.index_reg, line), + Integer(self.step), + IntOp.ADD, + line, + ) else: new_val = builder.binary_op( - builder.read(self.index_reg, line), Integer(self.step), '+', line) + builder.read(self.index_reg, line), Integer(self.step), "+", line + ) builder.assign(self.index_reg, new_val, line) builder.assign(self.index_target, new_val, line) @@ -787,9 +856,9 @@ def gen_step(self) -> None: # We can safely assume that the integer is short, since we are not going to wrap # around a 63-bit integer. # NOTE: This would be questionable if short ints could be 32 bits. - new_val = builder.int_op(short_int_rprimitive, - builder.read(self.index_reg, line), - Integer(1), IntOp.ADD, line) + new_val = builder.int_op( + short_int_rprimitive, builder.read(self.index_reg, line), Integer(1), IntOp.ADD, line + ) builder.assign(self.index_reg, new_val, line) builder.assign(self.index_target, new_val, line) @@ -805,20 +874,13 @@ def need_cleanup(self) -> bool: def init(self, index1: Lvalue, index2: Lvalue, expr: Expression) -> None: # Count from 0 to infinity (for the index lvalue). self.index_gen = ForInfiniteCounter( - self.builder, - index1, - self.body_block, - self.loop_exit, - self.line, nested=True) + self.builder, index1, self.body_block, self.loop_exit, self.line, nested=True + ) self.index_gen.init() # Iterate over the actual iterable. self.main_gen = make_for_loop_generator( - self.builder, - index2, - expr, - self.body_block, - self.loop_exit, - self.line, nested=True) + self.builder, index2, expr, self.body_block, self.loop_exit, self.line, nested=True + ) def gen_condition(self) -> None: # No need for a check for the index generator, since it's unconditional. @@ -853,12 +915,8 @@ def init(self, indexes: List[Lvalue], exprs: List[Expression]) -> None: self.gens: List[ForGenerator] = [] for index, expr, next_block in zip(indexes, exprs, self.cond_blocks): gen = make_for_loop_generator( - self.builder, - index, - expr, - next_block, - self.loop_exit, - self.line, nested=True) + self.builder, index, expr, next_block, self.loop_exit, self.line, nested=True + ) self.gens.append(gen) def gen_condition(self) -> None: diff --git a/mypyc/irbuild/format_str_tokenizer.py b/mypyc/irbuild/format_str_tokenizer.py index 721f28dbe3855..8c28621927fb7 100644 --- a/mypyc/irbuild/format_str_tokenizer.py +++ b/mypyc/irbuild/format_str_tokenizer.py @@ -1,20 +1,25 @@ """Tokenizers for three string formatting methods""" -from typing import List, Tuple, Optional -from typing_extensions import Final from enum import Enum, unique +from typing import List, Optional, Tuple + +from typing_extensions import Final from mypy.checkstrformat import ( - parse_format_value, ConversionSpecifier, parse_conversion_specifiers + ConversionSpecifier, + parse_conversion_specifiers, + parse_format_value, ) from mypy.errors import Errors from mypy.messages import MessageBuilder from mypy.nodes import Context, Expression - -from mypyc.ir.ops import Value, Integer +from mypyc.ir.ops import Integer, Value from mypyc.ir.rtypes import ( - c_pyssize_t_rprimitive, is_str_rprimitive, is_int_rprimitive, is_short_int_rprimitive, - is_bytes_rprimitive + c_pyssize_t_rprimitive, + is_bytes_rprimitive, + is_int_rprimitive, + is_short_int_rprimitive, + is_str_rprimitive, ) from mypyc.irbuild.builder import IRBuilder from mypyc.primitives.bytes_ops import bytes_build_op @@ -32,9 +37,10 @@ class FormatOp(Enum): ConversionSpecifier may have several representations, like '%s', '{}' or '{:{}}'. However, there would only exist one corresponding FormatOp. """ - STR = 's' - INT = 'd' - BYTES = 'b' + + STR = "s" + INT = "d" + BYTES = "b" def generate_format_ops(specifiers: List[ConversionSpecifier]) -> Optional[List[FormatOp]]: @@ -45,11 +51,11 @@ def generate_format_ops(specifiers: List[ConversionSpecifier]) -> Optional[List[ format_ops = [] for spec in specifiers: # TODO: Match specifiers instead of using whole_seq - if spec.whole_seq == '%s' or spec.whole_seq == '{:{}}': + if spec.whole_seq == "%s" or spec.whole_seq == "{:{}}": format_op = FormatOp.STR - elif spec.whole_seq == '%d': + elif spec.whole_seq == "%d": format_op = FormatOp.INT - elif spec.whole_seq == '%b': + elif spec.whole_seq == "%b": format_op = FormatOp.BYTES elif spec.whole_seq: return None @@ -86,8 +92,7 @@ def tokenizer_printf_style(format_str: str) -> Optional[Tuple[List[str], List[Fo EMPTY_CONTEXT: Final = Context() -def tokenizer_format_call( - format_str: str) -> Optional[Tuple[List[str], List[FormatOp]]]: +def tokenizer_format_call(format_str: str) -> Optional[Tuple[List[str], List[FormatOp]]]: """Tokenize a str.format() format string. The core function parse_format_value() is shared with mypy. @@ -103,8 +108,7 @@ def tokenizer_format_call( """ # Creates an empty MessageBuilder here. # It wouldn't be used since the code has passed the type-checking. - specifiers = parse_format_value(format_str, EMPTY_CONTEXT, - MessageBuilder(Errors(), {})) + specifiers = parse_format_value(format_str, EMPTY_CONTEXT, MessageBuilder(Errors(), {})) if specifiers is None: return None format_ops = generate_format_ops(specifiers) @@ -115,17 +119,18 @@ def tokenizer_format_call( last_end = 0 for spec in specifiers: # Skip { and } - literals.append(format_str[last_end:spec.start_pos - 1]) + literals.append(format_str[last_end : spec.start_pos - 1]) last_end = spec.start_pos + len(spec.whole_seq) + 1 literals.append(format_str[last_end:]) # Deal with escaped {{ - literals = [x.replace('{{', '{').replace('}}', '}') for x in literals] + literals = [x.replace("{{", "{").replace("}}", "}") for x in literals] return literals, format_ops -def convert_format_expr_to_str(builder: IRBuilder, format_ops: List[FormatOp], - exprs: List[Expression], line: int) -> Optional[List[Value]]: +def convert_format_expr_to_str( + builder: IRBuilder, format_ops: List[FormatOp], exprs: List[Expression], line: int +) -> Optional[List[Value]]: """Convert expressions into string literal objects with the guidance of FormatOps. Return None when fails.""" if len(format_ops) != len(exprs): @@ -152,8 +157,9 @@ def convert_format_expr_to_str(builder: IRBuilder, format_ops: List[FormatOp], return converted -def join_formatted_strings(builder: IRBuilder, literals: Optional[List[str]], - substitutions: List[Value], line: int) -> Value: +def join_formatted_strings( + builder: IRBuilder, literals: Optional[List[str]], substitutions: List[Value], line: int +) -> Value: """Merge the list of literals and the list of substitutions alternatively using 'str_build_op'. @@ -194,8 +200,9 @@ def join_formatted_strings(builder: IRBuilder, literals: Optional[List[str]], return builder.call_c(str_build_op, result_list, line) -def convert_format_expr_to_bytes(builder: IRBuilder, format_ops: List[FormatOp], - exprs: List[Expression], line: int) -> Optional[List[Value]]: +def convert_format_expr_to_bytes( + builder: IRBuilder, format_ops: List[FormatOp], exprs: List[Expression], line: int +) -> Optional[List[Value]]: """Convert expressions into bytes literal objects with the guidance of FormatOps. Return None when fails.""" if len(format_ops) != len(exprs): @@ -216,8 +223,9 @@ def convert_format_expr_to_bytes(builder: IRBuilder, format_ops: List[FormatOp], return converted -def join_formatted_bytes(builder: IRBuilder, literals: List[str], - substitutions: List[Value], line: int) -> Value: +def join_formatted_bytes( + builder: IRBuilder, literals: List[str], substitutions: List[Value], line: int +) -> Value: """Merge the list of literals and the list of substitutions alternatively using 'bytes_build_op'.""" result_list: List[Value] = [Integer(0, c_pyssize_t_rprimitive)] @@ -231,7 +239,7 @@ def join_formatted_bytes(builder: IRBuilder, literals: List[str], # Special case for empty bytes and literal if len(result_list) == 1: - return builder.load_bytes_from_str_literal('') + return builder.load_bytes_from_str_literal("") if not substitutions and len(result_list) == 2: return result_list[1] diff --git a/mypyc/irbuild/function.py b/mypyc/irbuild/function.py index 2c771df08809c..f6e5854f1e5b6 100644 --- a/mypyc/irbuild/function.py +++ b/mypyc/irbuild/function.py @@ -10,57 +10,96 @@ instance of the callable class. """ -from typing import ( - DefaultDict, NamedTuple, Optional, List, Sequence, Tuple, Union, Dict -) +from collections import defaultdict +from typing import DefaultDict, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from mypy.nodes import ( - ClassDef, FuncDef, OverloadedFuncDef, Decorator, Var, YieldFromExpr, AwaitExpr, YieldExpr, - FuncItem, LambdaExpr, SymbolNode, ArgKind, TypeInfo + ArgKind, + AwaitExpr, + ClassDef, + Decorator, + FuncDef, + FuncItem, + LambdaExpr, + OverloadedFuncDef, + SymbolNode, + TypeInfo, + Var, + YieldExpr, + YieldFromExpr, ) from mypy.types import CallableType, get_proper_type - +from mypyc.common import LAMBDA_NAME, SELF_NAME +from mypyc.ir.class_ir import ClassIR, NonExtClassInfo +from mypyc.ir.func_ir import ( + FUNC_CLASSMETHOD, + FUNC_NORMAL, + FUNC_STATICMETHOD, + FuncDecl, + FuncIR, + FuncSignature, + RuntimeArg, +) from mypyc.ir.ops import ( - BasicBlock, Value, Register, Return, SetAttr, Integer, GetAttr, Branch, InitStatic, - LoadAddress, LoadLiteral, Unbox, Unreachable, + BasicBlock, + Branch, + GetAttr, + InitStatic, + Integer, + LoadAddress, + LoadLiteral, + Register, + Return, + SetAttr, + Unbox, + Unreachable, + Value, ) from mypyc.ir.rtypes import ( - object_rprimitive, RInstance, object_pointer_rprimitive, dict_rprimitive, int_rprimitive, + RInstance, bool_rprimitive, + dict_rprimitive, + int_rprimitive, + object_pointer_rprimitive, + object_rprimitive, ) -from mypyc.ir.func_ir import ( - FuncIR, FuncSignature, RuntimeArg, FuncDecl, FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FUNC_NORMAL -) -from mypyc.ir.class_ir import ClassIR, NonExtClassInfo -from mypyc.primitives.generic_ops import py_setattr_op, next_raw_op, iter_op -from mypyc.primitives.misc_ops import ( - check_stop_op, yield_from_except_op, coro_op, send_op, register_function -) -from mypyc.primitives.dict_ops import dict_set_item_op, dict_new_op, dict_get_method_with_none -from mypyc.common import SELF_NAME, LAMBDA_NAME -from mypyc.sametype import is_same_method_signature -from mypyc.irbuild.util import is_constant -from mypyc.irbuild.context import FuncInfo, ImplicitClass -from mypyc.irbuild.targets import AssignmentTarget -from mypyc.irbuild.statement import transform_try_except from mypyc.irbuild.builder import IRBuilder, SymbolTarget, gen_arg_defaults from mypyc.irbuild.callable_class import ( - setup_callable_class, add_call_to_callable_class, add_get_to_callable_class, - instantiate_callable_class + add_call_to_callable_class, + add_get_to_callable_class, + instantiate_callable_class, + setup_callable_class, +) +from mypyc.irbuild.context import FuncInfo, ImplicitClass +from mypyc.irbuild.env_class import ( + finalize_env_class, + load_env_registers, + load_outer_envs, + setup_env_class, + setup_func_for_recursive_call, ) from mypyc.irbuild.generator import ( - gen_generator_func, setup_env_for_generator_class, create_switch_for_generator_class, - add_raise_exception_blocks_to_generator_class, populate_switch_for_generator_class, - add_methods_to_generator_class + add_methods_to_generator_class, + add_raise_exception_blocks_to_generator_class, + create_switch_for_generator_class, + gen_generator_func, + populate_switch_for_generator_class, + setup_env_for_generator_class, ) -from mypyc.irbuild.env_class import ( - setup_env_class, load_outer_envs, load_env_registers, finalize_env_class, - setup_func_for_recursive_call +from mypyc.irbuild.statement import transform_try_except +from mypyc.irbuild.targets import AssignmentTarget +from mypyc.irbuild.util import is_constant +from mypyc.primitives.dict_ops import dict_get_method_with_none, dict_new_op, dict_set_item_op +from mypyc.primitives.generic_ops import iter_op, next_raw_op, py_setattr_op +from mypyc.primitives.misc_ops import ( + check_stop_op, + coro_op, + register_function, + send_op, + yield_from_except_op, ) - from mypyc.primitives.registry import builtin_names -from collections import defaultdict - +from mypyc.sametype import is_same_method_signature # Top-level transform functions @@ -84,10 +123,7 @@ def transform_overloaded_func_def(builder: IRBuilder, o: OverloadedFuncDef) -> N def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: func_ir, func_reg = gen_func_item( - builder, - dec.func, - dec.func.name, - builder.mapper.fdef_to_sig(dec.func) + builder, dec.func, dec.func.name, builder.mapper.fdef_to_sig(dec.func) ) decorated_func: Optional[Value] = None if func_reg: @@ -99,7 +135,7 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: # treat this function as a regular function, not a decorated function elif dec.func in builder.fdefs_to_decorators: # Obtain the the function name in order to construct the name of the helper function. - name = dec.func.fullname.split('.')[-1] + name = dec.func.fullname.split(".")[-1] # Load the callable object representing the non-decorated function, and decorate it. orig_func = builder.load_global_str(name, dec.line) @@ -107,10 +143,11 @@ def transform_decorator(builder: IRBuilder, dec: Decorator) -> None: if decorated_func is not None: # Set the callable object representing the decorated function as a global. - builder.call_c(dict_set_item_op, - [builder.load_globals_dict(), - builder.load_str(dec.func.name), decorated_func], - decorated_func.line) + builder.call_c( + dict_set_item_op, + [builder.load_globals_dict(), builder.load_str(dec.func.name), decorated_func], + decorated_func.line, + ) maybe_insert_into_registry_dict(builder, dec.func) @@ -125,12 +162,13 @@ def transform_lambda_expr(builder: IRBuilder, expr: LambdaExpr) -> Value: for arg, arg_type in zip(expr.arguments, typ.arg_types): arg.variable.type = arg_type runtime_args.append( - RuntimeArg(arg.variable.name, builder.type_to_rtype(arg_type), arg.kind)) + RuntimeArg(arg.variable.name, builder.type_to_rtype(arg_type), arg.kind) + ) ret_type = builder.type_to_rtype(typ.ret_type) fsig = FuncSignature(runtime_args, ret_type) - fname = f'{LAMBDA_NAME}{builder.lambda_counter}' + fname = f"{LAMBDA_NAME}{builder.lambda_counter}" builder.lambda_counter += 1 func_ir, func_reg = gen_func_item(builder, expr, fname, fsig) assert func_reg is not None @@ -141,7 +179,7 @@ def transform_lambda_expr(builder: IRBuilder, expr: LambdaExpr) -> Value: def transform_yield_expr(builder: IRBuilder, expr: YieldExpr) -> Value: if builder.fn_info.is_coroutine: - builder.error('async generators are unimplemented', expr.line) + builder.error("async generators are unimplemented", expr.line) if expr.expr: retval = builder.accept(expr.expr) @@ -161,12 +199,13 @@ def transform_await_expr(builder: IRBuilder, o: AwaitExpr) -> Value: # Internal functions -def gen_func_item(builder: IRBuilder, - fitem: FuncItem, - name: str, - sig: FuncSignature, - cdef: Optional[ClassDef] = None, - ) -> Tuple[FuncIR, Optional[Value]]: +def gen_func_item( + builder: IRBuilder, + fitem: FuncItem, + name: str, + sig: FuncSignature, + cdef: Optional[ClassDef] = None, +) -> Tuple[FuncIR, Optional[Value]]: """Generate and return the FuncIR for a given FuncDef. If the given FuncItem is a nested function, then we generate a @@ -223,8 +262,18 @@ def c() -> None: func_name = singledispatch_main_func_name(name) else: func_name = name - builder.enter(FuncInfo(fitem, func_name, class_name, gen_func_ns(builder), - is_nested, contains_nested, is_decorated, in_non_ext)) + builder.enter( + FuncInfo( + fitem, + func_name, + class_name, + gen_func_ns(builder), + is_nested, + contains_nested, + is_decorated, + in_non_ext, + ) + ) # Functions that contain nested functions need an environment class to store variables that # are free in their nested functions. Generator functions need an environment class to @@ -241,7 +290,7 @@ def c() -> None: gen_generator_func(builder) args, _, blocks, ret_type, fn_info = builder.leave() func_ir, func_reg = gen_func_ir( - builder, args, blocks, sig, fn_info, cdef, is_singledispatch, + builder, args, blocks, sig, fn_info, cdef, is_singledispatch ) # Re-enter the FuncItem and visit the body of the function this time. @@ -275,8 +324,7 @@ def c() -> None: if builder.fn_info.fitem in builder.free_variables: # Sort the variables to keep things deterministic - for var in sorted(builder.free_variables[builder.fn_info.fitem], - key=lambda x: x.name): + for var in sorted(builder.free_variables[builder.fn_info.fitem], key=lambda x: x.name): if isinstance(var, Var): rtype = builder.type_to_rtype(var.type) builder.add_var_to_env_class(var, rtype, env_for_func, reassign=False) @@ -306,11 +354,10 @@ def c() -> None: args, _, blocks, ret_type, fn_info = builder.leave() if fn_info.is_generator: - add_methods_to_generator_class( - builder, fn_info, sig, args, blocks, fitem.is_coroutine) + add_methods_to_generator_class(builder, fn_info, sig, args, blocks, fitem.is_coroutine) else: func_ir, func_reg = gen_func_ir( - builder, args, blocks, sig, fn_info, cdef, is_singledispatch, + builder, args, blocks, sig, fn_info, cdef, is_singledispatch ) # Evaluate argument defaults in the surrounding scope, since we @@ -327,13 +374,15 @@ def c() -> None: return func_ir, func_reg -def gen_func_ir(builder: IRBuilder, - args: List[Register], - blocks: List[BasicBlock], - sig: FuncSignature, - fn_info: FuncInfo, - cdef: Optional[ClassDef], - is_singledispatch_main_func: bool = False) -> Tuple[FuncIR, Optional[Value]]: +def gen_func_ir( + builder: IRBuilder, + args: List[Register], + blocks: List[BasicBlock], + sig: FuncSignature, + fn_info: FuncInfo, + cdef: Optional[ClassDef], + is_singledispatch_main_func: bool = False, +) -> Tuple[FuncIR, Optional[Value]]: """Generate the FuncIR for a function. This takes the basic blocks and function info of a particular @@ -351,14 +400,22 @@ def gen_func_ir(builder: IRBuilder, func_decl = builder.mapper.func_to_decl[fn_info.fitem] if fn_info.is_decorated or is_singledispatch_main_func: class_name = None if cdef is None else cdef.name - func_decl = FuncDecl(fn_info.name, class_name, builder.module_name, sig, - func_decl.kind, - func_decl.is_prop_getter, func_decl.is_prop_setter) - func_ir = FuncIR(func_decl, args, blocks, fn_info.fitem.line, - traceback_name=fn_info.fitem.name) + func_decl = FuncDecl( + fn_info.name, + class_name, + builder.module_name, + sig, + func_decl.kind, + func_decl.is_prop_getter, + func_decl.is_prop_setter, + ) + func_ir = FuncIR( + func_decl, args, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name + ) else: - func_ir = FuncIR(func_decl, args, blocks, - fn_info.fitem.line, traceback_name=fn_info.fitem.name) + func_ir = FuncIR( + func_decl, args, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name + ) return (func_ir, func_reg) @@ -371,7 +428,7 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None if is_decorated(builder, fdef): # Obtain the the function name in order to construct the name of the helper function. - _, _, name = fdef.fullname.rpartition('.') + _, _, name = fdef.fullname.rpartition(".") # Read the PyTypeObject representing the class, get the callable object # representing the non-decorated method typ = builder.load_native_type_object(cdef.fullname) @@ -382,9 +439,7 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None # Set the callable object representing the decorated method as an attribute of the # extension class. - builder.call_c(py_setattr_op, - [typ, builder.load_str(name), decorated_func], - fdef.line) + builder.call_c(py_setattr_op, [typ, builder.load_str(name), decorated_func], fdef.line) if fdef.is_property: # If there is a property setter, it will be processed after the getter, @@ -403,9 +458,13 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None # If this overrides a parent class method with a different type, we need # to generate a glue method to mediate between them. for base in class_ir.mro[1:]: - if (name in base.method_decls and name != '__init__' - and not is_same_method_signature(class_ir.method_decls[name].sig, - base.method_decls[name].sig)): + if ( + name in base.method_decls + and name != "__init__" + and not is_same_method_signature( + class_ir.method_decls[name].sig, base.method_decls[name].sig + ) + ): # TODO: Support contravariant subtyping in the input argument for # property setters. Need to make a special glue method for handling this, @@ -426,7 +485,8 @@ def handle_ext_method(builder: IRBuilder, cdef: ClassDef, fdef: FuncDef) -> None def handle_non_ext_method( - builder: IRBuilder, non_ext: NonExtClassInfo, cdef: ClassDef, fdef: FuncDef) -> None: + builder: IRBuilder, non_ext: NonExtClassInfo, cdef: ClassDef, fdef: FuncDef +) -> None: # Perform the function of visit_method for methods inside non-extension classes. name = fdef.name func_ir, func_reg = gen_func_item(builder, fdef, name, builder.mapper.fdef_to_sig(fdef), cdef) @@ -440,26 +500,26 @@ def handle_non_ext_method( # TODO: Support property setters in non-extension classes if fdef.is_property: - prop = builder.load_module_attr_by_fullname('builtins.property', fdef.line) + prop = builder.load_module_attr_by_fullname("builtins.property", fdef.line) func_reg = builder.py_call(prop, [func_reg], fdef.line) elif builder.mapper.func_to_decl[fdef].kind == FUNC_CLASSMETHOD: - cls_meth = builder.load_module_attr_by_fullname('builtins.classmethod', fdef.line) + cls_meth = builder.load_module_attr_by_fullname("builtins.classmethod", fdef.line) func_reg = builder.py_call(cls_meth, [func_reg], fdef.line) elif builder.mapper.func_to_decl[fdef].kind == FUNC_STATICMETHOD: - stat_meth = builder.load_module_attr_by_fullname( - 'builtins.staticmethod', fdef.line - ) + stat_meth = builder.load_module_attr_by_fullname("builtins.staticmethod", fdef.line) func_reg = builder.py_call(stat_meth, [func_reg], fdef.line) builder.add_to_non_ext_dict(non_ext, name, func_reg, fdef.line) -def calculate_arg_defaults(builder: IRBuilder, - fn_info: FuncInfo, - func_reg: Optional[Value], - symtable: Dict[SymbolNode, SymbolTarget]) -> None: +def calculate_arg_defaults( + builder: IRBuilder, + fn_info: FuncInfo, + func_reg: Optional[Value], + symtable: Dict[SymbolNode, SymbolTarget], +) -> None: """Calculate default argument values and store them. They are stored in statics for top level functions and in @@ -471,12 +531,10 @@ def calculate_arg_defaults(builder: IRBuilder, # Constant values don't get stored but just recomputed if arg.initializer and not is_constant(arg.initializer): value = builder.coerce( - builder.accept(arg.initializer), - symtable[arg.variable].type, - arg.line + builder.accept(arg.initializer), symtable[arg.variable].type, arg.line ) if not fn_info.is_nested: - name = fitem.fullname + '.' + arg.variable.name + name = fitem.fullname + "." + arg.variable.name builder.add(InitStatic(value, name, builder.module_name)) else: assert func_reg is not None @@ -485,9 +543,11 @@ def calculate_arg_defaults(builder: IRBuilder, def gen_func_ns(builder: IRBuilder) -> str: """Generate a namespace for a nested function using its outer function names.""" - return '_'.join(info.name + ('' if not info.class_name else '_' + info.class_name) - for info in builder.fn_infos - if info.name and info.name != '') + return "_".join( + info.name + ("" if not info.class_name else "_" + info.class_name) + for info in builder.fn_infos + if info.name and info.name != "" + ) def emit_yield(builder: IRBuilder, val: Value, line: int) -> Value: @@ -553,8 +613,9 @@ def except_body() -> None: # indicating whether to break or yield (or raise an exception). val = Register(object_rprimitive) val_address = builder.add(LoadAddress(object_pointer_rprimitive, val)) - to_stop = builder.call_c(yield_from_except_op, - [builder.read(iter_reg), val_address], o.line) + to_stop = builder.call_c( + yield_from_except_op, [builder.read(iter_reg), val_address], o.line + ) ok, stop = BasicBlock(), BasicBlock() builder.add(Branch(to_stop, stop, ok, Branch.BOOL)) @@ -572,9 +633,7 @@ def except_body() -> None: def else_body() -> None: # Do a next() or a .send(). It will return NULL on exception # but it won't automatically propagate. - _y = builder.call_c( - send_op, [builder.read(iter_reg), builder.read(received_reg)], o.line - ) + _y = builder.call_c(send_op, [builder.read(iter_reg), builder.read(received_reg)], o.line) ok, stop = BasicBlock(), BasicBlock() builder.add(Branch(_y, stop, ok, Branch.IS_ERROR)) @@ -590,9 +649,7 @@ def else_body() -> None: builder.nonlocal_control[-1].gen_break(builder, o.line) builder.push_loop_stack(loop_block, done_block) - transform_try_except( - builder, try_body, [(None, None, except_body)], else_body, o.line - ) + transform_try_except(builder, try_body, [(None, None, except_body)], else_body, o.line) builder.pop_loop_stack() builder.goto_and_activate(done_block) @@ -625,11 +682,16 @@ def is_decorated(builder: IRBuilder, fdef: FuncDef) -> bool: return fdef in builder.fdefs_to_decorators -def gen_glue(builder: IRBuilder, sig: FuncSignature, target: FuncIR, - cls: ClassIR, base: ClassIR, fdef: FuncItem, - *, - do_py_ops: bool = False - ) -> FuncIR: +def gen_glue( + builder: IRBuilder, + sig: FuncSignature, + target: FuncIR, + cls: ClassIR, + base: ClassIR, + fdef: FuncItem, + *, + do_py_ops: bool = False, +) -> FuncIR: """Generate glue methods that mediate between different method types in subclasses. Works on both properties and methods. See gen_glue_methods below @@ -654,19 +716,27 @@ class ArgInfo(NamedTuple): def get_args(builder: IRBuilder, rt_args: Sequence[RuntimeArg], line: int) -> ArgInfo: # The environment operates on Vars, so we make some up fake_vars = [(Var(arg.name), arg.type) for arg in rt_args] - args = [builder.read(builder.add_local_reg(var, type, is_arg=True), line) - for var, type in fake_vars] - arg_names = [arg.name - if arg.kind.is_named() or (arg.kind.is_optional() and not arg.pos_only) else None - for arg in rt_args] + args = [ + builder.read(builder.add_local_reg(var, type, is_arg=True), line) + for var, type in fake_vars + ] + arg_names = [ + arg.name if arg.kind.is_named() or (arg.kind.is_optional() and not arg.pos_only) else None + for arg in rt_args + ] arg_kinds = [arg.kind for arg in rt_args] return ArgInfo(args, arg_names, arg_kinds) -def gen_glue_method(builder: IRBuilder, sig: FuncSignature, target: FuncIR, - cls: ClassIR, base: ClassIR, line: int, - do_pycall: bool, - ) -> FuncIR: +def gen_glue_method( + builder: IRBuilder, + sig: FuncSignature, + target: FuncIR, + cls: ClassIR, + base: ClassIR, + line: int, + do_pycall: bool, +) -> FuncIR: """Generate glue methods that mediate between different method types in subclasses. For example, if we have: @@ -704,9 +774,8 @@ def f(builder: IRBuilder, x: object) -> int: ... # We can do a passthrough *args/**kwargs with a native call, but if the # args need to get distributed out to arguments, we just let python handle it - if ( - any(kind.is_star() for kind in arg_kinds) - and any(not arg.kind.is_star() for arg in target.decl.sig.args) + if any(kind.is_star() for kind in arg_kinds) and any( + not arg.kind.is_star() for arg in target.decl.sig.args ): do_pycall = True @@ -719,7 +788,8 @@ def f(builder: IRBuilder, x: object) -> int: ... first = args[0] st = 1 retval = builder.builder.py_method_call( - first, target.name, args[st:], line, arg_kinds[st:], arg_names[st:]) + first, target.name, args[st:], line, arg_kinds[st:], arg_names[st:] + ) else: retval = builder.builder.call(target.decl, args, arg_kinds, arg_names, line) retval = builder.coerce(retval, sig.ret_type, line) @@ -727,20 +797,27 @@ def f(builder: IRBuilder, x: object) -> int: ... arg_regs, _, blocks, ret_type, _ = builder.leave() return FuncIR( - FuncDecl(target.name + '__' + base.name + '_glue', - cls.name, builder.module_name, - FuncSignature(rt_args, ret_type), - target.decl.kind), - arg_regs, blocks) - - -def gen_glue_property(builder: IRBuilder, - sig: FuncSignature, - target: FuncIR, - cls: ClassIR, - base: ClassIR, - line: int, - do_pygetattr: bool) -> FuncIR: + FuncDecl( + target.name + "__" + base.name + "_glue", + cls.name, + builder.module_name, + FuncSignature(rt_args, ret_type), + target.decl.kind, + ), + arg_regs, + blocks, + ) + + +def gen_glue_property( + builder: IRBuilder, + sig: FuncSignature, + target: FuncIR, + cls: ClassIR, + base: ClassIR, + line: int, + do_pygetattr: bool, +) -> FuncIR: """Generate glue methods for properties that mediate between different subclass types. Similarly to methods, properties of derived types can be covariantly subtyped. Thus, @@ -765,9 +842,15 @@ def gen_glue_property(builder: IRBuilder, args, _, blocks, return_type, _ = builder.leave() return FuncIR( - FuncDecl(target.name + '__' + base.name + '_glue', - cls.name, builder.module_name, FuncSignature([rt_arg], return_type)), - args, blocks) + FuncDecl( + target.name + "__" + base.name + "_glue", + cls.name, + builder.module_name, + FuncSignature([rt_arg], return_type), + ), + args, + blocks, + ) def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget: @@ -806,7 +889,7 @@ def load_func(builder: IRBuilder, func_name: str, fullname: Optional[str], line: # We can't use load_module_attr_by_fullname here because we need to load the function using # func_name, not the name specified by fullname (which can be different for underscore # function) - module = fullname.rsplit('.')[0] + module = fullname.rsplit(".")[0] loaded_module = builder.load_module(module) func = builder.py_get_attr(loaded_module, func_name, line) @@ -816,9 +899,7 @@ def load_func(builder: IRBuilder, func_name: str, fullname: Optional[str], line: def generate_singledispatch_dispatch_function( - builder: IRBuilder, - main_singledispatch_function_name: str, - fitem: FuncDef, + builder: IRBuilder, main_singledispatch_function_name: str, fitem: FuncDef ) -> None: line = fitem.line current_func_decl = builder.mapper.func_to_decl[fitem] @@ -828,11 +909,11 @@ def generate_singledispatch_dispatch_function( arg_type = builder.builder.get_type_of_obj(arg_info.args[0], line) dispatch_cache = builder.builder.get_attr( - dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line + dispatch_func_obj, "dispatch_cache", dict_rprimitive, line ) call_find_impl, use_cache, call_func = BasicBlock(), BasicBlock(), BasicBlock() get_result = builder.call_c(dict_get_method_with_none, [dispatch_cache, arg_type], line) - is_not_none = builder.translate_is_op(get_result, builder.none_object(), 'is not', line) + is_not_none = builder.translate_is_op(get_result, builder.none_object(), "is not", line) impl_to_use = Register(object_rprimitive) builder.add_bool_branch(is_not_none, use_cache, call_find_impl) @@ -841,7 +922,7 @@ def generate_singledispatch_dispatch_function( builder.goto(call_func) builder.activate_block(call_find_impl) - find_impl = builder.load_module_attr_by_fullname('functools._find_impl', line) + find_impl = builder.load_module_attr_by_fullname("functools._find_impl", line) registry = load_singledispatch_registry(builder, dispatch_func_obj, line) uncached_impl = builder.py_call(find_impl, [arg_type, registry], line) builder.call_c(dict_set_item_op, [dispatch_cache, arg_type, uncached_impl], line) @@ -853,11 +934,7 @@ def generate_singledispatch_dispatch_function( def gen_calls_to_correct_impl( - builder: IRBuilder, - impl_to_use: Value, - arg_info: ArgInfo, - fitem: FuncDef, - line: int, + builder: IRBuilder, impl_to_use: Value, arg_info: ArgInfo, fitem: FuncDef, line: int ) -> None: current_func_decl = builder.mapper.func_to_decl[fitem] @@ -869,7 +946,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: coerced = builder.coerce(ret_val, current_func_decl.sig.ret_type, line) builder.add(Return(coerced)) - typ, src = builtin_names['builtins.int'] + typ, src = builtin_names["builtins.int"] int_type_obj = builder.add(LoadAddress(typ, src, line)) is_int = builder.builder.type_is_op(impl_to_use, int_type_obj, line) @@ -885,12 +962,7 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: current_id = builder.load_int(i) builder.builder.compare_tagged_condition( - passed_id, - current_id, - '==', - call_impl, - next_impl, - line, + passed_id, current_id, "==", call_impl, next_impl, line ) # Call the registered implementation @@ -911,19 +983,15 @@ def gen_native_func_call_and_return(fdef: FuncDef) -> None: def gen_dispatch_func_ir( - builder: IRBuilder, - fitem: FuncDef, - main_func_name: str, - dispatch_name: str, - sig: FuncSignature, + builder: IRBuilder, fitem: FuncDef, main_func_name: str, dispatch_name: str, sig: FuncSignature ) -> Tuple[FuncIR, Value]: """Create a dispatch function (a function that checks the first argument type and dispatches to the correct implementation) """ builder.enter(FuncInfo(fitem, dispatch_name)) setup_callable_class(builder) - builder.fn_info.callable_class.ir.attributes['registry'] = dict_rprimitive - builder.fn_info.callable_class.ir.attributes['dispatch_cache'] = dict_rprimitive + builder.fn_info.callable_class.ir.attributes["registry"] = dict_rprimitive + builder.fn_info.callable_class.ir.attributes["dispatch_cache"] = dict_rprimitive builder.fn_info.callable_class.ir.has_dict = True builder.fn_info.callable_class.ir.needs_getseters = True generate_singledispatch_callable_class_ctor(builder) @@ -943,10 +1011,7 @@ def gen_dispatch_func_ir( def generate_dispatch_glue_native_function( - builder: IRBuilder, - fitem: FuncDef, - callable_class_decl: FuncDecl, - dispatch_name: str, + builder: IRBuilder, fitem: FuncDef, callable_class_decl: FuncDecl, dispatch_name: str ) -> FuncIR: line = fitem.line builder.enter() @@ -957,7 +1022,7 @@ def generate_dispatch_glue_native_function( args = [callable_class] + arg_info.args arg_kinds = [ArgKind.ARG_POS] + arg_info.arg_kinds arg_names = arg_info.arg_names - arg_names.insert(0, 'self') + arg_names.insert(0, "self") ret_val = builder.builder.call(callable_class_decl, args, arg_kinds, arg_names, line) builder.add(Return(ret_val)) arg_regs, _, blocks, _, fn_info = builder.leave() @@ -968,11 +1033,11 @@ def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None: """Create an __init__ that sets registry and dispatch_cache to empty dicts""" line = -1 class_ir = builder.fn_info.callable_class.ir - with builder.enter_method(class_ir, '__init__', bool_rprimitive): + with builder.enter_method(class_ir, "__init__", bool_rprimitive): empty_dict = builder.call_c(dict_new_op, [], line) - builder.add(SetAttr(builder.self(), 'registry', empty_dict, line)) + builder.add(SetAttr(builder.self(), "registry", empty_dict, line)) cache_dict = builder.call_c(dict_new_op, [], line) - dispatch_cache_str = builder.load_str('dispatch_cache') + dispatch_cache_str = builder.load_str("dispatch_cache") # use the py_setattr_op instead of SetAttr so that it also gets added to our __dict__ builder.call_c(py_setattr_op, [builder.self(), dispatch_cache_str, cache_dict], line) # the generated C code seems to expect that __init__ returns a char, so just return 1 @@ -981,23 +1046,23 @@ def generate_singledispatch_callable_class_ctor(builder: IRBuilder) -> None: def add_register_method_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None: line = -1 - with builder.enter_method(fn_info.callable_class.ir, 'register', object_rprimitive): - cls_arg = builder.add_argument('cls', object_rprimitive) - func_arg = builder.add_argument('func', object_rprimitive, ArgKind.ARG_OPT) + with builder.enter_method(fn_info.callable_class.ir, "register", object_rprimitive): + cls_arg = builder.add_argument("cls", object_rprimitive) + func_arg = builder.add_argument("func", object_rprimitive, ArgKind.ARG_OPT) ret_val = builder.call_c(register_function, [builder.self(), cls_arg, func_arg], line) builder.add(Return(ret_val, line)) def load_singledispatch_registry(builder: IRBuilder, dispatch_func_obj: Value, line: int) -> Value: - return builder.builder.get_attr(dispatch_func_obj, 'registry', dict_rprimitive, line) + return builder.builder.get_attr(dispatch_func_obj, "registry", dict_rprimitive, line) def singledispatch_main_func_name(orig_name: str) -> str: - return f'__mypyc_singledispatch_main_function_{orig_name}__' + return f"__mypyc_singledispatch_main_function_{orig_name}__" def get_registry_identifier(fitem: FuncDef) -> str: - return f'__mypyc_singledispatch_registry_{fitem.fullname}__' + return f"__mypyc_singledispatch_registry_{fitem.fullname}__" def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: @@ -1017,12 +1082,12 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: main_func_name = singledispatch_main_func_name(fitem.name) main_func_obj = load_func(builder, main_func_name, fitem.fullname, line) - loaded_object_type = builder.load_module_attr_by_fullname('builtins.object', line) + loaded_object_type = builder.load_module_attr_by_fullname("builtins.object", line) registry_dict = builder.builder.make_dict([(loaded_object_type, main_func_obj)], line) dispatch_func_obj = builder.load_global_str(fitem.name, line) builder.call_c( - py_setattr_op, [dispatch_func_obj, builder.load_str('registry'), registry_dict], line + py_setattr_op, [dispatch_func_obj, builder.load_str("registry"), registry_dict], line ) for singledispatch_func, types in to_register.items(): @@ -1044,9 +1109,9 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None: loaded_type = load_type(builder, typ, line) builder.call_c(dict_set_item_op, [registry, loaded_type, to_insert], line) dispatch_cache = builder.builder.get_attr( - dispatch_func_obj, 'dispatch_cache', dict_rprimitive, line + dispatch_func_obj, "dispatch_cache", dict_rprimitive, line ) - builder.gen_method_call(dispatch_cache, 'clear', [], None, line) + builder.gen_method_call(dispatch_cache, "clear", [], None, line) def get_native_impl_ids(builder: IRBuilder, singledispatch_func: FuncDef) -> Dict[FuncDef, int]: diff --git a/mypyc/irbuild/generator.py b/mypyc/irbuild/generator.py index 7a96d390e156d..742c0da0337a4 100644 --- a/mypyc/irbuild/generator.py +++ b/mypyc/irbuild/generator.py @@ -10,26 +10,43 @@ from typing import List -from mypy.nodes import Var, ARG_OPT - -from mypyc.common import SELF_NAME, NEXT_LABEL_ATTR_NAME, ENV_ATTR_NAME +from mypy.nodes import ARG_OPT, Var +from mypyc.common import ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME, SELF_NAME +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg from mypyc.ir.ops import ( - BasicBlock, Call, Return, Goto, Integer, SetAttr, Unreachable, RaiseStandardError, - Value, Register, MethodCall, TupleSet, Branch, NO_TRACEBACK_LINE_NO + NO_TRACEBACK_LINE_NO, + BasicBlock, + Branch, + Call, + Goto, + Integer, + MethodCall, + RaiseStandardError, + Register, + Return, + SetAttr, + TupleSet, + Unreachable, + Value, ) from mypyc.ir.rtypes import RInstance, int_rprimitive, object_rprimitive -from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature, RuntimeArg -from mypyc.ir.class_ir import ClassIR +from mypyc.irbuild.builder import IRBuilder, gen_arg_defaults +from mypyc.irbuild.context import FuncInfo, GeneratorClass +from mypyc.irbuild.env_class import ( + add_args_to_env, + finalize_env_class, + load_env_registers, + load_outer_env, +) from mypyc.irbuild.nonlocalcontrol import ExceptNonlocalControl from mypyc.primitives.exc_ops import ( - raise_exception_with_tb_op, error_catch_op, exc_matches_op, reraise_exception_op, - restore_exc_info_op -) -from mypyc.irbuild.env_class import ( - add_args_to_env, load_outer_env, load_env_registers, finalize_env_class + error_catch_op, + exc_matches_op, + raise_exception_with_tb_op, + reraise_exception_op, + restore_exc_info_op, ) -from mypyc.irbuild.builder import IRBuilder, gen_arg_defaults -from mypyc.irbuild.context import FuncInfo, GeneratorClass def gen_generator_func(builder: IRBuilder) -> None: @@ -64,7 +81,7 @@ def instantiate_generator_class(builder: IRBuilder) -> Value: def setup_generator_class(builder: IRBuilder) -> ClassIR: - name = f'{builder.fn_info.namespaced_name()}_gen' + name = f"{builder.fn_info.namespaced_name()}_gen" generator_class_ir = ClassIR(name, builder.module_name, is_generated=True) generator_class_ir.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_info.env_class) @@ -89,9 +106,7 @@ def populate_switch_for_generator_class(builder: IRBuilder) -> None: builder.activate_block(cls.switch_block) for label, true_block in enumerate(cls.continuation_blocks): false_block = BasicBlock() - comparison = builder.binary_op( - cls.next_label_reg, Integer(label), '==', line - ) + comparison = builder.binary_op(cls.next_label_reg, Integer(label), "==", line) builder.add_bool_branch(comparison, true_block, false_block) builder.activate_block(false_block) @@ -113,7 +128,7 @@ def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int) # Check to see if an exception was raised. error_block = BasicBlock() ok_block = BasicBlock() - comparison = builder.translate_is_op(exc_type, builder.none_object(), 'is not', line) + comparison = builder.translate_is_op(exc_type, builder.none_object(), "is not", line) builder.add_bool_branch(comparison, error_block, ok_block) builder.activate_block(error_block) @@ -122,12 +137,14 @@ def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int) builder.goto_and_activate(ok_block) -def add_methods_to_generator_class(builder: IRBuilder, - fn_info: FuncInfo, - sig: FuncSignature, - arg_regs: List[Register], - blocks: List[BasicBlock], - is_coroutine: bool) -> None: +def add_methods_to_generator_class( + builder: IRBuilder, + fn_info: FuncInfo, + sig: FuncSignature, + arg_regs: List[Register], + blocks: List[BasicBlock], + is_coroutine: bool, +) -> None: helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, sig, fn_info) add_next_to_generator_class(builder, fn_info, helper_fn_decl, sig) add_send_to_generator_class(builder, fn_info, helper_fn_decl, sig) @@ -138,74 +155,84 @@ def add_methods_to_generator_class(builder: IRBuilder, add_await_to_generator_class(builder, fn_info) -def add_helper_to_generator_class(builder: IRBuilder, - arg_regs: List[Register], - blocks: List[BasicBlock], - sig: FuncSignature, - fn_info: FuncInfo) -> FuncDecl: +def add_helper_to_generator_class( + builder: IRBuilder, + arg_regs: List[Register], + blocks: List[BasicBlock], + sig: FuncSignature, + fn_info: FuncInfo, +) -> FuncDecl: """Generates a helper method for a generator class, called by '__next__' and 'throw'.""" - sig = FuncSignature((RuntimeArg(SELF_NAME, object_rprimitive), - RuntimeArg('type', object_rprimitive), - RuntimeArg('value', object_rprimitive), - RuntimeArg('traceback', object_rprimitive), - RuntimeArg('arg', object_rprimitive) - ), sig.ret_type) - helper_fn_decl = FuncDecl('__mypyc_generator_helper__', fn_info.generator_class.ir.name, - builder.module_name, sig) - helper_fn_ir = FuncIR(helper_fn_decl, arg_regs, blocks, - fn_info.fitem.line, traceback_name=fn_info.fitem.name) - fn_info.generator_class.ir.methods['__mypyc_generator_helper__'] = helper_fn_ir + sig = FuncSignature( + ( + RuntimeArg(SELF_NAME, object_rprimitive), + RuntimeArg("type", object_rprimitive), + RuntimeArg("value", object_rprimitive), + RuntimeArg("traceback", object_rprimitive), + RuntimeArg("arg", object_rprimitive), + ), + sig.ret_type, + ) + helper_fn_decl = FuncDecl( + "__mypyc_generator_helper__", fn_info.generator_class.ir.name, builder.module_name, sig + ) + helper_fn_ir = FuncIR( + helper_fn_decl, arg_regs, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name + ) + fn_info.generator_class.ir.methods["__mypyc_generator_helper__"] = helper_fn_ir builder.functions.append(helper_fn_ir) return helper_fn_decl def add_iter_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None: """Generates the '__iter__' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, '__iter__', object_rprimitive, fn_info): + with builder.enter_method(fn_info.generator_class.ir, "__iter__", object_rprimitive, fn_info): builder.add(Return(builder.self())) -def add_next_to_generator_class(builder: IRBuilder, - fn_info: FuncInfo, - fn_decl: FuncDecl, - sig: FuncSignature) -> None: +def add_next_to_generator_class( + builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature +) -> None: """Generates the '__next__' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, '__next__', - object_rprimitive, fn_info): + with builder.enter_method(fn_info.generator_class.ir, "__next__", object_rprimitive, fn_info): none_reg = builder.none_object() # Call the helper function with error flags set to Py_None, and return that result. - result = builder.add(Call(fn_decl, - [builder.self(), none_reg, none_reg, none_reg, none_reg], - fn_info.fitem.line)) + result = builder.add( + Call( + fn_decl, + [builder.self(), none_reg, none_reg, none_reg, none_reg], + fn_info.fitem.line, + ) + ) builder.add(Return(result)) -def add_send_to_generator_class(builder: IRBuilder, - fn_info: FuncInfo, - fn_decl: FuncDecl, - sig: FuncSignature) -> None: +def add_send_to_generator_class( + builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature +) -> None: """Generates the 'send' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, 'send', object_rprimitive, fn_info): - arg = builder.add_argument('arg', object_rprimitive) + with builder.enter_method(fn_info.generator_class.ir, "send", object_rprimitive, fn_info): + arg = builder.add_argument("arg", object_rprimitive) none_reg = builder.none_object() # Call the helper function with error flags set to Py_None, and return that result. - result = builder.add(Call( - fn_decl, - [builder.self(), none_reg, none_reg, none_reg, builder.read(arg)], - fn_info.fitem.line)) + result = builder.add( + Call( + fn_decl, + [builder.self(), none_reg, none_reg, none_reg, builder.read(arg)], + fn_info.fitem.line, + ) + ) builder.add(Return(result)) -def add_throw_to_generator_class(builder: IRBuilder, - fn_info: FuncInfo, - fn_decl: FuncDecl, - sig: FuncSignature) -> None: +def add_throw_to_generator_class( + builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature +) -> None: """Generates the 'throw' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, 'throw', - object_rprimitive, fn_info): - typ = builder.add_argument('type', object_rprimitive) - val = builder.add_argument('value', object_rprimitive, ARG_OPT) - tb = builder.add_argument('traceback', object_rprimitive, ARG_OPT) + with builder.enter_method(fn_info.generator_class.ir, "throw", object_rprimitive, fn_info): + typ = builder.add_argument("type", object_rprimitive) + val = builder.add_argument("value", object_rprimitive, ARG_OPT) + tb = builder.add_argument("traceback", object_rprimitive, ARG_OPT) # Because the value and traceback arguments are optional and hence # can be NULL if not passed in, we have to assign them Py_None if @@ -215,38 +242,45 @@ def add_throw_to_generator_class(builder: IRBuilder, builder.assign_if_null(tb, lambda: none_reg, builder.fn_info.fitem.line) # Call the helper function using the arguments passed in, and return that result. - result = builder.add(Call( - fn_decl, - [builder.self(), builder.read(typ), builder.read(val), builder.read(tb), none_reg], - fn_info.fitem.line)) + result = builder.add( + Call( + fn_decl, + [builder.self(), builder.read(typ), builder.read(val), builder.read(tb), none_reg], + fn_info.fitem.line, + ) + ) builder.add(Return(result)) def add_close_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None: """Generates the '__close__' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, 'close', object_rprimitive, fn_info): + with builder.enter_method(fn_info.generator_class.ir, "close", object_rprimitive, fn_info): except_block, else_block = BasicBlock(), BasicBlock() builder.builder.push_error_handler(except_block) builder.goto_and_activate(BasicBlock()) - generator_exit = builder.load_module_attr_by_fullname('builtins.GeneratorExit', - fn_info.fitem.line) - builder.add(MethodCall( - builder.self(), - 'throw', - [generator_exit, builder.none_object(), builder.none_object()])) + generator_exit = builder.load_module_attr_by_fullname( + "builtins.GeneratorExit", fn_info.fitem.line + ) + builder.add( + MethodCall( + builder.self(), + "throw", + [generator_exit, builder.none_object(), builder.none_object()], + ) + ) builder.goto(else_block) builder.builder.pop_error_handler() builder.activate_block(except_block) old_exc = builder.call_c(error_catch_op, [], fn_info.fitem.line) builder.nonlocal_control.append( - ExceptNonlocalControl(builder.nonlocal_control[-1], old_exc)) - stop_iteration = builder.load_module_attr_by_fullname('builtins.StopIteration', - fn_info.fitem.line) - exceptions = builder.add( - TupleSet([generator_exit, stop_iteration], fn_info.fitem.line)) - matches = builder.call_c( - exc_matches_op, [exceptions], fn_info.fitem.line) + ExceptNonlocalControl(builder.nonlocal_control[-1], old_exc) + ) + stop_iteration = builder.load_module_attr_by_fullname( + "builtins.StopIteration", fn_info.fitem.line + ) + exceptions = builder.add(TupleSet([generator_exit, stop_iteration], fn_info.fitem.line)) + matches = builder.call_c(exc_matches_op, [exceptions], fn_info.fitem.line) match_block, non_match_block = BasicBlock(), BasicBlock() builder.add(Branch(matches, match_block, non_match_block, Branch.BOOL)) @@ -262,15 +296,19 @@ def add_close_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None: builder.nonlocal_control.pop() builder.activate_block(else_block) - builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, - 'generator ignored GeneratorExit', - fn_info.fitem.line)) + builder.add( + RaiseStandardError( + RaiseStandardError.RUNTIME_ERROR, + "generator ignored GeneratorExit", + fn_info.fitem.line, + ) + ) builder.add(Unreachable()) def add_await_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None: """Generates the '__await__' method for a generator class.""" - with builder.enter_method(fn_info.generator_class.ir, '__await__', object_rprimitive, fn_info): + with builder.enter_method(fn_info.generator_class.ir, "__await__", object_rprimitive, fn_info): builder.add(Return(builder.self())) @@ -281,11 +319,11 @@ def setup_env_for_generator_class(builder: IRBuilder) -> None: self_target = builder.add_self_to_env(cls.ir) # Add the type, value, and traceback variables to the environment. - exc_type = builder.add_local(Var('type'), object_rprimitive, is_arg=True) - exc_val = builder.add_local(Var('value'), object_rprimitive, is_arg=True) - exc_tb = builder.add_local(Var('traceback'), object_rprimitive, is_arg=True) + exc_type = builder.add_local(Var("type"), object_rprimitive, is_arg=True) + exc_val = builder.add_local(Var("value"), object_rprimitive, is_arg=True) + exc_tb = builder.add_local(Var("traceback"), object_rprimitive, is_arg=True) # TODO: Use the right type here instead of object? - exc_arg = builder.add_local(Var('arg'), object_rprimitive, is_arg=True) + exc_arg = builder.add_local(Var("arg"), object_rprimitive, is_arg=True) cls.exc_regs = (exc_type, exc_val, exc_tb) cls.send_arg_reg = exc_arg @@ -297,10 +335,7 @@ def setup_env_for_generator_class(builder: IRBuilder) -> None: # the '__next__' function of the generator is called, and add it # as an attribute to the environment class. cls.next_label_target = builder.add_var_to_env_class( - Var(NEXT_LABEL_ATTR_NAME), - int_rprimitive, - cls, - reassign=False + Var(NEXT_LABEL_ATTR_NAME), int_rprimitive, cls, reassign=False ) # Add arguments from the original generator function to the diff --git a/mypyc/irbuild/ll_builder.py b/mypyc/irbuild/ll_builder.py index 20c8e3a80acfb..101a068ae8a2e 100644 --- a/mypyc/irbuild/ll_builder.py +++ b/mypyc/irbuild/ll_builder.py @@ -8,75 +8,137 @@ level---it has *no knowledge* of mypy types or expressions. """ -from typing import ( - Callable, List, Tuple, Optional, Sequence -) +from typing import Callable, List, Optional, Sequence, Tuple from typing_extensions import Final -from mypy.nodes import ArgKind, ARG_POS, ARG_STAR, ARG_STAR2 +from mypy.checkexpr import map_actuals_to_formals +from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, ArgKind from mypy.operators import op_methods from mypy.types import AnyType, TypeOfAny -from mypy.checkexpr import map_actuals_to_formals - -from mypyc.ir.ops import ( - BasicBlock, Op, Integer, Value, Register, Assign, Branch, Goto, Call, Box, Unbox, Cast, - GetAttr, LoadStatic, MethodCall, CallC, Truncate, LoadLiteral, AssignMulti, - RaiseStandardError, Unreachable, LoadErrorValue, - NAMESPACE_TYPE, NAMESPACE_MODULE, NAMESPACE_STATIC, IntOp, GetElementPtr, - LoadMem, ComparisonOp, LoadAddress, TupleGet, KeepAlive, ERR_NEVER, ERR_FALSE, SetMem -) -from mypyc.ir.rtypes import ( - RType, RUnion, RInstance, RArray, optional_value_type, int_rprimitive, float_rprimitive, - bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive, - c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive, - is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject, - none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive, - pointer_rprimitive, PyObject, PyListObject, bit_rprimitive, is_bit_rprimitive, - object_pointer_rprimitive, c_size_t_rprimitive, dict_rprimitive, bytes_rprimitive, - is_bytes_rprimitive +from mypyc.common import ( + FAST_ISINSTANCE_MAX_SUBCLASSES, + MAX_LITERAL_SHORT_INT, + MIN_LITERAL_SHORT_INT, + PLATFORM_SIZE, + use_method_vectorcall, + use_vectorcall, ) -from mypyc.ir.func_ir import FuncDecl, FuncSignature from mypyc.ir.class_ir import ClassIR, all_concrete_classes -from mypyc.common import ( - FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT, MIN_LITERAL_SHORT_INT, PLATFORM_SIZE, - use_vectorcall, use_method_vectorcall +from mypyc.ir.func_ir import FuncDecl, FuncSignature +from mypyc.ir.ops import ( + ERR_FALSE, + ERR_NEVER, + NAMESPACE_MODULE, + NAMESPACE_STATIC, + NAMESPACE_TYPE, + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + GetAttr, + GetElementPtr, + Goto, + Integer, + IntOp, + KeepAlive, + LoadAddress, + LoadErrorValue, + LoadLiteral, + LoadMem, + LoadStatic, + MethodCall, + Op, + RaiseStandardError, + Register, + SetMem, + Truncate, + TupleGet, + Unbox, + Unreachable, + Value, ) -from mypyc.primitives.registry import ( - method_call_ops, CFunctionDescription, - binary_ops, unary_ops, ERR_NEG_INT +from mypyc.ir.rtypes import ( + PyListObject, + PyObject, + PySetObject, + PyVarObject, + RArray, + RInstance, + RTuple, + RType, + RUnion, + bit_rprimitive, + bool_rprimitive, + bytes_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + c_size_t_rprimitive, + dict_rprimitive, + float_rprimitive, + int_rprimitive, + is_bit_rprimitive, + is_bool_rprimitive, + is_bytes_rprimitive, + is_dict_rprimitive, + is_list_rprimitive, + is_none_rprimitive, + is_set_rprimitive, + is_short_int_rprimitive, + is_str_rprimitive, + is_tagged, + is_tuple_rprimitive, + list_rprimitive, + none_rprimitive, + object_pointer_rprimitive, + object_rprimitive, + optional_value_type, + pointer_rprimitive, + short_int_rprimitive, + str_rprimitive, ) +from mypyc.irbuild.mapper import Mapper +from mypyc.irbuild.util import concrete_arg_kind +from mypyc.options import CompilerOptions from mypyc.primitives.bytes_ops import bytes_compare -from mypyc.primitives.list_ops import ( - list_extend_op, new_list_op, list_build_op -) -from mypyc.primitives.tuple_ops import ( - list_tuple_op, new_tuple_op, new_tuple_with_length_op -) from mypyc.primitives.dict_ops import ( - dict_update_in_display_op, dict_new_op, dict_build_op, dict_ssize_t_size_op + dict_build_op, + dict_new_op, + dict_ssize_t_size_op, + dict_update_in_display_op, ) +from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op from mypyc.primitives.generic_ops import ( - py_getattr_op, py_call_op, py_call_with_kwargs_op, py_method_call_op, - py_vectorcall_op, py_vectorcall_method_op, - generic_len_op, generic_ssize_t_len_op -) -from mypyc.primitives.misc_ops import ( - none_object_op, fast_isinstance_op, bool_op + generic_len_op, + generic_ssize_t_len_op, + py_call_op, + py_call_with_kwargs_op, + py_getattr_op, + py_method_call_op, + py_vectorcall_method_op, + py_vectorcall_op, ) from mypyc.primitives.int_ops import int_comparison_op_mapping -from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op -from mypyc.primitives.str_ops import ( - unicode_compare, str_check_if_true, str_ssize_t_size_op +from mypyc.primitives.list_ops import list_build_op, list_extend_op, new_list_op +from mypyc.primitives.misc_ops import bool_op, fast_isinstance_op, none_object_op +from mypyc.primitives.registry import ( + ERR_NEG_INT, + CFunctionDescription, + binary_ops, + method_call_ops, + unary_ops, ) from mypyc.primitives.set_ops import new_set_op +from mypyc.primitives.str_ops import str_check_if_true, str_ssize_t_size_op, unicode_compare +from mypyc.primitives.tuple_ops import list_tuple_op, new_tuple_op, new_tuple_with_length_op from mypyc.rt_subtype import is_runtime_subtype -from mypyc.subtype import is_subtype from mypyc.sametype import is_same_type -from mypyc.irbuild.mapper import Mapper -from mypyc.options import CompilerOptions -from mypyc.irbuild.util import concrete_arg_kind - +from mypyc.subtype import is_subtype DictEntry = Tuple[Optional[Value], Value] @@ -92,12 +154,7 @@ class LowLevelIRBuilder: - def __init__( - self, - current_module: str, - mapper: Mapper, - options: CompilerOptions, - ) -> None: + def __init__(self, current_module: str, mapper: Mapper, options: CompilerOptions) -> None: self.current_module = current_module self.mapper = mapper self.options = options @@ -163,8 +220,9 @@ def box(self, src: Value) -> Value: else: return src - def unbox_or_cast(self, src: Value, target_type: RType, line: int, *, - can_borrow: bool = False) -> Value: + def unbox_or_cast( + self, src: Value, target_type: RType, line: int, *, can_borrow: bool = False + ) -> Value: if target_type.is_unboxed: return self.add(Unbox(src, target_type, line)) else: @@ -172,8 +230,15 @@ def unbox_or_cast(self, src: Value, target_type: RType, line: int, *, self.keep_alives.append(src) return self.add(Cast(src, target_type, line, borrow=can_borrow)) - def coerce(self, src: Value, target_type: RType, line: int, force: bool = False, *, - can_borrow: bool = False) -> Value: + def coerce( + self, + src: Value, + target_type: RType, + line: int, + force: bool = False, + *, + can_borrow: bool = False, + ) -> Value: """Generate a coercion/cast from one type to other (only if needed). For example, int -> object boxes the source int; int -> int emits nothing; @@ -186,14 +251,16 @@ def coerce(self, src: Value, target_type: RType, line: int, force: bool = False, """ if src.type.is_unboxed and not target_type.is_unboxed: return self.box(src) - if ((src.type.is_unboxed and target_type.is_unboxed) - and not is_runtime_subtype(src.type, target_type)): + if (src.type.is_unboxed and target_type.is_unboxed) and not is_runtime_subtype( + src.type, target_type + ): # To go from one unboxed type to another, we go through a boxed # in-between value, for simplicity. tmp = self.box(src) return self.unbox_or_cast(tmp, target_type, line) - if ((not src.type.is_unboxed and target_type.is_unboxed) - or not is_subtype(src.type, target_type)): + if (not src.type.is_unboxed and target_type.is_unboxed) or not is_subtype( + src.type, target_type + ): return self.unbox_or_cast(src, target_type, line, can_borrow=can_borrow) elif force: tmp = Register(target_type) @@ -203,12 +270,9 @@ def coerce(self, src: Value, target_type: RType, line: int, force: bool = False, def coerce_nullable(self, src: Value, target_type: RType, line: int) -> Value: """Generate a coercion from a potentially null value.""" - if ( - src.type.is_unboxed == target_type.is_unboxed - and ( - (target_type.is_unboxed and is_runtime_subtype(src.type, target_type)) - or (not target_type.is_unboxed and is_subtype(src.type, target_type)) - ) + if src.type.is_unboxed == target_type.is_unboxed and ( + (target_type.is_unboxed and is_runtime_subtype(src.type, target_type)) + or (not target_type.is_unboxed and is_subtype(src.type, target_type)) ): return src @@ -231,11 +295,15 @@ def coerce_nullable(self, src: Value, target_type: RType, line: int) -> Value: # Attribute access - def get_attr(self, obj: Value, attr: str, result_type: RType, line: int, *, - borrow: bool = False) -> Value: + def get_attr( + self, obj: Value, attr: str, result_type: RType, line: int, *, borrow: bool = False + ) -> Value: """Get a native or Python attribute of an object.""" - if (isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class - and obj.type.class_ir.has_attr(attr)): + if ( + isinstance(obj.type, RInstance) + and obj.type.class_ir.is_ext_class + and obj.type.class_ir.has_attr(attr) + ): if borrow: self.keep_alives.append(obj) return self.add(GetAttr(obj, attr, line, borrow=borrow)) @@ -244,12 +312,9 @@ def get_attr(self, obj: Value, attr: str, result_type: RType, line: int, *, else: return self.py_get_attr(obj, attr, line) - def union_get_attr(self, - obj: Value, - rtype: RUnion, - attr: str, - result_type: RType, - line: int) -> Value: + def union_get_attr( + self, obj: Value, rtype: RUnion, attr: str, result_type: RType, line: int + ) -> Value: """Get an attribute of an object with a union type.""" def get_item_attr(value: Value) -> Value: @@ -273,13 +338,15 @@ def isinstance_helper(self, obj: Value, class_irs: List[ClassIR], line: int) -> return self.false() ret = self.isinstance_native(obj, class_irs[0], line) for class_ir in class_irs[1:]: + def other() -> Value: return self.isinstance_native(obj, class_ir, line) - ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line) + + ret = self.shortcircuit_helper("or", bool_rprimitive, lambda: ret, other, line) return ret def get_type_of_obj(self, obj: Value, line: int) -> Value: - ob_type_address = self.add(GetElementPtr(obj, PyObject, 'ob_type', line)) + ob_type_address = self.add(GetElementPtr(obj, PyObject, "ob_type", line)) ob_type = self.add(LoadMem(object_rprimitive, ob_type_address)) self.add(KeepAlive([obj])) return ob_type @@ -297,28 +364,30 @@ def isinstance_native(self, obj: Value, class_ir: ClassIR, line: int) -> Value: """ concrete = all_concrete_classes(class_ir) if concrete is None or len(concrete) > FAST_ISINSTANCE_MAX_SUBCLASSES + 1: - return self.call_c(fast_isinstance_op, - [obj, self.get_native_type(class_ir)], - line) + return self.call_c(fast_isinstance_op, [obj, self.get_native_type(class_ir)], line) if not concrete: # There can't be any concrete instance that matches this. return self.false() type_obj = self.get_native_type(concrete[0]) ret = self.type_is_op(obj, type_obj, line) for c in concrete[1:]: + def other() -> Value: return self.type_is_op(obj, self.get_native_type(c), line) - ret = self.shortcircuit_helper('or', bool_rprimitive, lambda: ret, other, line) + + ret = self.shortcircuit_helper("or", bool_rprimitive, lambda: ret, other, line) return ret # Calls - def _construct_varargs(self, - args: Sequence[Tuple[Value, ArgKind, Optional[str]]], - line: int, - *, - has_star: bool, - has_star2: bool) -> Tuple[Optional[Value], Optional[Value]]: + def _construct_varargs( + self, + args: Sequence[Tuple[Value, ArgKind, Optional[str]]], + line: int, + *, + has_star: bool, + has_star2: bool, + ) -> Tuple[Optional[Value], Optional[Value]]: """Construct *args and **kwargs from a collection of arguments This is pretty complicated, and almost all of the complication here stems from @@ -404,11 +473,7 @@ def _construct_varargs(self, if star2_result is None: star2_result = self._create_dict(star2_keys, star2_values, line) - self.call_c( - dict_update_in_display_op, - [star2_result, value], - line=line - ) + self.call_c(dict_update_in_display_op, [star2_result, value], line=line) else: nullable = kind.is_optional() maybe_pos = kind.is_positional() and has_star @@ -467,7 +532,8 @@ def _construct_varargs(self, self.activate_block(pos_block) assert star_result self.translate_special_method_call( - star_result, 'append', [value], result_type=None, line=line) + star_result, "append", [value], result_type=None, line=line + ) self.goto(out) if maybe_named and (not maybe_pos or seen_empty_reg): @@ -476,7 +542,8 @@ def _construct_varargs(self, key = self.load_str(name) assert star2_result self.translate_special_method_call( - star2_result, '__setitem__', [key, value], result_type=None, line=line) + star2_result, "__setitem__", [key, value], result_type=None, line=line + ) self.goto(out) if nullable and maybe_pos and new_seen_empty_reg: @@ -504,12 +571,14 @@ def _construct_varargs(self, return star_result, star2_result - def py_call(self, - function: Value, - arg_values: List[Value], - line: int, - arg_kinds: Optional[List[ArgKind]] = None, - arg_names: Optional[Sequence[Optional[str]]] = None) -> Value: + def py_call( + self, + function: Value, + arg_values: List[Value], + line: int, + arg_kinds: Optional[List[ArgKind]] = None, + arg_names: Optional[Sequence[Optional[str]]] = None, + ) -> Value: """Call a Python function (non-native and slow). Use py_call_op or py_call_with_kwargs_op for Python function call. @@ -532,23 +601,25 @@ def py_call(self, ) assert pos_args_tuple and kw_args_dict - return self.call_c( - py_call_with_kwargs_op, [function, pos_args_tuple, kw_args_dict], line) + return self.call_c(py_call_with_kwargs_op, [function, pos_args_tuple, kw_args_dict], line) - def _py_vector_call(self, - function: Value, - arg_values: List[Value], - line: int, - arg_kinds: Optional[List[ArgKind]] = None, - arg_names: Optional[Sequence[Optional[str]]] = None) -> Optional[Value]: + def _py_vector_call( + self, + function: Value, + arg_values: List[Value], + line: int, + arg_kinds: Optional[List[ArgKind]] = None, + arg_names: Optional[Sequence[Optional[str]]] = None, + ) -> Optional[Value]: """Call function using the vectorcall API if possible. Return the return value if successful. Return None if a non-vectorcall API should be used instead. """ # We can do this if all args are positional or named (no *args or **kwargs, not optional). - if arg_kinds is None or all(not kind.is_star() and not kind.is_optional() - for kind in arg_kinds): + if arg_kinds is None or all( + not kind.is_star() and not kind.is_optional() for kind in arg_kinds + ): if arg_values: # Create a C array containing all arguments as boxed values. array = Register(RArray(object_rprimitive, len(arg_values))) @@ -559,11 +630,11 @@ def _py_vector_call(self, arg_ptr = Integer(0, object_pointer_rprimitive) num_pos = num_positional_args(arg_values, arg_kinds) keywords = self._vectorcall_keywords(arg_names) - value = self.call_c(py_vectorcall_op, [function, - arg_ptr, - Integer(num_pos, c_size_t_rprimitive), - keywords], - line) + value = self.call_c( + py_vectorcall_op, + [function, arg_ptr, Integer(num_pos, c_size_t_rprimitive), keywords], + line, + ) if arg_values: # Make sure arguments won't be freed until after the call. # We need this because RArray doesn't support automatic @@ -583,18 +654,21 @@ def _vectorcall_keywords(self, arg_names: Optional[Sequence[Optional[str]]]) -> return self.add(LoadLiteral(tuple(kw_list), object_rprimitive)) return Integer(0, object_rprimitive) - def py_method_call(self, - obj: Value, - method_name: str, - arg_values: List[Value], - line: int, - arg_kinds: Optional[List[ArgKind]], - arg_names: Optional[Sequence[Optional[str]]]) -> Value: + def py_method_call( + self, + obj: Value, + method_name: str, + arg_values: List[Value], + line: int, + arg_kinds: Optional[List[ArgKind]], + arg_names: Optional[Sequence[Optional[str]]], + ) -> Value: """Call a Python method (non-native and slow).""" if use_method_vectorcall(self.options.capi_version): # More recent Python versions support faster vectorcalls. result = self._py_vector_method_call( - obj, method_name, arg_values, line, arg_kinds, arg_names) + obj, method_name, arg_values, line, arg_kinds, arg_names + ) if result is not None: return result @@ -607,36 +681,43 @@ def py_method_call(self, method = self.py_get_attr(obj, method_name, line) return self.py_call(method, arg_values, line, arg_kinds=arg_kinds, arg_names=arg_names) - def _py_vector_method_call(self, - obj: Value, - method_name: str, - arg_values: List[Value], - line: int, - arg_kinds: Optional[List[ArgKind]], - arg_names: Optional[Sequence[Optional[str]]]) -> Optional[Value]: + def _py_vector_method_call( + self, + obj: Value, + method_name: str, + arg_values: List[Value], + line: int, + arg_kinds: Optional[List[ArgKind]], + arg_names: Optional[Sequence[Optional[str]]], + ) -> Optional[Value]: """Call method using the vectorcall API if possible. Return the return value if successful. Return None if a non-vectorcall API should be used instead. """ - if arg_kinds is None or all(not kind.is_star() and not kind.is_optional() - for kind in arg_kinds): + if arg_kinds is None or all( + not kind.is_star() and not kind.is_optional() for kind in arg_kinds + ): method_name_reg = self.load_str(method_name) array = Register(RArray(object_rprimitive, len(arg_values) + 1)) self_arg = self.coerce(obj, object_rprimitive, line) - coerced_args = [self_arg] + [self.coerce(arg, object_rprimitive, line) - for arg in arg_values] + coerced_args = [self_arg] + [ + self.coerce(arg, object_rprimitive, line) for arg in arg_values + ] self.add(AssignMulti(array, coerced_args)) arg_ptr = self.add(LoadAddress(object_pointer_rprimitive, array)) num_pos = num_positional_args(arg_values, arg_kinds) keywords = self._vectorcall_keywords(arg_names) - value = self.call_c(py_vectorcall_method_op, - [method_name_reg, - arg_ptr, - Integer((num_pos + 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, - c_size_t_rprimitive), - keywords], - line) + value = self.call_c( + py_vectorcall_method_op, + [ + method_name_reg, + arg_ptr, + Integer((num_pos + 1) | PY_VECTORCALL_ARGUMENTS_OFFSET, c_size_t_rprimitive), + keywords, + ], + line, + ) # Make sure arguments won't be freed until after the call. # We need this because RArray doesn't support automatic # memory management. @@ -644,24 +725,27 @@ def _py_vector_method_call(self, return value return None - def call(self, - decl: FuncDecl, - args: Sequence[Value], - arg_kinds: List[ArgKind], - arg_names: Sequence[Optional[str]], - line: int) -> Value: + def call( + self, + decl: FuncDecl, + args: Sequence[Value], + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + line: int, + ) -> Value: """Call a native function.""" # Normalize args to positionals. - args = self.native_args_to_positional( - args, arg_kinds, arg_names, decl.sig, line) + args = self.native_args_to_positional(args, arg_kinds, arg_names, decl.sig, line) return self.add(Call(decl, args, line)) - def native_args_to_positional(self, - args: Sequence[Value], - arg_kinds: List[ArgKind], - arg_names: Sequence[Optional[str]], - sig: FuncSignature, - line: int) -> List[Value]: + def native_args_to_positional( + self, + args: Sequence[Value], + arg_kinds: List[ArgKind], + arg_names: Sequence[Optional[str]], + sig: FuncSignature, + line: int, + ) -> List[Value]: """Prepare arguments for a native call. Given args/kinds/names and a target signature for a native call, map @@ -674,11 +758,13 @@ def native_args_to_positional(self, sig_arg_kinds = [arg.kind for arg in sig.args] sig_arg_names = [arg.name for arg in sig.args] concrete_kinds = [concrete_arg_kind(arg_kind) for arg_kind in arg_kinds] - formal_to_actual = map_actuals_to_formals(concrete_kinds, - arg_names, - sig_arg_kinds, - sig_arg_names, - lambda n: AnyType(TypeOfAny.special_form)) + formal_to_actual = map_actuals_to_formals( + concrete_kinds, + arg_names, + sig_arg_kinds, + sig_arg_names, + lambda n: AnyType(TypeOfAny.special_form), + ) # First scan for */** and construct those has_star = has_star2 = False @@ -718,23 +804,28 @@ def native_args_to_positional(self, return output_args - def gen_method_call(self, - base: Value, - name: str, - arg_values: List[Value], - result_type: Optional[RType], - line: int, - arg_kinds: Optional[List[ArgKind]] = None, - arg_names: Optional[List[Optional[str]]] = None, - can_borrow: bool = False) -> Value: + def gen_method_call( + self, + base: Value, + name: str, + arg_values: List[Value], + result_type: Optional[RType], + line: int, + arg_kinds: Optional[List[ArgKind]] = None, + arg_names: Optional[List[Optional[str]]] = None, + can_borrow: bool = False, + ) -> Value: """Generate either a native or Python method call.""" # If we have *args, then fallback to Python method call. if arg_kinds is not None and any(kind.is_star() for kind in arg_kinds): return self.py_method_call(base, name, arg_values, base.line, arg_kinds, arg_names) # If the base type is one of ours, do a MethodCall - if (isinstance(base.type, RInstance) and base.type.class_ir.is_ext_class - and not base.type.class_ir.builtin_base): + if ( + isinstance(base.type, RInstance) + and base.type.class_ir.is_ext_class + and not base.type.class_ir.builtin_base + ): if base.type.class_ir.has_method(name): decl = base.type.class_ir.method_decl(name) if arg_kinds is None: @@ -747,44 +838,51 @@ def gen_method_call(self, # Normalize args to positionals. assert decl.bound_sig arg_values = self.native_args_to_positional( - arg_values, arg_kinds, arg_names, decl.bound_sig, line) + arg_values, arg_kinds, arg_names, decl.bound_sig, line + ) return self.add(MethodCall(base, name, arg_values, line)) elif base.type.class_ir.has_attr(name): function = self.add(GetAttr(base, name, line)) - return self.py_call(function, arg_values, line, - arg_kinds=arg_kinds, arg_names=arg_names) + return self.py_call( + function, arg_values, line, arg_kinds=arg_kinds, arg_names=arg_names + ) elif isinstance(base.type, RUnion): - return self.union_method_call(base, base.type, name, arg_values, result_type, line, - arg_kinds, arg_names) + return self.union_method_call( + base, base.type, name, arg_values, result_type, line, arg_kinds, arg_names + ) # Try to do a special-cased method call if not arg_kinds or arg_kinds == [ARG_POS] * len(arg_values): target = self.translate_special_method_call( - base, name, arg_values, result_type, line, can_borrow=can_borrow) + base, name, arg_values, result_type, line, can_borrow=can_borrow + ) if target: return target # Fall back to Python method call return self.py_method_call(base, name, arg_values, line, arg_kinds, arg_names) - def union_method_call(self, - base: Value, - obj_type: RUnion, - name: str, - arg_values: List[Value], - return_rtype: Optional[RType], - line: int, - arg_kinds: Optional[List[ArgKind]], - arg_names: Optional[List[Optional[str]]]) -> Value: + def union_method_call( + self, + base: Value, + obj_type: RUnion, + name: str, + arg_values: List[Value], + return_rtype: Optional[RType], + line: int, + arg_kinds: Optional[List[ArgKind]], + arg_names: Optional[List[Optional[str]]], + ) -> Value: """Generate a method call with a union type for the object.""" # Union method call needs a return_rtype for the type of the output register. # If we don't have one, use object_rprimitive. return_rtype = return_rtype or object_rprimitive def call_union_item(value: Value) -> Value: - return self.gen_method_call(value, name, arg_values, return_rtype, line, - arg_kinds, arg_names) + return self.gen_method_call( + value, name, arg_values, return_rtype, line, arg_kinds, arg_names + ) return self.decompose_union_helper(base, obj_type, return_rtype, call_union_item, line) @@ -833,19 +931,22 @@ def load_complex(self, value: complex) -> Value: """Load a complex literal value.""" return self.add(LoadLiteral(value, object_rprimitive)) - def load_static_checked(self, typ: RType, identifier: str, module_name: Optional[str] = None, - namespace: str = NAMESPACE_STATIC, - line: int = -1, - error_msg: Optional[str] = None) -> Value: + def load_static_checked( + self, + typ: RType, + identifier: str, + module_name: Optional[str] = None, + namespace: str = NAMESPACE_STATIC, + line: int = -1, + error_msg: Optional[str] = None, + ) -> Value: if error_msg is None: error_msg = f'name "{identifier}" is not defined' ok_block, error_block = BasicBlock(), BasicBlock() value = self.add(LoadStatic(typ, identifier, module_name, namespace, line=line)) self.add(Branch(value, error_block, ok_block, Branch.IS_ERROR, rare=True)) self.activate_block(error_block) - self.add(RaiseStandardError(RaiseStandardError.NAME_ERROR, - error_msg, - line)) + self.add(RaiseStandardError(RaiseStandardError.NAME_ERROR, error_msg, line)) self.add(Unreachable()) self.activate_block(ok_block) return value @@ -855,11 +956,11 @@ def load_module(self, name: str) -> Value: def get_native_type(self, cls: ClassIR) -> Value: """Load native type object.""" - fullname = f'{cls.module_name}.{cls.name}' + fullname = f"{cls.module_name}.{cls.name}" return self.load_native_type_object(fullname) def load_native_type_object(self, fullname: str) -> Value: - module, name = fullname.rsplit('.', 1) + module, name = fullname.rsplit(".", 1) return self.add(LoadStatic(object_rprimitive, name, module, NAMESPACE_TYPE)) # Other primitive operations @@ -868,35 +969,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: rtype = rreg.type # Special case tuple comparison here so that nested tuples can be supported - if isinstance(ltype, RTuple) and isinstance(rtype, RTuple) and op in ('==', '!='): + if isinstance(ltype, RTuple) and isinstance(rtype, RTuple) and op in ("==", "!="): return self.compare_tuples(lreg, rreg, op, line) # Special case == and != when we can resolve the method call statically - if op in ('==', '!='): + if op in ("==", "!="): value = self.translate_eq_cmp(lreg, rreg, op, line) if value is not None: return value # Special case various ops - if op in ('is', 'is not'): + if op in ("is", "is not"): return self.translate_is_op(lreg, rreg, op, line) # TODO: modify 'str' to use same interface as 'compare_bytes' as it avoids # call to PyErr_Occurred() - if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ('==', '!='): + if is_str_rprimitive(ltype) and is_str_rprimitive(rtype) and op in ("==", "!="): return self.compare_strings(lreg, rreg, op, line) - if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ('==', '!='): + if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="): return self.compare_bytes(lreg, rreg, op, line) if is_tagged(ltype) and is_tagged(rtype) and op in int_comparison_op_mapping: return self.compare_tagged(lreg, rreg, op, line) - if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in ( - '&', '&=', '|', '|=', '^', '^='): + if ( + is_bool_rprimitive(ltype) + and is_bool_rprimitive(rtype) + and op in ("&", "&=", "|", "|=", "^", "^=") + ): return self.bool_bitwise_op(lreg, rreg, op[0], line) - if isinstance(rtype, RInstance) and op in ('in', 'not in'): + if isinstance(rtype, RInstance) and op in ("in", "not in"): return self.translate_instance_contains(rreg, lreg, op, line) call_c_ops_candidates = binary_ops.get(op, []) target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line) - assert target, 'Unsupported binary operation: %s' % op + assert target, "Unsupported binary operation: %s" % op return target def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value: @@ -946,13 +1050,9 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: self.goto_and_activate(out) return result - def compare_tagged_condition(self, - lhs: Value, - rhs: Value, - op: str, - true: BasicBlock, - false: BasicBlock, - line: int) -> None: + def compare_tagged_condition( + self, lhs: Value, rhs: Value, op: str, true: BasicBlock, false: BasicBlock, line: int + ) -> None: """Compare two tagged integers using given operator (conditional context). Assume lhs and and rhs are tagged integers. @@ -965,9 +1065,9 @@ def compare_tagged_condition(self, false: Branch target if comparison is false """ is_eq = op in ("==", "!=") - if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) - or (is_eq and (is_short_int_rprimitive(lhs.type) or - is_short_int_rprimitive(rhs.type)))): + if (is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type)) or ( + is_eq and (is_short_int_rprimitive(lhs.type) or is_short_int_rprimitive(rhs.type)) + ): # We can skip the tag check check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line) self.flush_keep_alives() @@ -1008,8 +1108,9 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: """Compare two strings""" compare_result = self.call_c(unicode_compare, [lhs, rhs], line) error_constant = Integer(-1, c_int_rprimitive, line) - compare_error_check = self.add(ComparisonOp(compare_result, - error_constant, ComparisonOp.EQ, line)) + compare_error_check = self.add( + ComparisonOp(compare_result, error_constant, ComparisonOp.EQ, line) + ) exception_check, propagate, final_compare = BasicBlock(), BasicBlock(), BasicBlock() branch = Branch(compare_error_check, exception_check, final_compare, Branch.BOOL) branch.negated = False @@ -1017,8 +1118,9 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: self.activate_block(exception_check) check_error_result = self.call_c(err_occurred_op, [], line) null = Integer(0, pointer_rprimitive, line) - compare_error_check = self.add(ComparisonOp(check_error_result, - null, ComparisonOp.NEQ, line)) + compare_error_check = self.add( + ComparisonOp(check_error_result, null, ComparisonOp.NEQ, line) + ) branch = Branch(compare_error_check, propagate, final_compare, Branch.BOOL) branch.negated = False self.add(branch) @@ -1026,25 +1128,19 @@ def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: self.call_c(keep_propagating_op, [], line) self.goto(final_compare) self.activate_block(final_compare) - op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ - return self.add(ComparisonOp(compare_result, - Integer(0, c_int_rprimitive), op_type, line)) + op_type = ComparisonOp.EQ if op == "==" else ComparisonOp.NEQ + return self.add(ComparisonOp(compare_result, Integer(0, c_int_rprimitive), op_type, line)) def compare_bytes(self, lhs: Value, rhs: Value, op: str, line: int) -> Value: compare_result = self.call_c(bytes_compare, [lhs, rhs], line) - op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ - return self.add(ComparisonOp(compare_result, - Integer(1, c_int_rprimitive), op_type, line)) - - def compare_tuples(self, - lhs: Value, - rhs: Value, - op: str, - line: int = -1) -> Value: + op_type = ComparisonOp.EQ if op == "==" else ComparisonOp.NEQ + return self.add(ComparisonOp(compare_result, Integer(1, c_int_rprimitive), op_type, line)) + + def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Value: """Compare two tuples item by item""" # type cast to pass mypy check assert isinstance(lhs.type, RTuple) and isinstance(rhs.type, RTuple) - equal = True if op == '==' else False + equal = True if op == "==" else False result = Register(bool_rprimitive) # empty tuples if len(lhs.type.types) == 0 and len(rhs.type.types) == 0: @@ -1087,49 +1183,44 @@ def compare_tuples(self, return result def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value: - res = self.gen_method_call(inst, '__contains__', [item], None, line) + res = self.gen_method_call(inst, "__contains__", [item], None, line) if not is_bool_rprimitive(res.type): res = self.call_c(bool_op, [res], line) - if op == 'not in': - res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), '^', line) + if op == "not in": + res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line) return res def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value: - if op == '&': + if op == "&": code = IntOp.AND - elif op == '|': + elif op == "|": code = IntOp.OR - elif op == '^': + elif op == "^": code = IntOp.XOR else: assert False, op return self.add(IntOp(bool_rprimitive, lreg, rreg, code, line)) - def unary_not(self, - value: Value, - line: int) -> Value: + def unary_not(self, value: Value, line: int) -> Value: mask = Integer(1, value.type, line) return self.int_op(value.type, value, mask, IntOp.XOR, line) - def unary_op(self, - value: Value, - expr_op: str, - line: int) -> Value: + def unary_op(self, value: Value, expr_op: str, line: int) -> Value: typ = value.type - if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == 'not': + if (is_bool_rprimitive(typ) or is_bit_rprimitive(typ)) and expr_op == "not": return self.unary_not(value, line) if isinstance(typ, RInstance): - if expr_op == '-': - method = '__neg__' - elif expr_op == '~': - method = '__invert__' + if expr_op == "-": + method = "__neg__" + elif expr_op == "~": + method = "__invert__" else: - method = '' + method = "" if method and typ.class_ir.has_method(method): return self.gen_method_call(value, method, [], None, line) call_c_ops_candidates = unary_ops.get(expr_op, []) target = self.matching_call_c(call_c_ops_candidates, [value], line) - assert target, 'Unsupported unary operation: %s' % expr_op + assert target, "Unsupported unary operation: %s" % expr_op return target def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value: @@ -1145,21 +1236,14 @@ def make_dict(self, key_value_pairs: Sequence[DictEntry], line: int) -> Value: continue self.translate_special_method_call( - result, - '__setitem__', - [key, value], - result_type=None, - line=line) + result, "__setitem__", [key, value], result_type=None, line=line + ) else: # **value if result is None: result = self._create_dict(keys, values, line) - self.call_c( - dict_update_in_display_op, - [result, value], - line=line - ) + self.call_c(dict_update_in_display_op, [result, value], line=line) if result is None: result = self._create_dict(keys, values, line) @@ -1193,15 +1277,16 @@ def new_list_op(self, values: List[Value], line: int) -> Value: if len(values) == 0: return result_list args = [self.coerce(item, object_rprimitive, line) for item in values] - ob_item_ptr = self.add(GetElementPtr(result_list, PyListObject, 'ob_item', line)) + ob_item_ptr = self.add(GetElementPtr(result_list, PyListObject, "ob_item", line)) ob_item_base = self.add(LoadMem(pointer_rprimitive, ob_item_ptr, line)) for i in range(len(values)): if i == 0: item_address = ob_item_base else: offset = Integer(PLATFORM_SIZE * i, c_pyssize_t_rprimitive, line) - item_address = self.add(IntOp(pointer_rprimitive, ob_item_base, offset, - IntOp.ADD, line)) + item_address = self.add( + IntOp(pointer_rprimitive, ob_item_base, offset, IntOp.ADD, line) + ) self.add(SetMem(object_rprimitive, item_address, args[i], line)) self.add(KeepAlive([result_list])) return result_list @@ -1209,10 +1294,14 @@ def new_list_op(self, values: List[Value], line: int) -> Value: def new_set_op(self, values: List[Value], line: int) -> Value: return self.call_c(new_set_op, values, line) - def shortcircuit_helper(self, op: str, - expr_type: RType, - left: Callable[[], Value], - right: Callable[[], Value], line: int) -> Value: + def shortcircuit_helper( + self, + op: str, + expr_type: RType, + left: Callable[[], Value], + right: Callable[[], Value], + line: int, + ) -> Value: # Having actual Phi nodes would be really nice here! target = Register(expr_type) # left_body takes the value of the left side, right_body the right @@ -1220,8 +1309,7 @@ def shortcircuit_helper(self, op: str, # true_body is taken if the left is true, false_body if it is false. # For 'and' the value is the right side if the left is true, and for 'or' # it is the right side if the left is false. - true_body, false_body = ( - (right_body, left_body) if op == 'and' else (left_body, right_body)) + true_body, false_body = (right_body, left_body) if op == "and" else (left_body, right_body) left_value = left() self.add_bool_branch(left_value, true_body, false_body) @@ -1243,30 +1331,35 @@ def shortcircuit_helper(self, op: str, def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> None: if is_runtime_subtype(value.type, int_rprimitive): zero = Integer(0, short_int_rprimitive) - self.compare_tagged_condition(value, zero, '!=', true, false, value.line) + self.compare_tagged_condition(value, zero, "!=", true, false, value.line) return elif is_same_type(value.type, str_rprimitive): value = self.call_c(str_check_if_true, [value], value.line) - elif (is_same_type(value.type, list_rprimitive) - or is_same_type(value.type, dict_rprimitive)): + elif is_same_type(value.type, list_rprimitive) or is_same_type( + value.type, dict_rprimitive + ): length = self.builtin_len(value, value.line) zero = Integer(0) - value = self.binary_op(length, zero, '!=', value.line) - elif (isinstance(value.type, RInstance) and value.type.class_ir.is_ext_class - and value.type.class_ir.has_method('__bool__')): + value = self.binary_op(length, zero, "!=", value.line) + elif ( + isinstance(value.type, RInstance) + and value.type.class_ir.is_ext_class + and value.type.class_ir.has_method("__bool__") + ): # Directly call the __bool__ method on classes that have it. - value = self.gen_method_call(value, '__bool__', [], bool_rprimitive, value.line) + value = self.gen_method_call(value, "__bool__", [], bool_rprimitive, value.line) else: value_type = optional_value_type(value.type) if value_type is not None: - is_none = self.translate_is_op(value, self.none_object(), 'is not', value.line) + is_none = self.translate_is_op(value, self.none_object(), "is not", value.line) branch = Branch(is_none, true, false, Branch.BOOL) self.add(branch) always_truthy = False if isinstance(value_type, RInstance): # check whether X.__bool__ is always just the default (object.__bool__) - if (not value_type.class_ir.has_method('__bool__') - and value_type.class_ir.is_method_final('__bool__')): + if not value_type.class_ir.has_method( + "__bool__" + ) and value_type.class_ir.is_method_final("__bool__"): always_truthy = True if not always_truthy: @@ -1282,11 +1375,13 @@ def add_bool_branch(self, value: Value, true: BasicBlock, false: BasicBlock) -> value = self.call_c(bool_op, [value], value.line) self.add(Branch(value, true, false, Branch.BOOL)) - def call_c(self, - desc: CFunctionDescription, - args: List[Value], - line: int, - result_type: Optional[RType] = None) -> Value: + def call_c( + self, + desc: CFunctionDescription, + args: List[Value], + line: int, + result_type: Optional[RType] = None, + ) -> Value: """Call function using C/native calling convention (not a Python callable).""" # Handle void function via singleton RVoid instance coerced = [] @@ -1317,8 +1412,18 @@ def call_c(self, if error_kind == ERR_NEG_INT: # Handled with an explicit comparison error_kind = ERR_NEVER - target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals, - desc.is_borrowed, error_kind, line, var_arg_idx)) + target = self.add( + CallC( + desc.c_function_name, + coerced, + desc.return_type, + desc.steals, + desc.is_borrowed, + error_kind, + line, + var_arg_idx, + ) + ) if desc.is_borrowed: # If the result is borrowed, force the arguments to be # kept alive afterwards, as otherwise the result might be @@ -1327,10 +1432,7 @@ def call_c(self, if not isinstance(arg, (Integer, LoadLiteral)): self.keep_alives.append(arg) if desc.error_kind == ERR_NEG_INT: - comp = ComparisonOp(target, - Integer(0, desc.return_type, line), - ComparisonOp.SGE, - line) + comp = ComparisonOp(target, Integer(0, desc.return_type, line), ComparisonOp.SGE, line) comp.error_kind = ERR_FALSE self.add(comp) @@ -1348,22 +1450,25 @@ def call_c(self, result = self.coerce(target, result_type, line, can_borrow=desc.is_borrowed) return result - def matching_call_c(self, - candidates: List[CFunctionDescription], - args: List[Value], - line: int, - result_type: Optional[RType] = None, - can_borrow: bool = False) -> Optional[Value]: + def matching_call_c( + self, + candidates: List[CFunctionDescription], + args: List[Value], + line: int, + result_type: Optional[RType] = None, + can_borrow: bool = False, + ) -> Optional[Value]: matching: Optional[CFunctionDescription] = None for desc in candidates: if len(desc.arg_types) != len(args): continue - if (all(is_subtype(actual.type, formal) - for actual, formal in zip(args, desc.arg_types)) and - (not desc.is_borrowed or can_borrow)): + if all( + is_subtype(actual.type, formal) for actual, formal in zip(args, desc.arg_types) + ) and (not desc.is_borrowed or can_borrow): if matching: - assert matching.priority != desc.priority, 'Ambiguous:\n1) {}\n2) {}'.format( - matching, desc) + assert matching.priority != desc.priority, "Ambiguous:\n1) {}\n2) {}".format( + matching, desc + ) if desc.priority > matching.priority: matching = desc else: @@ -1387,13 +1492,12 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val """ typ = val.type size_value = None - if (is_list_rprimitive(typ) or is_tuple_rprimitive(typ) - or is_bytes_rprimitive(typ)): - elem_address = self.add(GetElementPtr(val, PyVarObject, 'ob_size')) + if is_list_rprimitive(typ) or is_tuple_rprimitive(typ) or is_bytes_rprimitive(typ): + elem_address = self.add(GetElementPtr(val, PyVarObject, "ob_size")) size_value = self.add(LoadMem(c_pyssize_t_rprimitive, elem_address)) self.add(KeepAlive([val])) elif is_set_rprimitive(typ): - elem_address = self.add(GetElementPtr(val, PySetObject, 'used')) + elem_address = self.add(GetElementPtr(val, PySetObject, "used")) size_value = self.add(LoadMem(c_pyssize_t_rprimitive, elem_address)) self.add(KeepAlive([val])) elif is_dict_rprimitive(typ): @@ -1405,20 +1509,21 @@ def builtin_len(self, val: Value, line: int, use_pyssize_t: bool = False) -> Val if use_pyssize_t: return size_value offset = Integer(1, c_pyssize_t_rprimitive, line) - return self.int_op(short_int_rprimitive, size_value, offset, - IntOp.LEFT_SHIFT, line) + return self.int_op(short_int_rprimitive, size_value, offset, IntOp.LEFT_SHIFT, line) if isinstance(typ, RInstance): # TODO: Support use_pyssize_t assert not use_pyssize_t - length = self.gen_method_call(val, '__len__', [], int_rprimitive, line) + length = self.gen_method_call(val, "__len__", [], int_rprimitive, line) length = self.coerce(length, int_rprimitive, line) ok, fail = BasicBlock(), BasicBlock() - self.compare_tagged_condition(length, Integer(0), '>=', ok, fail, line) + self.compare_tagged_condition(length, Integer(0), ">=", ok, fail, line) self.activate_block(fail) - self.add(RaiseStandardError(RaiseStandardError.VALUE_ERROR, - "__len__() should return >= 0", - line)) + self.add( + RaiseStandardError( + RaiseStandardError.VALUE_ERROR, "__len__() should return >= 0", line + ) + ) self.add(Unreachable()) self.activate_block(ok) return length @@ -1449,12 +1554,14 @@ def new_tuple_with_length(self, length: Value, line: int) -> Value: # Internal helpers - def decompose_union_helper(self, - obj: Value, - rtype: RUnion, - result_type: RType, - process_item: Callable[[Value], Value], - line: int) -> Value: + def decompose_union_helper( + self, + obj: Value, + rtype: RUnion, + result_type: RType, + process_item: Callable[[Value], Value], + line: int, + ) -> Value: """Generate isinstance() + specialized operations for union items. Say, for Union[A, B] generate ops resembling this (pseudocode): @@ -1510,13 +1617,15 @@ def decompose_union_helper(self, self.activate_block(exit_block) return result - def translate_special_method_call(self, - base_reg: Value, - name: str, - args: List[Value], - result_type: Optional[RType], - line: int, - can_borrow: bool = False) -> Optional[Value]: + def translate_special_method_call( + self, + base_reg: Value, + name: str, + args: List[Value], + result_type: Optional[RType], + line: int, + can_borrow: bool = False, + ) -> Optional[Value]: """Translate a method call which is handled nongenerically. These are special in the sense that we have code generated specifically for them. @@ -1526,15 +1635,14 @@ def translate_special_method_call(self, Return None if no translation found; otherwise return the target register. """ call_c_ops_candidates = method_call_ops.get(name, []) - call_c_op = self.matching_call_c(call_c_ops_candidates, [base_reg] + args, - line, result_type, can_borrow=can_borrow) + call_c_op = self.matching_call_c( + call_c_ops_candidates, [base_reg] + args, line, result_type, can_borrow=can_borrow + ) return call_c_op - def translate_eq_cmp(self, - lreg: Value, - rreg: Value, - expr_op: str, - line: int) -> Optional[Value]: + def translate_eq_cmp( + self, lreg: Value, rreg: Value, expr_op: str, line: int + ) -> Optional[Value]: """Add a equality comparison operation. Args: @@ -1550,8 +1658,8 @@ def translate_eq_cmp(self, # or it might be redefined in a Python parent class or by # dataclasses cmp_varies_at_runtime = ( - not class_ir.is_method_final('__eq__') - or not class_ir.is_method_final('__ne__') + not class_ir.is_method_final("__eq__") + or not class_ir.is_method_final("__ne__") or class_ir.inherits_python or class_ir.is_augmented ) @@ -1561,38 +1669,25 @@ def translate_eq_cmp(self, # depending on which is the more specific type. return None - if not class_ir.has_method('__eq__'): + if not class_ir.has_method("__eq__"): # There's no __eq__ defined, so just use object identity. - identity_ref_op = 'is' if expr_op == '==' else 'is not' + identity_ref_op = "is" if expr_op == "==" else "is not" return self.translate_is_op(lreg, rreg, identity_ref_op, line) - return self.gen_method_call( - lreg, - op_methods[expr_op], - [rreg], - ltype, - line - ) + return self.gen_method_call(lreg, op_methods[expr_op], [rreg], ltype, line) - def translate_is_op(self, - lreg: Value, - rreg: Value, - expr_op: str, - line: int) -> Value: + def translate_is_op(self, lreg: Value, rreg: Value, expr_op: str, line: int) -> Value: """Create equality comparison operation between object identities Args: expr_op: either 'is' or 'is not' """ - op = ComparisonOp.EQ if expr_op == 'is' else ComparisonOp.NEQ + op = ComparisonOp.EQ if expr_op == "is" else ComparisonOp.NEQ lhs = self.coerce(lreg, object_rprimitive, line) rhs = self.coerce(rreg, object_rprimitive, line) return self.add(ComparisonOp(lhs, rhs, op, line)) - def _create_dict(self, - keys: List[Value], - values: List[Value], - line: int) -> Value: + def _create_dict(self, keys: List[Value], values: List[Value], line: int) -> Value: """Create a dictionary(possibly empty) using keys and values""" # keys and values should have the same number of items size = len(keys) diff --git a/mypyc/irbuild/main.py b/mypyc/irbuild/main.py index 52c9d5cf32dfe..29df7f173424a 100644 --- a/mypyc/irbuild/main.py +++ b/mypyc/irbuild/main.py @@ -20,42 +20,42 @@ def f(x: int) -> int: below, mypyc.irbuild.builder, and mypyc.irbuild.visitor. """ -from mypy.backports import OrderedDict -from typing import List, Dict, Callable, Any, TypeVar, cast +from typing import Any, Callable, Dict, List, TypeVar, cast -from mypy.nodes import MypyFile, Expression, ClassDef -from mypy.types import Type -from mypy.state import state +from mypy.backports import OrderedDict from mypy.build import Graph - +from mypy.nodes import ClassDef, Expression, MypyFile +from mypy.state import state +from mypy.types import Type +from mypyc.analysis.attrdefined import analyze_always_defined_attrs from mypyc.common import TOP_LEVEL_NAME from mypyc.errors import Errors -from mypyc.options import CompilerOptions -from mypyc.ir.rtypes import none_rprimitive +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature from mypyc.ir.module_ir import ModuleIR, ModuleIRs -from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature +from mypyc.ir.rtypes import none_rprimitive +from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.prebuildvisitor import PreBuildVisitor -from mypyc.irbuild.vtable import compute_vtable from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls -from mypyc.irbuild.builder import IRBuilder from mypyc.irbuild.visitor import IRBuilderVisitor -from mypyc.irbuild.mapper import Mapper -from mypyc.analysis.attrdefined import analyze_always_defined_attrs - +from mypyc.irbuild.vtable import compute_vtable +from mypyc.options import CompilerOptions # The stubs for callable contextmanagers are busted so cast it to the # right type... -F = TypeVar('F', bound=Callable[..., Any]) +F = TypeVar("F", bound=Callable[..., Any]) strict_optional_dec = cast(Callable[[F], F], state.strict_optional_set(True)) @strict_optional_dec # Turn on strict optional for any type manipulations we do -def build_ir(modules: List[MypyFile], - graph: Graph, - types: Dict[Expression, Type], - mapper: Mapper, - options: CompilerOptions, - errors: Errors) -> ModuleIRs: +def build_ir( + modules: List[MypyFile], + graph: Graph, + types: Dict[Expression, Type], + mapper: Mapper, + options: CompilerOptions, + errors: Errors, +) -> ModuleIRs: """Build IR for a set of modules that have been type-checked by mypy.""" build_type_map(mapper, modules, graph, types, options, errors) @@ -74,7 +74,14 @@ def build_ir(modules: List[MypyFile], # Construct and configure builder objects (cyclic runtime dependency). visitor = IRBuilderVisitor() builder = IRBuilder( - module.fullname, types, graph, errors, mapper, pbv, visitor, options, + module.fullname, + types, + graph, + errors, + mapper, + pbv, + visitor, + options, singledispatch_info.singledispatch_impls, ) visitor.builder = builder @@ -86,7 +93,7 @@ def build_ir(modules: List[MypyFile], list(builder.imports), builder.functions, builder.classes, - builder.final_names + builder.final_names, ) result[module.fullname] = module_ir class_irs.extend(builder.classes) @@ -104,7 +111,7 @@ def build_ir(modules: List[MypyFile], def transform_mypy_file(builder: IRBuilder, mypyfile: MypyFile) -> None: """Generate IR for a single module.""" - if mypyfile.fullname in ('typing', 'abc'): + if mypyfile.fullname in ("typing", "abc"): # These module are special; their contents are currently all # built-in primitives. return @@ -118,10 +125,10 @@ def transform_mypy_file(builder: IRBuilder, mypyfile: MypyFile) -> None: ir = builder.mapper.type_to_ir[cls.info] builder.classes.append(ir) - builder.enter('') + builder.enter("") # Make sure we have a builtins import - builder.gen_import('builtins', -1) + builder.gen_import("builtins", -1) # Generate ops. for node in mypyfile.defs: @@ -132,6 +139,10 @@ def transform_mypy_file(builder: IRBuilder, mypyfile: MypyFile) -> None: # Generate special function representing module top level. args, _, blocks, ret_type, _ = builder.leave() sig = FuncSignature([], none_rprimitive) - func_ir = FuncIR(FuncDecl(TOP_LEVEL_NAME, None, builder.module_name, sig), args, blocks, - traceback_name="") + func_ir = FuncIR( + FuncDecl(TOP_LEVEL_NAME, None, builder.module_name, sig), + args, + blocks, + traceback_name="", + ) builder.functions.append(func_ir) diff --git a/mypyc/irbuild/mapper.py b/mypyc/irbuild/mapper.py index 576eacc141df1..e86c99f51e69c 100644 --- a/mypyc/irbuild/mapper.py +++ b/mypyc/irbuild/mapper.py @@ -2,20 +2,45 @@ from typing import Dict, Optional -from mypy.nodes import FuncDef, TypeInfo, SymbolNode, RefExpr, ArgKind, ARG_STAR, ARG_STAR2, GDEF +from mypy.nodes import ARG_STAR, ARG_STAR2, GDEF, ArgKind, FuncDef, RefExpr, SymbolNode, TypeInfo from mypy.types import ( - Instance, Type, CallableType, LiteralType, TypedDictType, UnboundType, PartialType, - UninhabitedType, Overloaded, UnionType, TypeType, AnyType, NoneTyp, TupleType, TypeVarType, - get_proper_type + AnyType, + CallableType, + Instance, + LiteralType, + NoneTyp, + Overloaded, + PartialType, + TupleType, + Type, + TypedDictType, + TypeType, + TypeVarType, + UnboundType, + UninhabitedType, + UnionType, + get_proper_type, ) - +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncSignature, RuntimeArg from mypyc.ir.rtypes import ( - RType, RUnion, RTuple, RInstance, object_rprimitive, dict_rprimitive, tuple_rprimitive, - none_rprimitive, int_rprimitive, float_rprimitive, str_rprimitive, bool_rprimitive, - list_rprimitive, set_rprimitive, range_rprimitive, bytes_rprimitive + RInstance, + RTuple, + RType, + RUnion, + bool_rprimitive, + bytes_rprimitive, + dict_rprimitive, + float_rprimitive, + int_rprimitive, + list_rprimitive, + none_rprimitive, + object_rprimitive, + range_rprimitive, + set_rprimitive, + str_rprimitive, + tuple_rprimitive, ) -from mypyc.ir.func_ir import FuncSignature, FuncDecl, RuntimeArg -from mypyc.ir.class_ir import ClassIR class Mapper: @@ -39,28 +64,28 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: typ = get_proper_type(typ) if isinstance(typ, Instance): - if typ.type.fullname == 'builtins.int': + if typ.type.fullname == "builtins.int": return int_rprimitive - elif typ.type.fullname == 'builtins.float': + elif typ.type.fullname == "builtins.float": return float_rprimitive - elif typ.type.fullname == 'builtins.bool': + elif typ.type.fullname == "builtins.bool": return bool_rprimitive - elif typ.type.fullname == 'builtins.str': + elif typ.type.fullname == "builtins.str": return str_rprimitive - elif typ.type.fullname == 'builtins.bytes': + elif typ.type.fullname == "builtins.bytes": return bytes_rprimitive - elif typ.type.fullname == 'builtins.list': + elif typ.type.fullname == "builtins.list": return list_rprimitive # Dict subclasses are at least somewhat common and we # specifically support them, so make sure that dict operations # get optimized on them. - elif any(cls.fullname == 'builtins.dict' for cls in typ.type.mro): + elif any(cls.fullname == "builtins.dict" for cls in typ.type.mro): return dict_rprimitive - elif typ.type.fullname == 'builtins.set': + elif typ.type.fullname == "builtins.set": return set_rprimitive - elif typ.type.fullname == 'builtins.tuple': + elif typ.type.fullname == "builtins.tuple": return tuple_rprimitive # Varying-length tuple - elif typ.type.fullname == 'builtins.range': + elif typ.type.fullname == "builtins.range": return range_rprimitive elif typ.type in self.type_to_ir: inst = RInstance(self.type_to_ir[typ.type]) @@ -76,7 +101,7 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: elif isinstance(typ, TupleType): # Use our unboxed tuples for raw tuples but fall back to # being boxed for NamedTuple. - if typ.partial_fallback.type.fullname == 'builtins.tuple': + if typ.partial_fallback.type.fullname == "builtins.tuple": return RTuple([self.type_to_rtype(t) for t in typ.items]) else: return tuple_rprimitive @@ -85,8 +110,7 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: elif isinstance(typ, NoneTyp): return none_rprimitive elif isinstance(typ, UnionType): - return RUnion([self.type_to_rtype(item) - for item in typ.items]) + return RUnion([self.type_to_rtype(item) for item in typ.items]) elif isinstance(typ, AnyType): return object_rprimitive elif isinstance(typ, TypeType): @@ -110,7 +134,7 @@ def type_to_rtype(self, typ: Optional[Type]) -> RType: # I think we've covered everything that is supposed to # actually show up, so anything else is a bug somewhere. - assert False, 'unexpected type %s' % type(typ) + assert False, "unexpected type %s" % type(typ) def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType: if kind == ARG_STAR: @@ -122,8 +146,10 @@ def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType: def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: if isinstance(fdef.type, CallableType): - arg_types = [self.get_arg_rtype(typ, kind) - for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds)] + arg_types = [ + self.get_arg_rtype(typ, kind) + for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds) + ] arg_pos_onlys = [name is None for name in fdef.type.arg_names] ret = self.type_to_rtype(fdef.type.ret_type) else: @@ -131,7 +157,7 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: arg_types = [object_rprimitive for arg in fdef.arguments] arg_pos_onlys = [arg.pos_only for arg in fdef.arguments] # We at least know the return type for __init__ methods will be None. - is_init_method = fdef.name == '__init__' and bool(fdef.info) + is_init_method = fdef.name == "__init__" and bool(fdef.info) if is_init_method: ret = none_rprimitive else: @@ -145,19 +171,22 @@ def fdef_to_sig(self, fdef: FuncDef) -> FuncSignature: # deserialized FuncDef that lacks arguments. We won't ever # need to use those inside of a FuncIR, so we just make up # some crap. - if hasattr(fdef, 'arguments'): + if hasattr(fdef, "arguments"): arg_names = [arg.variable.name for arg in fdef.arguments] else: - arg_names = [name or '' for name in fdef.arg_names] + arg_names = [name or "" for name in fdef.arg_names] - args = [RuntimeArg(arg_name, arg_type, arg_kind, arg_pos_only) - for arg_name, arg_kind, arg_type, arg_pos_only - in zip(arg_names, fdef.arg_kinds, arg_types, arg_pos_onlys)] + args = [ + RuntimeArg(arg_name, arg_type, arg_kind, arg_pos_only) + for arg_name, arg_kind, arg_type, arg_pos_only in zip( + arg_names, fdef.arg_kinds, arg_types, arg_pos_onlys + ) + ] # We force certain dunder methods to return objects to support letting them # return NotImplemented. It also avoids some pointless boxing and unboxing, # since tp_richcompare needs an object anyways. - if fdef.name in ('__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__'): + if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"): ret = object_rprimitive return FuncSignature(args, ret) @@ -168,8 +197,8 @@ def is_native_module(self, module: str) -> bool: def is_native_ref_expr(self, expr: RefExpr) -> bool: if expr.node is None: return False - if '.' in expr.node.fullname: - return self.is_native_module(expr.node.fullname.rpartition('.')[0]) + if "." in expr.node.fullname: + return self.is_native_module(expr.node.fullname.rpartition(".")[0]) return True def is_native_module_ref_expr(self, expr: RefExpr) -> bool: diff --git a/mypyc/irbuild/nonlocalcontrol.py b/mypyc/irbuild/nonlocalcontrol.py index e2dcbec8fbc3d..6266d1db0ae5d 100644 --- a/mypyc/irbuild/nonlocalcontrol.py +++ b/mypyc/irbuild/nonlocalcontrol.py @@ -5,14 +5,23 @@ from abc import abstractmethod from typing import Optional, Union + from typing_extensions import TYPE_CHECKING from mypyc.ir.ops import ( - Branch, BasicBlock, Unreachable, Value, Goto, Integer, Assign, Register, Return, - NO_TRACEBACK_LINE_NO + NO_TRACEBACK_LINE_NO, + Assign, + BasicBlock, + Branch, + Goto, + Integer, + Register, + Return, + Unreachable, + Value, ) -from mypyc.primitives.exc_ops import set_stop_iteration_value, restore_exc_info_op from mypyc.irbuild.targets import AssignmentTarget +from mypyc.primitives.exc_ops import restore_exc_info_op, set_stop_iteration_value if TYPE_CHECKING: from mypyc.irbuild.builder import IRBuilder @@ -31,59 +40,59 @@ class NonlocalControl: """ @abstractmethod - def gen_break(self, builder: 'IRBuilder', line: int) -> None: pass + def gen_break(self, builder: "IRBuilder", line: int) -> None: + pass @abstractmethod - def gen_continue(self, builder: 'IRBuilder', line: int) -> None: pass + def gen_continue(self, builder: "IRBuilder", line: int) -> None: + pass @abstractmethod - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: pass + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: + pass class BaseNonlocalControl(NonlocalControl): """Default nonlocal control outside any statements that affect it.""" - def gen_break(self, builder: 'IRBuilder', line: int) -> None: + def gen_break(self, builder: "IRBuilder", line: int) -> None: assert False, "break outside of loop" - def gen_continue(self, builder: 'IRBuilder', line: int) -> None: + def gen_continue(self, builder: "IRBuilder", line: int) -> None: assert False, "continue outside of loop" - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: builder.add(Return(value)) class LoopNonlocalControl(NonlocalControl): """Nonlocal control within a loop.""" - def __init__(self, - outer: NonlocalControl, - continue_block: BasicBlock, - break_block: BasicBlock) -> None: + def __init__( + self, outer: NonlocalControl, continue_block: BasicBlock, break_block: BasicBlock + ) -> None: self.outer = outer self.continue_block = continue_block self.break_block = break_block - def gen_break(self, builder: 'IRBuilder', line: int) -> None: + def gen_break(self, builder: "IRBuilder", line: int) -> None: builder.add(Goto(self.break_block)) - def gen_continue(self, builder: 'IRBuilder', line: int) -> None: + def gen_continue(self, builder: "IRBuilder", line: int) -> None: builder.add(Goto(self.continue_block)) - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: self.outer.gen_return(builder, value, line) class GeneratorNonlocalControl(BaseNonlocalControl): """Default nonlocal control in a generator function outside statements.""" - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: # Assign an invalid next label number so that the next time # __next__ is called, we jump to the case in which # StopIteration is raised. - builder.assign(builder.fn_info.generator_class.next_label_target, - Integer(-1), - line) + builder.assign(builder.fn_info.generator_class.next_label_target, Integer(-1), line) # Raise a StopIteration containing a field for the value that # should be returned. Before doing so, create a new block @@ -106,23 +115,24 @@ def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: class CleanupNonlocalControl(NonlocalControl): - """Abstract nonlocal control that runs some cleanup code. """ + """Abstract nonlocal control that runs some cleanup code.""" def __init__(self, outer: NonlocalControl) -> None: self.outer = outer @abstractmethod - def gen_cleanup(self, builder: 'IRBuilder', line: int) -> None: ... + def gen_cleanup(self, builder: "IRBuilder", line: int) -> None: + ... - def gen_break(self, builder: 'IRBuilder', line: int) -> None: + def gen_break(self, builder: "IRBuilder", line: int) -> None: self.gen_cleanup(builder, line) self.outer.gen_break(builder, line) - def gen_continue(self, builder: 'IRBuilder', line: int) -> None: + def gen_continue(self, builder: "IRBuilder", line: int) -> None: self.gen_cleanup(builder, line) self.outer.gen_continue(builder, line) - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: self.gen_cleanup(builder, line) self.outer.gen_return(builder, value, line) @@ -134,13 +144,13 @@ def __init__(self, target: BasicBlock) -> None: self.target = target self.ret_reg: Optional[Register] = None - def gen_break(self, builder: 'IRBuilder', line: int) -> None: + def gen_break(self, builder: "IRBuilder", line: int) -> None: builder.error("break inside try/finally block is unimplemented", line) - def gen_continue(self, builder: 'IRBuilder', line: int) -> None: + def gen_continue(self, builder: "IRBuilder", line: int) -> None: builder.error("continue inside try/finally block is unimplemented", line) - def gen_return(self, builder: 'IRBuilder', value: Value, line: int) -> None: + def gen_return(self, builder: "IRBuilder", value: Value, line: int) -> None: if self.ret_reg is None: self.ret_reg = Register(builder.ret_types[-1]) @@ -159,7 +169,7 @@ def __init__(self, outer: NonlocalControl, saved: Union[Value, AssignmentTarget] super().__init__(outer) self.saved = saved - def gen_cleanup(self, builder: 'IRBuilder', line: int) -> None: + def gen_cleanup(self, builder: "IRBuilder", line: int) -> None: builder.call_c(restore_exc_info_op, [builder.read(self.saved)], line) @@ -175,7 +185,7 @@ def __init__(self, outer: NonlocalControl, ret_reg: Optional[Value], saved: Valu self.ret_reg = ret_reg self.saved = saved - def gen_cleanup(self, builder: 'IRBuilder', line: int) -> None: + def gen_cleanup(self, builder: "IRBuilder", line: int) -> None: # Restore the old exc_info target, cleanup = BasicBlock(), BasicBlock() builder.add(Branch(self.saved, target, cleanup, Branch.IS_ERROR)) diff --git a/mypyc/irbuild/prebuildvisitor.py b/mypyc/irbuild/prebuildvisitor.py index 55928a57b839f..182286c38a75a 100644 --- a/mypyc/irbuild/prebuildvisitor.py +++ b/mypyc/irbuild/prebuildvisitor.py @@ -1,11 +1,19 @@ -from mypyc.errors import Errors from typing import Dict, List, Set from mypy.nodes import ( - Decorator, Expression, FuncDef, FuncItem, LambdaExpr, NameExpr, SymbolNode, Var, MemberExpr, - MypyFile + Decorator, + Expression, + FuncDef, + FuncItem, + LambdaExpr, + MemberExpr, + MypyFile, + NameExpr, + SymbolNode, + Var, ) from mypy.traverser import TraverserVisitor +from mypyc.errors import Errors class PreBuildVisitor(TraverserVisitor): @@ -73,7 +81,7 @@ def visit_decorator(self, dec: Decorator) -> None: # mypy. Functions decorated only by special decorators # (and property setters) are not treated as decorated # functions by the IR builder. - if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == 'setter': + if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter": # Property setters are not treated as decorated methods. self.prop_setters.add(dec.func) else: diff --git a/mypyc/irbuild/prepare.py b/mypyc/irbuild/prepare.py index cc9505853db15..02ad40f2b3815 100644 --- a/mypyc/irbuild/prepare.py +++ b/mypyc/irbuild/prepare.py @@ -11,41 +11,63 @@ Also build a mapping from mypy TypeInfos to ClassIR objects. """ -from typing import List, Dict, Iterable, Optional, Union, DefaultDict, NamedTuple, Tuple +from collections import defaultdict +from typing import DefaultDict, Dict, Iterable, List, NamedTuple, Optional, Tuple, Union +from mypy.build import Graph from mypy.nodes import ( - ClassDef, OverloadedFuncDef, Var, - SymbolNode, ARG_STAR, ARG_STAR2, CallExpr, Decorator, Expression, FuncDef, - MemberExpr, MypyFile, NameExpr, RefExpr, TypeInfo + ARG_STAR, + ARG_STAR2, + CallExpr, + ClassDef, + Decorator, + Expression, + FuncDef, + MemberExpr, + MypyFile, + NameExpr, + OverloadedFuncDef, + RefExpr, + SymbolNode, + TypeInfo, + Var, ) -from mypy.types import Type, Instance, get_proper_type -from mypy.build import Graph - -from mypyc.ir.ops import DeserMaps -from mypyc.ir.rtypes import RInstance, tuple_rprimitive, dict_rprimitive +from mypy.semanal import refers_to_fullname +from mypy.traverser import TraverserVisitor +from mypy.types import Instance, Type, get_proper_type +from mypyc.common import PROPSET_PREFIX, get_id_from_name +from mypyc.crash import catch_errors +from mypyc.errors import Errors +from mypyc.ir.class_ir import ClassIR from mypyc.ir.func_ir import ( - FuncDecl, FuncSignature, RuntimeArg, FUNC_NORMAL, FUNC_STATICMETHOD, FUNC_CLASSMETHOD + FUNC_CLASSMETHOD, + FUNC_NORMAL, + FUNC_STATICMETHOD, + FuncDecl, + FuncSignature, + RuntimeArg, ) -from mypyc.ir.class_ir import ClassIR -from mypyc.common import PROPSET_PREFIX, get_id_from_name +from mypyc.ir.ops import DeserMaps +from mypyc.ir.rtypes import RInstance, dict_rprimitive, tuple_rprimitive from mypyc.irbuild.mapper import Mapper from mypyc.irbuild.util import ( - get_func_def, is_dataclass, is_trait, is_extension_class, get_mypyc_attrs + get_func_def, + get_mypyc_attrs, + is_dataclass, + is_extension_class, + is_trait, ) -from mypyc.errors import Errors from mypyc.options import CompilerOptions -from mypyc.crash import catch_errors -from collections import defaultdict -from mypy.traverser import TraverserVisitor -from mypy.semanal import refers_to_fullname -def build_type_map(mapper: Mapper, - modules: List[MypyFile], - graph: Graph, - types: Dict[Expression, Type], - options: CompilerOptions, - errors: Errors) -> None: +def build_type_map( + mapper: Mapper, + modules: List[MypyFile], + graph: Graph, + types: Dict[Expression, Type], + options: CompilerOptions, + errors: Errors, +) -> None: # Collect all classes defined in everything we are compiling classes = [] for module in modules: @@ -55,8 +77,9 @@ def build_type_map(mapper: Mapper, # Collect all class mappings so that we can bind arbitrary class name # references even if there are import cycles. for module, cdef in classes: - class_ir = ClassIR(cdef.name, module.fullname, is_trait(cdef), - is_abstract=cdef.info.is_abstract) + class_ir = ClassIR( + cdef.name, module.fullname, is_trait(cdef), is_abstract=cdef.info.is_abstract + ) class_ir.is_ext_class = is_extension_class(cdef) if class_ir.is_ext_class: class_ir.deletable = cdef.info.deletable_attributes[:] @@ -83,12 +106,10 @@ def build_type_map(mapper: Mapper, def is_from_module(node: SymbolNode, module: MypyFile) -> bool: - return node.fullname == module.fullname + '.' + node.name + return node.fullname == module.fullname + "." + node.name -def load_type_map(mapper: 'Mapper', - modules: List[MypyFile], - deser_ctx: DeserMaps) -> None: +def load_type_map(mapper: "Mapper", modules: List[MypyFile], deser_ctx: DeserMaps) -> None: """Populate a Mapper with deserialized IR from a list of modules.""" for module in modules: for name, node in module.names.items(): @@ -109,22 +130,28 @@ def get_module_func_defs(module: MypyFile) -> Iterable[FuncDef]: # We need to filter out functions that are imported or # aliases. The best way to do this seems to be by # checking that the fullname matches. - if (isinstance(node.node, (FuncDef, Decorator, OverloadedFuncDef)) - and is_from_module(node.node, module)): + if isinstance(node.node, (FuncDef, Decorator, OverloadedFuncDef)) and is_from_module( + node.node, module + ): yield get_func_def(node.node) -def prepare_func_def(module_name: str, class_name: Optional[str], - fdef: FuncDef, mapper: Mapper) -> FuncDecl: - kind = FUNC_STATICMETHOD if fdef.is_static else ( - FUNC_CLASSMETHOD if fdef.is_class else FUNC_NORMAL) +def prepare_func_def( + module_name: str, class_name: Optional[str], fdef: FuncDef, mapper: Mapper +) -> FuncDecl: + kind = ( + FUNC_STATICMETHOD + if fdef.is_static + else (FUNC_CLASSMETHOD if fdef.is_class else FUNC_NORMAL) + ) decl = FuncDecl(fdef.name, class_name, module_name, mapper.fdef_to_sig(fdef), kind) mapper.func_to_decl[fdef] = decl return decl -def prepare_method_def(ir: ClassIR, module_name: str, cdef: ClassDef, mapper: Mapper, - node: Union[FuncDef, Decorator]) -> None: +def prepare_method_def( + ir: ClassIR, module_name: str, cdef: ClassDef, mapper: Mapper, node: Union[FuncDef, Decorator] +) -> None: if isinstance(node, FuncDef): ir.method_decls[node.name] = prepare_func_def(module_name, cdef.name, node, mapper) elif isinstance(node, Decorator): @@ -133,7 +160,7 @@ def prepare_method_def(ir: ClassIR, module_name: str, cdef: ClassDef, mapper: Ma decl = prepare_func_def(module_name, cdef.name, node.func, mapper) if not node.decorators: ir.method_decls[node.name] = decl - elif isinstance(node.decorators[0], MemberExpr) and node.decorators[0].name == 'setter': + elif isinstance(node.decorators[0], MemberExpr) and node.decorators[0].name == "setter": # Make property setter name different than getter name so there are no # name clashes when generating C code, and property lookup at the IR level # works correctly. @@ -163,13 +190,21 @@ def is_valid_multipart_property_def(prop: OverloadedFuncDef) -> bool: def can_subclass_builtin(builtin_base: str) -> bool: # BaseException and dict are special cased. return builtin_base in ( - ('builtins.Exception', 'builtins.LookupError', 'builtins.IndexError', - 'builtins.Warning', 'builtins.UserWarning', 'builtins.ValueError', - 'builtins.object', )) - - -def prepare_class_def(path: str, module_name: str, cdef: ClassDef, - errors: Errors, mapper: Mapper) -> None: + ( + "builtins.Exception", + "builtins.LookupError", + "builtins.IndexError", + "builtins.Warning", + "builtins.UserWarning", + "builtins.ValueError", + "builtins.object", + ) + ) + + +def prepare_class_def( + path: str, module_name: str, cdef: ClassDef, errors: Errors, mapper: Mapper +) -> None: ir = mapper.type_to_ir[cdef.info] info = cdef.info @@ -189,7 +224,7 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, if isinstance(node.node, Var): assert node.node.type, "Class member %s missing type" % name - if not node.node.is_classvar and name not in ('__slots__', '__deletable__'): + if not node.node.is_classvar and name not in ("__slots__", "__deletable__"): ir.attributes[name] = mapper.type_to_rtype(node.node.type) elif isinstance(node.node, (FuncDef, Decorator)): prepare_method_def(ir, module_name, cdef, mapper, node.node) @@ -211,24 +246,25 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, for cls in info.mro: # Special case exceptions and dicts # XXX: How do we handle *other* things?? - if cls.fullname == 'builtins.BaseException': - ir.builtin_base = 'PyBaseExceptionObject' - elif cls.fullname == 'builtins.dict': - ir.builtin_base = 'PyDictObject' - elif cls.fullname.startswith('builtins.'): + if cls.fullname == "builtins.BaseException": + ir.builtin_base = "PyBaseExceptionObject" + elif cls.fullname == "builtins.dict": + ir.builtin_base = "PyDictObject" + elif cls.fullname.startswith("builtins."): if not can_subclass_builtin(cls.fullname): # Note that if we try to subclass a C extension class that # isn't in builtins, bad things will happen and we won't # catch it here! But this should catch a lot of the most # common pitfalls. - errors.error("Inheriting from most builtin types is unimplemented", - path, cdef.line) + errors.error( + "Inheriting from most builtin types is unimplemented", path, cdef.line + ) if ir.builtin_base: ir.attributes.clear() # Set up a constructor decl - init_node = cdef.info['__init__'].node + init_node = cdef.info["__init__"].node if not ir.is_trait and not ir.builtin_base and isinstance(init_node, FuncDef): init_sig = mapper.fdef_to_sig(init_node) @@ -236,22 +272,26 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, # If there is a nontrivial __init__ that wasn't defined in an # extension class, we need to make the constructor take *args, # **kwargs so it can call tp_init. - if ((defining_ir is None or not defining_ir.is_ext_class - or cdef.info['__init__'].plugin_generated) - and init_node.info.fullname != 'builtins.object'): + if ( + defining_ir is None + or not defining_ir.is_ext_class + or cdef.info["__init__"].plugin_generated + ) and init_node.info.fullname != "builtins.object": init_sig = FuncSignature( - [init_sig.args[0], - RuntimeArg("args", tuple_rprimitive, ARG_STAR), - RuntimeArg("kwargs", dict_rprimitive, ARG_STAR2)], - init_sig.ret_type) + [ + init_sig.args[0], + RuntimeArg("args", tuple_rprimitive, ARG_STAR), + RuntimeArg("kwargs", dict_rprimitive, ARG_STAR2), + ], + init_sig.ret_type, + ) ctor_sig = FuncSignature(init_sig.args[1:], RInstance(ir)) ir.ctor = FuncDecl(cdef.name, None, module_name, ctor_sig) mapper.func_to_decl[cdef.info] = ir.ctor # Set up the parent class - bases = [mapper.type_to_ir[base.type] for base in info.bases - if base.type in mapper.type_to_ir] + bases = [mapper.type_to_ir[base.type] for base in info.bases if base.type in mapper.type_to_ir] if not all(c.is_trait for c in bases[1:]): errors.error("Non-trait bases must appear first in parent list", path, cdef.line) ir.traits = [c for c in bases if c.is_trait] @@ -260,7 +300,7 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, base_mro = [] for cls in info.mro: if cls not in mapper.type_to_ir: - if cls.fullname != 'builtins.object': + if cls.fullname != "builtins.object": ir.inherits_python = True continue base_ir = mapper.type_to_ir[cls] @@ -285,8 +325,9 @@ def prepare_class_def(path: str, module_name: str, cdef: ClassDef, ir.is_augmented = True -def prepare_non_ext_class_def(path: str, module_name: str, cdef: ClassDef, - errors: Errors, mapper: Mapper) -> None: +def prepare_non_ext_class_def( + path: str, module_name: str, cdef: ClassDef, errors: Errors, mapper: Mapper +) -> None: ir = mapper.type_to_ir[cdef.info] info = cdef.info @@ -305,11 +346,10 @@ def prepare_non_ext_class_def(path: str, module_name: str, cdef: ClassDef, else: prepare_method_def(ir, module_name, cdef, mapper, get_func_def(node.node)) - if any( - cls in mapper.type_to_ir and mapper.type_to_ir[cls].is_ext_class for cls in info.mro - ): + if any(cls in mapper.type_to_ir and mapper.type_to_ir[cls].is_ext_class for cls in info.mro): errors.error( - "Non-extension classes may not inherit from extension classes", path, cdef.line) + "Non-extension classes may not inherit from extension classes", path, cdef.line + ) RegisterImplInfo = Tuple[TypeInfo, FuncDef] @@ -321,8 +361,7 @@ class SingledispatchInfo(NamedTuple): def find_singledispatch_register_impls( - modules: List[MypyFile], - errors: Errors, + modules: List[MypyFile], errors: Errors ) -> SingledispatchInfo: visitor = SingledispatchVisitor(errors) for module in modules: @@ -356,7 +395,8 @@ def visit_decorator(self, dec: Decorator) -> None: impl = get_singledispatch_register_call_info(d, dec.func) if impl is not None: self.singledispatch_impls[impl.singledispatch_func].append( - (impl.dispatch_type, dec.func)) + (impl.dispatch_type, dec.func) + ) decorators_to_remove.append(i) if last_non_register is not None: # found a register decorator after a non-register decorator, which we @@ -369,7 +409,7 @@ def visit_decorator(self, dec: Decorator) -> None: decorators_to_store[last_non_register].line, ) else: - if refers_to_fullname(d, 'functools.singledispatch'): + if refers_to_fullname(d, "functools.singledispatch"): decorators_to_remove.append(i) # make sure that we still treat the function as a singledispatch function # even if we don't find any registered implementations (which might happen @@ -391,12 +431,16 @@ class RegisteredImpl(NamedTuple): dispatch_type: TypeInfo -def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef - ) -> Optional[RegisteredImpl]: +def get_singledispatch_register_call_info( + decorator: Expression, func: FuncDef +) -> Optional[RegisteredImpl]: # @fun.register(complex) # def g(arg): ... - if (isinstance(decorator, CallExpr) and len(decorator.args) == 1 - and isinstance(decorator.args[0], RefExpr)): + if ( + isinstance(decorator, CallExpr) + and len(decorator.args) == 1 + and isinstance(decorator.args[0], RefExpr) + ): callee = decorator.callee dispatch_type = decorator.args[0].node if not isinstance(dispatch_type, TypeInfo): @@ -419,9 +463,10 @@ def get_singledispatch_register_call_info(decorator: Expression, func: FuncDef return None -def registered_impl_from_possible_register_call(expr: MemberExpr, dispatch_type: TypeInfo - ) -> Optional[RegisteredImpl]: - if expr.name == 'register' and isinstance(expr.expr, NameExpr): +def registered_impl_from_possible_register_call( + expr: MemberExpr, dispatch_type: TypeInfo +) -> Optional[RegisteredImpl]: + if expr.name == "register" and isinstance(expr.expr, NameExpr): node = expr.expr.node if isinstance(node, Decorator): return RegisteredImpl(node.func, dispatch_type) diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index 1b4aa5e8c8c0b..9a08257e38ce5 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -12,34 +12,58 @@ See comment below for more documentation. """ -from typing import Callable, Optional, Dict, Tuple, List +from typing import Callable, Dict, List, Optional, Tuple from mypy.nodes import ( - CallExpr, RefExpr, MemberExpr, NameExpr, TupleExpr, GeneratorExpr, - ListExpr, DictExpr, StrExpr, IntExpr, ARG_POS, ARG_NAMED, Expression + ARG_NAMED, + ARG_POS, + CallExpr, + DictExpr, + Expression, + GeneratorExpr, + IntExpr, + ListExpr, + MemberExpr, + NameExpr, + RefExpr, + StrExpr, + TupleExpr, ) from mypy.types import AnyType, TypeOfAny - -from mypyc.ir.ops import ( - Value, Register, BasicBlock, Integer, RaiseStandardError, Unreachable -) +from mypyc.ir.ops import BasicBlock, Integer, RaiseStandardError, Register, Unreachable, Value from mypyc.ir.rtypes import ( - RType, RTuple, str_rprimitive, list_rprimitive, dict_rprimitive, set_rprimitive, - bool_rprimitive, c_int_rprimitive, is_dict_rprimitive, is_list_rprimitive + RTuple, + RType, + bool_rprimitive, + c_int_rprimitive, + dict_rprimitive, + is_dict_rprimitive, + is_list_rprimitive, + list_rprimitive, + set_rprimitive, + str_rprimitive, +) +from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.for_helpers import ( + comprehension_helper, + sequence_from_generator_preallocate_helper, + translate_list_comprehension, + translate_set_comprehension, ) from mypyc.irbuild.format_str_tokenizer import ( - tokenizer_format_call, join_formatted_strings, convert_format_expr_to_str, FormatOp + FormatOp, + convert_format_expr_to_str, + join_formatted_strings, + tokenizer_format_call, ) from mypyc.primitives.dict_ops import ( - dict_keys_op, dict_values_op, dict_items_op, dict_setdefault_spec_init_op + dict_items_op, + dict_keys_op, + dict_setdefault_spec_init_op, + dict_values_op, ) from mypyc.primitives.list_ops import new_list_set_item_op from mypyc.primitives.tuple_ops import new_tuple_set_item_op -from mypyc.irbuild.builder import IRBuilder -from mypyc.irbuild.for_helpers import ( - translate_list_comprehension, translate_set_comprehension, - comprehension_helper, sequence_from_generator_preallocate_helper -) # Specializers are attempted before compiling the arguments to the # function. Specializers can return None to indicate that they failed @@ -48,7 +72,7 @@ # # Specializers take three arguments: the IRBuilder, the CallExpr being # compiled, and the RefExpr that is the left hand side of the call. -Specializer = Callable[['IRBuilder', CallExpr, RefExpr], Optional[Value]] +Specializer = Callable[["IRBuilder", CallExpr, RefExpr], Optional[Value]] # Dictionary containing all configured specializers. # @@ -57,8 +81,13 @@ specializers: Dict[Tuple[str, Optional[RType]], List[Specializer]] = {} -def _apply_specialization(builder: 'IRBuilder', expr: CallExpr, callee: RefExpr, - name: Optional[str], typ: Optional[RType] = None) -> Optional[Value]: +def _apply_specialization( + builder: "IRBuilder", + expr: CallExpr, + callee: RefExpr, + name: Optional[str], + typ: Optional[RType] = None, +) -> Optional[Value]: # TODO: Allow special cases to have default args or named args. Currently they don't since # they check that everything in arg_kinds is ARG_POS. @@ -72,21 +101,24 @@ def _apply_specialization(builder: 'IRBuilder', expr: CallExpr, callee: RefExpr, return None -def apply_function_specialization(builder: 'IRBuilder', expr: CallExpr, - callee: RefExpr) -> Optional[Value]: +def apply_function_specialization( + builder: "IRBuilder", expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Invoke the Specializer callback for a function if one has been registered""" return _apply_specialization(builder, expr, callee, callee.fullname) -def apply_method_specialization(builder: 'IRBuilder', expr: CallExpr, callee: MemberExpr, - typ: Optional[RType] = None) -> Optional[Value]: +def apply_method_specialization( + builder: "IRBuilder", expr: CallExpr, callee: MemberExpr, typ: Optional[RType] = None +) -> Optional[Value]: """Invoke the Specializer callback for a method if one has been registered""" name = callee.fullname if typ is None else callee.name return _apply_specialization(builder, expr, callee, name, typ) def specialize_function( - name: str, typ: Optional[RType] = None) -> Callable[[Specializer], Specializer]: + name: str, typ: Optional[RType] = None +) -> Callable[[Specializer], Specializer]: """Decorator to register a function as being a specializer. There may exist multiple specializers for one function. When @@ -101,18 +133,16 @@ def wrapper(f: Specializer) -> Specializer: return wrapper -@specialize_function('builtins.globals') +@specialize_function("builtins.globals") def translate_globals(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: if len(expr.args) == 0: return builder.load_globals_dict() return None -@specialize_function('builtins.len') -def translate_len( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: - if (len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS]): +@specialize_function("builtins.len") +def translate_len(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + if len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]: arg = expr.args[0] expr_rtype = builder.node_type(arg) if isinstance(expr_rtype, RTuple): @@ -130,9 +160,8 @@ def translate_len( return None -@specialize_function('builtins.list') -def dict_methods_fast_path( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: +@specialize_function("builtins.list") +def dict_methods_fast_path(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: """Specialize a common case when list() is called on a dictionary view method call. @@ -142,30 +171,30 @@ def dict_methods_fast_path( if not (len(expr.args) == 1 and expr.arg_kinds == [ARG_POS]): return None arg = expr.args[0] - if not (isinstance(arg, CallExpr) and not arg.args - and isinstance(arg.callee, MemberExpr)): + if not (isinstance(arg, CallExpr) and not arg.args and isinstance(arg.callee, MemberExpr)): return None base = arg.callee.expr attr = arg.callee.name rtype = builder.node_type(base) - if not (is_dict_rprimitive(rtype) and attr in ('keys', 'values', 'items')): + if not (is_dict_rprimitive(rtype) and attr in ("keys", "values", "items")): return None obj = builder.accept(base) # Note that it is not safe to use fast methods on dict subclasses, # so the corresponding helpers in CPy.h fallback to (inlined) # generic logic. - if attr == 'keys': + if attr == "keys": return builder.call_c(dict_keys_op, [obj], expr.line) - elif attr == 'values': + elif attr == "values": return builder.call_c(dict_values_op, [obj], expr.line) else: return builder.call_c(dict_items_op, [obj], expr.line) -@specialize_function('builtins.list') +@specialize_function("builtins.list") def translate_list_from_generator_call( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special case for simplest list comprehension. For example: @@ -173,19 +202,24 @@ def translate_list_from_generator_call( 'translate_list_comprehension()' would take care of other cases if this fails. """ - if (len(expr.args) == 1 - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) == 1 + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], GeneratorExpr) + ): return sequence_from_generator_preallocate_helper( - builder, expr.args[0], + builder, + expr.args[0], empty_op_llbuilder=builder.builder.new_list_op_with_length, - set_item_op=new_list_set_item_op) + set_item_op=new_list_set_item_op, + ) return None -@specialize_function('builtins.tuple') +@specialize_function("builtins.tuple") def translate_tuple_from_generator_call( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special case for simplest tuple creation from a generator. For example: @@ -193,42 +227,49 @@ def translate_tuple_from_generator_call( 'translate_safe_generator_call()' would take care of other cases if this fails. """ - if (len(expr.args) == 1 - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) == 1 + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], GeneratorExpr) + ): return sequence_from_generator_preallocate_helper( - builder, expr.args[0], + builder, + expr.args[0], empty_op_llbuilder=builder.builder.new_tuple_with_length, - set_item_op=new_tuple_set_item_op) + set_item_op=new_tuple_set_item_op, + ) return None -@specialize_function('builtins.set') +@specialize_function("builtins.set") def translate_set_from_generator_call( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special case for set creation from a generator. For example: set(f(...) for ... in iterator/nested_generators...) """ - if (len(expr.args) == 1 - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) == 1 + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], GeneratorExpr) + ): return translate_set_comprehension(builder, expr.args[0]) return None -@specialize_function('builtins.min') -@specialize_function('builtins.max') +@specialize_function("builtins.min") +@specialize_function("builtins.max") def faster_min_max(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: if expr.arg_kinds == [ARG_POS, ARG_POS]: x, y = builder.accept(expr.args[0]), builder.accept(expr.args[1]) result = Register(builder.node_type(expr)) # CPython evaluates arguments reversely when calling min(...) or max(...) - if callee.fullname == 'builtins.min': - comparison = builder.binary_op(y, x, '<', expr.line) + if callee.fullname == "builtins.min": + comparison = builder.binary_op(y, x, "<", expr.line) else: - comparison = builder.binary_op(y, x, '>', expr.line) + comparison = builder.binary_op(y, x, ">", expr.line) true_block, false_block, next_block = BasicBlock(), BasicBlock(), BasicBlock() builder.add_bool_branch(comparison, true_block, false_block) @@ -246,67 +287,88 @@ def faster_min_max(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optio return None -@specialize_function('builtins.tuple') -@specialize_function('builtins.frozenset') -@specialize_function('builtins.dict') -@specialize_function('builtins.min') -@specialize_function('builtins.max') -@specialize_function('builtins.sorted') -@specialize_function('collections.OrderedDict') -@specialize_function('join', str_rprimitive) -@specialize_function('extend', list_rprimitive) -@specialize_function('update', dict_rprimitive) -@specialize_function('update', set_rprimitive) +@specialize_function("builtins.tuple") +@specialize_function("builtins.frozenset") +@specialize_function("builtins.dict") +@specialize_function("builtins.min") +@specialize_function("builtins.max") +@specialize_function("builtins.sorted") +@specialize_function("collections.OrderedDict") +@specialize_function("join", str_rprimitive) +@specialize_function("extend", list_rprimitive) +@specialize_function("update", dict_rprimitive) +@specialize_function("update", set_rprimitive) def translate_safe_generator_call( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special cases for things that consume iterators where we know we can safely compile a generator into a list. """ - if (len(expr.args) > 0 - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) > 0 + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], GeneratorExpr) + ): if isinstance(callee, MemberExpr): return builder.gen_method_call( - builder.accept(callee.expr), callee.name, - ([translate_list_comprehension(builder, expr.args[0])] - + [builder.accept(arg) for arg in expr.args[1:]]), - builder.node_type(expr), expr.line, expr.arg_kinds, expr.arg_names) + builder.accept(callee.expr), + callee.name, + ( + [translate_list_comprehension(builder, expr.args[0])] + + [builder.accept(arg) for arg in expr.args[1:]] + ), + builder.node_type(expr), + expr.line, + expr.arg_kinds, + expr.arg_names, + ) else: return builder.call_refexpr_with_args( - expr, callee, - ([translate_list_comprehension(builder, expr.args[0])] - + [builder.accept(arg) for arg in expr.args[1:]])) + expr, + callee, + ( + [translate_list_comprehension(builder, expr.args[0])] + + [builder.accept(arg) for arg in expr.args[1:]] + ), + ) return None -@specialize_function('builtins.any') +@specialize_function("builtins.any") def translate_any_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: - if (len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) == 1 + and expr.arg_kinds == [ARG_POS] + and isinstance(expr.args[0], GeneratorExpr) + ): return any_all_helper(builder, expr.args[0], builder.false, lambda x: x, builder.true) return None -@specialize_function('builtins.all') +@specialize_function("builtins.all") def translate_all_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: - if (len(expr.args) == 1 - and expr.arg_kinds == [ARG_POS] - and isinstance(expr.args[0], GeneratorExpr)): + if ( + len(expr.args) == 1 + and expr.arg_kinds == [ARG_POS] + and isinstance(expr.args[0], GeneratorExpr) + ): return any_all_helper( - builder, expr.args[0], + builder, + expr.args[0], builder.true, - lambda x: builder.unary_op(x, 'not', expr.line), - builder.false + lambda x: builder.unary_op(x, "not", expr.line), + builder.false, ) return None -def any_all_helper(builder: IRBuilder, - gen: GeneratorExpr, - initial_value: Callable[[], Value], - modify: Callable[[Value], Value], - new_value: Callable[[], Value]) -> Value: +def any_all_helper( + builder: IRBuilder, + gen: GeneratorExpr, + initial_value: Callable[[], Value], + modify: Callable[[Value], Value], + new_value: Callable[[], Value], +) -> Value: retval = Register(bool_rprimitive) builder.assign(retval, initial_value(), -1) loop_params = list(zip(gen.indices, gen.sequences, gen.condlists)) @@ -326,15 +388,17 @@ def gen_inner_stmts() -> None: return retval -@specialize_function('builtins.sum') +@specialize_function("builtins.sum") def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: # specialized implementation is used if: # - only one or two arguments given (if not, sum() has been given invalid arguments) # - first argument is a Generator (there is no benefit to optimizing the performance of eg. # sum([1, 2, 3]), so non-Generator Iterables are not handled) - if not (len(expr.args) in (1, 2) - and expr.arg_kinds[0] == ARG_POS - and isinstance(expr.args[0], GeneratorExpr)): + if not ( + len(expr.args) in (1, 2) + and expr.arg_kinds[0] == ARG_POS + and isinstance(expr.args[0], GeneratorExpr) + ): return None # handle 'start' argument, if given @@ -353,7 +417,7 @@ def translate_sum_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> O def gen_inner_stmts() -> None: call_expr = builder.accept(gen_expr.left_expr) - builder.assign(retval, builder.binary_op(retval, call_expr, '+', -1), -1) + builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1) loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists)) comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line) @@ -361,12 +425,13 @@ def gen_inner_stmts() -> None: return retval -@specialize_function('dataclasses.field') -@specialize_function('attr.ib') -@specialize_function('attr.attrib') -@specialize_function('attr.Factory') +@specialize_function("dataclasses.field") +@specialize_function("attr.ib") +@specialize_function("attr.attrib") +@specialize_function("attr.Factory") def translate_dataclasses_field_call( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special case for 'dataclasses.field', 'attr.attrib', and 'attr.Factory' function calls because the results of such calls are type-checked by mypy using the types of the arguments to their respective @@ -377,7 +442,7 @@ def translate_dataclasses_field_call( return None -@specialize_function('builtins.next') +@specialize_function("builtins.next") def translate_next_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: """Special case for calling next() on a generator expression, an idiom that shows up some in mypy. @@ -387,8 +452,10 @@ def translate_next_call(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> and produce the first such object, or None if no such element exists. """ - if not (expr.arg_kinds in ([ARG_POS], [ARG_POS, ARG_POS]) - and isinstance(expr.args[0], GeneratorExpr)): + if not ( + expr.arg_kinds in ([ARG_POS], [ARG_POS, ARG_POS]) + and isinstance(expr.args[0], GeneratorExpr) + ): return None gen = expr.args[0] @@ -419,7 +486,7 @@ def gen_inner_stmts() -> None: return retval -@specialize_function('builtins.isinstance') +@specialize_function("builtins.isinstance") def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: """Special case for builtins.isinstance. @@ -427,25 +494,28 @@ def translate_isinstance(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> there is no need to coerce something to a new type before checking what type it is, and the coercion could lead to bugs. """ - if (len(expr.args) == 2 - and expr.arg_kinds == [ARG_POS, ARG_POS] - and isinstance(expr.args[1], (RefExpr, TupleExpr))): + if ( + len(expr.args) == 2 + and expr.arg_kinds == [ARG_POS, ARG_POS] + and isinstance(expr.args[1], (RefExpr, TupleExpr)) + ): builder.types[expr.args[0]] = AnyType(TypeOfAny.from_error) irs = builder.flatten_classes(expr.args[1]) if irs is not None: - can_borrow = all(ir.is_ext_class - and not ir.inherits_python - and not ir.allow_interpreted_subclasses - for ir in irs) + can_borrow = all( + ir.is_ext_class and not ir.inherits_python and not ir.allow_interpreted_subclasses + for ir in irs + ) obj = builder.accept(expr.args[0], can_borrow=can_borrow) return builder.builder.isinstance_helper(obj, irs, expr.line) return None -@specialize_function('setdefault', dict_rprimitive) +@specialize_function("setdefault", dict_rprimitive) def translate_dict_setdefault( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + builder: IRBuilder, expr: CallExpr, callee: RefExpr +) -> Optional[Value]: """Special case for 'dict.setdefault' which would only construct default empty collection when needed. @@ -457,9 +527,11 @@ def translate_dict_setdefault( d.setdefault(key, []).append(value) d.setdefault(key, {})[inner_key] = inner_val """ - if (len(expr.args) == 2 - and expr.arg_kinds == [ARG_POS, ARG_POS] - and isinstance(callee, MemberExpr)): + if ( + len(expr.args) == 2 + and expr.arg_kinds == [ARG_POS, ARG_POS] + and isinstance(callee, MemberExpr) + ): arg = expr.args[1] if isinstance(arg, ListExpr): if len(arg.items): @@ -469,8 +541,11 @@ def translate_dict_setdefault( if len(arg.items): return None data_type = Integer(2, c_int_rprimitive, expr.line) - elif (isinstance(arg, CallExpr) and isinstance(arg.callee, NameExpr) - and arg.callee.fullname == 'builtins.set'): + elif ( + isinstance(arg, CallExpr) + and isinstance(arg.callee, NameExpr) + and arg.callee.fullname == "builtins.set" + ): if len(arg.args): return None data_type = Integer(3, c_int_rprimitive, expr.line) @@ -479,17 +554,19 @@ def translate_dict_setdefault( callee_dict = builder.accept(callee.expr) key_val = builder.accept(expr.args[0]) - return builder.call_c(dict_setdefault_spec_init_op, - [callee_dict, key_val, data_type], - expr.line) + return builder.call_c( + dict_setdefault_spec_init_op, [callee_dict, key_val, data_type], expr.line + ) return None -@specialize_function('format', str_rprimitive) -def translate_str_format( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: - if (isinstance(callee, MemberExpr) and isinstance(callee.expr, StrExpr) - and expr.arg_kinds.count(ARG_POS) == len(expr.arg_kinds)): +@specialize_function("format", str_rprimitive) +def translate_str_format(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: + if ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, StrExpr) + and expr.arg_kinds.count(ARG_POS) == len(expr.arg_kinds) + ): format_str = callee.expr.value tokens = tokenizer_format_call(format_str) if tokens is None: @@ -503,30 +580,33 @@ def translate_str_format( return None -@specialize_function('join', str_rprimitive) -def translate_fstring( - builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: +@specialize_function("join", str_rprimitive) +def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Optional[Value]: """Special case for f-string, which is translated into str.join() in mypy AST. This specializer optimizes simplest f-strings which don't contain any format operation. """ - if (isinstance(callee, MemberExpr) - and isinstance(callee.expr, StrExpr) and callee.expr.value == '' - and expr.arg_kinds == [ARG_POS] and isinstance(expr.args[0], ListExpr)): + if ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, StrExpr) + and callee.expr.value == "" + and expr.arg_kinds == [ARG_POS] + and isinstance(expr.args[0], ListExpr) + ): for item in expr.args[0].items: if isinstance(item, StrExpr): continue elif isinstance(item, CallExpr): - if (not isinstance(item.callee, MemberExpr) - or item.callee.name != 'format'): + if not isinstance(item.callee, MemberExpr) or item.callee.name != "format": return None - elif (not isinstance(item.callee.expr, StrExpr) - or item.callee.expr.value != '{:{}}'): + elif ( + not isinstance(item.callee.expr, StrExpr) or item.callee.expr.value != "{:{}}" + ): return None - if not isinstance(item.args[1], StrExpr) or item.args[1].value != '': + if not isinstance(item.args[1], StrExpr) or item.args[1].value != "": return None else: return None @@ -535,7 +615,7 @@ def translate_fstring( exprs: List[Expression] = [] for item in expr.args[0].items: - if isinstance(item, StrExpr) and item.value != '': + if isinstance(item, StrExpr) and item.value != "": format_ops.append(FormatOp.STR) exprs.append(item) elif isinstance(item, CallExpr): diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 7ee757090d3d5..38764c972e5b8 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -6,37 +6,76 @@ A few statements are transformed in mypyc.irbuild.function (yield, for example). """ -from typing import Optional, List, Tuple, Sequence, Callable import importlib.util +from typing import Callable, List, Optional, Sequence, Tuple from mypy.nodes import ( - Block, ExpressionStmt, ReturnStmt, AssignmentStmt, OperatorAssignmentStmt, IfStmt, WhileStmt, - ForStmt, BreakStmt, ContinueStmt, RaiseStmt, TryStmt, WithStmt, AssertStmt, DelStmt, - Expression, StrExpr, TempNode, Lvalue, Import, ImportFrom, ImportAll, TupleExpr, ListExpr, - StarExpr + AssertStmt, + AssignmentStmt, + Block, + BreakStmt, + ContinueStmt, + DelStmt, + Expression, + ExpressionStmt, + ForStmt, + IfStmt, + Import, + ImportAll, + ImportFrom, + ListExpr, + Lvalue, + OperatorAssignmentStmt, + RaiseStmt, + ReturnStmt, + StarExpr, + StrExpr, + TempNode, + TryStmt, + TupleExpr, + WhileStmt, + WithStmt, ) - from mypyc.ir.ops import ( - Assign, Unreachable, RaiseStandardError, LoadErrorValue, BasicBlock, TupleGet, Value, Register, - Branch, NO_TRACEBACK_LINE_NO + NO_TRACEBACK_LINE_NO, + Assign, + BasicBlock, + Branch, + LoadErrorValue, + RaiseStandardError, + Register, + TupleGet, + Unreachable, + Value, ) from mypyc.ir.rtypes import RInstance, exc_rtuple, is_tagged -from mypyc.primitives.generic_ops import py_delattr_op -from mypyc.primitives.misc_ops import type_op, import_from_op -from mypyc.primitives.exc_ops import ( - raise_exception_op, reraise_exception_op, error_catch_op, exc_matches_op, restore_exc_info_op, - get_exc_value_op, keep_propagating_op, get_exc_info_op +from mypyc.irbuild.ast_helpers import is_borrow_friendly_expr, process_conditional +from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op +from mypyc.irbuild.for_helpers import for_loop_helper +from mypyc.irbuild.nonlocalcontrol import ( + ExceptNonlocalControl, + FinallyNonlocalControl, + TryFinallyNonlocalControl, ) from mypyc.irbuild.targets import ( - AssignmentTarget, AssignmentTargetRegister, AssignmentTargetIndex, AssignmentTargetAttr, - AssignmentTargetTuple + AssignmentTarget, + AssignmentTargetAttr, + AssignmentTargetIndex, + AssignmentTargetRegister, + AssignmentTargetTuple, ) -from mypyc.irbuild.nonlocalcontrol import ( - ExceptNonlocalControl, FinallyNonlocalControl, TryFinallyNonlocalControl +from mypyc.primitives.exc_ops import ( + error_catch_op, + exc_matches_op, + get_exc_info_op, + get_exc_value_op, + keep_propagating_op, + raise_exception_op, + reraise_exception_op, + restore_exc_info_op, ) -from mypyc.irbuild.for_helpers import for_loop_helper -from mypyc.irbuild.builder import IRBuilder, int_borrow_friendly_op -from mypyc.irbuild.ast_helpers import process_conditional, is_borrow_friendly_expr +from mypyc.primitives.generic_ops import py_delattr_op +from mypyc.primitives.misc_ops import import_from_op, type_op GenFunc = Callable[[], None] @@ -49,9 +88,11 @@ def transform_block(builder: IRBuilder, block: Block) -> None: # Don't complain about empty unreachable blocks, since mypy inserts # those after `if MYPY`. elif block.body: - builder.add(RaiseStandardError(RaiseStandardError.RUNTIME_ERROR, - 'Reached allegedly unreachable code!', - block.line)) + builder.add( + RaiseStandardError( + RaiseStandardError.RUNTIME_ERROR, "Reached allegedly unreachable code!", block.line + ) + ) builder.add(Unreachable()) @@ -87,11 +128,13 @@ def transform_assignment_stmt(builder: IRBuilder, stmt: AssignmentStmt) -> None: return # Special case multiple assignments like 'x, y = e1, e2'. - if (isinstance(first_lvalue, (TupleExpr, ListExpr)) - and isinstance(stmt.rvalue, (TupleExpr, ListExpr)) - and len(first_lvalue.items) == len(stmt.rvalue.items) - and all(is_simple_lvalue(item) for item in first_lvalue.items) - and len(lvalues) == 1): + if ( + isinstance(first_lvalue, (TupleExpr, ListExpr)) + and isinstance(stmt.rvalue, (TupleExpr, ListExpr)) + and len(first_lvalue.items) == len(stmt.rvalue.items) + and all(is_simple_lvalue(item) for item in first_lvalue.items) + and len(lvalues) == 1 + ): temps = [] for right in stmt.rvalue.items: rvalue_reg = builder.accept(right) @@ -121,18 +164,21 @@ def is_simple_lvalue(expr: Expression) -> bool: def transform_operator_assignment_stmt(builder: IRBuilder, stmt: OperatorAssignmentStmt) -> None: """Operator assignment statement such as x += 1""" builder.disallow_class_assignments([stmt.lvalue], stmt.line) - if (is_tagged(builder.node_type(stmt.lvalue)) - and is_tagged(builder.node_type(stmt.rvalue)) - and stmt.op in int_borrow_friendly_op): - can_borrow = (is_borrow_friendly_expr(builder, stmt.rvalue) - and is_borrow_friendly_expr(builder, stmt.lvalue)) + if ( + is_tagged(builder.node_type(stmt.lvalue)) + and is_tagged(builder.node_type(stmt.rvalue)) + and stmt.op in int_borrow_friendly_op + ): + can_borrow = is_borrow_friendly_expr(builder, stmt.rvalue) and is_borrow_friendly_expr( + builder, stmt.lvalue + ) else: can_borrow = False target = builder.get_assignment_target(stmt.lvalue) target_value = builder.read(target, stmt.line, can_borrow=can_borrow) rreg = builder.accept(stmt.rvalue, can_borrow=can_borrow) # the Python parser strips the '=' from operator assignment statements, so re-add it - op = stmt.op + '=' + op = stmt.op + "=" res = builder.binary_op(target_value, rreg, op, stmt.line) # usually operator assignments are done in-place # but when target doesn't support that we need to manually assign @@ -162,13 +208,13 @@ def transform_import(builder: IRBuilder, node: Import) -> None: name = as_name base = node_id else: - base = name = node_id.split('.')[0] + base = name = node_id.split(".")[0] obj = builder.get_module(base, node.line) builder.gen_method_call( - globals, '__setitem__', [builder.load_str(name), obj], - result_type=None, line=node.line) + globals, "__setitem__", [builder.load_str(name), obj], result_type=None, line=node.line + ) def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: @@ -181,9 +227,9 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: elif builder.module_path.endswith("__init__.py"): module_package = builder.module_name else: - module_package = '' + module_package = "" - id = importlib.util.resolve_name('.' * node.relative + node.id, module_package) + id = importlib.util.resolve_name("." * node.relative + node.id, module_package) globals = builder.load_globals_dict() imported_names = [name for name, _ in node.names] @@ -195,13 +241,18 @@ def transform_import_from(builder: IRBuilder, node: ImportFrom) -> None: # This probably doesn't matter much and the code runs basically right. for name, maybe_as_name in node.names: as_name = maybe_as_name or name - obj = builder.call_c(import_from_op, - [module, builder.load_str(id), - builder.load_str(name), builder.load_str(as_name)], - node.line) + obj = builder.call_c( + import_from_op, + [module, builder.load_str(id), builder.load_str(name), builder.load_str(as_name)], + node.line, + ) builder.gen_method_call( - globals, '__setitem__', [builder.load_str(as_name), obj], - result_type=None, line=node.line) + globals, + "__setitem__", + [builder.load_str(as_name), obj], + result_type=None, + line=node.line, + ) def transform_import_all(builder: IRBuilder, node: ImportAll) -> None: @@ -255,7 +306,7 @@ def transform_while_stmt(builder: IRBuilder, s: WhileStmt) -> None: def transform_for_stmt(builder: IRBuilder, s: ForStmt) -> None: if s.is_async: - builder.error('async for is unimplemented', s.line) + builder.error("async for is unimplemented", s.line) def body() -> None: builder.accept(s.body) @@ -264,8 +315,7 @@ def else_block() -> None: assert s.else_body is not None builder.accept(s.else_body) - for_loop_helper(builder, s.index, s.expr, body, - else_block if s.else_body else None, s.line) + for_loop_helper(builder, s.index, s.expr, body, else_block if s.else_body else None, s.line) def transform_break_stmt(builder: IRBuilder, node: BreakStmt) -> None: @@ -287,12 +337,13 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None: builder.add(Unreachable()) -def transform_try_except(builder: IRBuilder, - body: GenFunc, - handlers: Sequence[ - Tuple[Optional[Expression], Optional[Expression], GenFunc]], - else_body: Optional[GenFunc], - line: int) -> None: +def transform_try_except( + builder: IRBuilder, + body: GenFunc, + handlers: Sequence[Tuple[Optional[Expression], Optional[Expression], GenFunc]], + else_body: Optional[GenFunc], + line: int, +) -> None: """Generalized try/except/else handling that takes functions to gen the bodies. The point of this is to also be able to support with.""" @@ -320,26 +371,19 @@ def transform_try_except(builder: IRBuilder, builder.activate_block(except_entry) old_exc = builder.maybe_spill(builder.call_c(error_catch_op, [], line)) # Compile the except blocks with the nonlocal control flow overridden to clear exc_info - builder.nonlocal_control.append( - ExceptNonlocalControl(builder.nonlocal_control[-1], old_exc)) + builder.nonlocal_control.append(ExceptNonlocalControl(builder.nonlocal_control[-1], old_exc)) # Process the bodies for type, var, handler_body in handlers: next_block = None if type: next_block, body_block = BasicBlock(), BasicBlock() - matches = builder.call_c( - exc_matches_op, [builder.accept(type)], type.line - ) + matches = builder.call_c(exc_matches_op, [builder.accept(type)], type.line) builder.add(Branch(matches, body_block, next_block, Branch.BOOL)) builder.activate_block(body_block) if var: target = builder.get_assignment_target(var) - builder.assign( - target, - builder.call_c(get_exc_value_op, [], var.line), - var.line - ) + builder.assign(target, builder.call_c(get_exc_value_op, [], var.line), var.line) handler_body() builder.goto(cleanup_block) if next_block: @@ -385,17 +429,20 @@ def body() -> None: def make_handler(body: Block) -> GenFunc: return lambda: builder.accept(body) - handlers = [(type, var, make_handler(body)) - for type, var, body in zip(t.types, t.vars, t.handlers)] + handlers = [ + (type, var, make_handler(body)) for type, var, body in zip(t.types, t.vars, t.handlers) + ] else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None transform_try_except(builder, body, handlers, else_body, t.line) -def try_finally_try(builder: IRBuilder, - err_handler: BasicBlock, - return_entry: BasicBlock, - main_entry: BasicBlock, - try_body: GenFunc) -> Optional[Register]: +def try_finally_try( + builder: IRBuilder, + err_handler: BasicBlock, + return_entry: BasicBlock, + main_entry: BasicBlock, + try_body: GenFunc, +) -> Optional[Register]: # Compile the try block with an error handler control = TryFinallyNonlocalControl(return_entry) builder.builder.push_error_handler(err_handler) @@ -410,23 +457,20 @@ def try_finally_try(builder: IRBuilder, return control.ret_reg -def try_finally_entry_blocks(builder: IRBuilder, - err_handler: BasicBlock, - return_entry: BasicBlock, - main_entry: BasicBlock, - finally_block: BasicBlock, - ret_reg: Optional[Register]) -> Value: +def try_finally_entry_blocks( + builder: IRBuilder, + err_handler: BasicBlock, + return_entry: BasicBlock, + main_entry: BasicBlock, + finally_block: BasicBlock, + ret_reg: Optional[Register], +) -> Value: old_exc = Register(exc_rtuple) # Entry block for non-exceptional flow builder.activate_block(main_entry) if ret_reg: - builder.add( - Assign( - ret_reg, - builder.add(LoadErrorValue(builder.ret_types[-1])) - ) - ) + builder.add(Assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1])))) builder.goto(return_entry) builder.activate_block(return_entry) @@ -436,12 +480,7 @@ def try_finally_entry_blocks(builder: IRBuilder, # Entry block for errors builder.activate_block(err_handler) if ret_reg: - builder.add( - Assign( - ret_reg, - builder.add(LoadErrorValue(builder.ret_types[-1])) - ) - ) + builder.add(Assign(ret_reg, builder.add(LoadErrorValue(builder.ret_types[-1])))) builder.add(Assign(old_exc, builder.call_c(error_catch_op, [], -1))) builder.goto(finally_block) @@ -449,16 +488,16 @@ def try_finally_entry_blocks(builder: IRBuilder, def try_finally_body( - builder: IRBuilder, - finally_block: BasicBlock, - finally_body: GenFunc, - ret_reg: Optional[Value], - old_exc: Value) -> Tuple[BasicBlock, FinallyNonlocalControl]: + builder: IRBuilder, + finally_block: BasicBlock, + finally_body: GenFunc, + ret_reg: Optional[Value], + old_exc: Value, +) -> Tuple[BasicBlock, FinallyNonlocalControl]: cleanup_block = BasicBlock() # Compile the finally block with the nonlocal control flow overridden to restore exc_info builder.builder.push_error_handler(cleanup_block) - finally_control = FinallyNonlocalControl( - builder.nonlocal_control[-1], ret_reg, old_exc) + finally_control = FinallyNonlocalControl(builder.nonlocal_control[-1], ret_reg, old_exc) builder.nonlocal_control.append(finally_control) builder.activate_block(finally_block) finally_body() @@ -467,11 +506,13 @@ def try_finally_body( return cleanup_block, finally_control -def try_finally_resolve_control(builder: IRBuilder, - cleanup_block: BasicBlock, - finally_control: FinallyNonlocalControl, - old_exc: Value, - ret_reg: Optional[Value]) -> BasicBlock: +def try_finally_resolve_control( + builder: IRBuilder, + cleanup_block: BasicBlock, + finally_control: FinallyNonlocalControl, + old_exc: Value, + ret_reg: Optional[Value], +) -> BasicBlock: """Resolve the control flow out of a finally block. This means returning if there was a return, propagating @@ -509,9 +550,9 @@ def try_finally_resolve_control(builder: IRBuilder, return out_block -def transform_try_finally_stmt(builder: IRBuilder, - try_body: GenFunc, - finally_body: GenFunc) -> None: +def transform_try_finally_stmt( + builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc +) -> None: """Generalized try/finally handling that takes functions to gen the bodies. The point of this is to also be able to support with.""" @@ -519,23 +560,29 @@ def transform_try_finally_stmt(builder: IRBuilder, # exits can occur. We emit 10+ basic blocks for every finally! err_handler, main_entry, return_entry, finally_block = ( - BasicBlock(), BasicBlock(), BasicBlock(), BasicBlock()) + BasicBlock(), + BasicBlock(), + BasicBlock(), + BasicBlock(), + ) # Compile the body of the try - ret_reg = try_finally_try( - builder, err_handler, return_entry, main_entry, try_body) + ret_reg = try_finally_try(builder, err_handler, return_entry, main_entry, try_body) # Set up the entry blocks for the finally statement old_exc = try_finally_entry_blocks( - builder, err_handler, return_entry, main_entry, finally_block, ret_reg) + builder, err_handler, return_entry, main_entry, finally_block, ret_reg + ) # Compile the body of the finally cleanup_block, finally_control = try_finally_body( - builder, finally_block, finally_body, ret_reg, old_exc) + builder, finally_block, finally_body, ret_reg, old_exc + ) # Resolve the control flow out of the finally block out_block = try_finally_resolve_control( - builder, cleanup_block, finally_control, old_exc, ret_reg) + builder, cleanup_block, finally_control, old_exc, ret_reg + ) builder.activate_block(out_block) @@ -547,11 +594,13 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None: # try/except/else/finally, we treat the try/except/else as the # body of a try/finally block. if t.finally_body: + def transform_try_body() -> None: if t.handlers: transform_try_except_stmt(builder, t) else: builder.accept(t.body) + body = t.finally_body transform_try_finally_stmt(builder, transform_try_body, lambda: builder.accept(body)) @@ -564,21 +613,17 @@ def get_sys_exc_info(builder: IRBuilder) -> List[Value]: return [builder.add(TupleGet(exc_info, i, -1)) for i in range(3)] -def transform_with(builder: IRBuilder, - expr: Expression, - target: Optional[Lvalue], - body: GenFunc, - line: int) -> None: +def transform_with( + builder: IRBuilder, expr: Expression, target: Optional[Lvalue], body: GenFunc, line: int +) -> None: # This is basically a straight transcription of the Python code in PEP 343. # I don't actually understand why a bunch of it is the way it is. # We could probably optimize the case where the manager is compiled by us, # but that is not our common case at all, so. mgr_v = builder.accept(expr) typ = builder.call_c(type_op, [mgr_v], line) - exit_ = builder.maybe_spill(builder.py_get_attr(typ, '__exit__', line)) - value = builder.py_call( - builder.py_get_attr(typ, '__enter__', line), [mgr_v], line - ) + exit_ = builder.maybe_spill(builder.py_get_attr(typ, "__exit__", line)) + value = builder.py_call(builder.py_get_attr(typ, "__enter__", line), [mgr_v], line) mgr = builder.maybe_spill(mgr_v) exc = builder.maybe_spill_assignable(builder.true()) @@ -591,10 +636,11 @@ def except_body() -> None: builder.assign(exc, builder.false(), line) out_block, reraise_block = BasicBlock(), BasicBlock() builder.add_bool_branch( - builder.py_call(builder.read(exit_), - [builder.read(mgr)] + get_sys_exc_info(builder), line), + builder.py_call( + builder.read(exit_), [builder.read(mgr)] + get_sys_exc_info(builder), line + ), out_block, - reraise_block + reraise_block, ) builder.activate_block(reraise_block) builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO) @@ -603,30 +649,22 @@ def except_body() -> None: def finally_body() -> None: out_block, exit_block = BasicBlock(), BasicBlock() - builder.add( - Branch(builder.read(exc), exit_block, out_block, Branch.BOOL) - ) + builder.add(Branch(builder.read(exc), exit_block, out_block, Branch.BOOL)) builder.activate_block(exit_block) none = builder.none_object() - builder.py_call( - builder.read(exit_), [builder.read(mgr), none, none, none], line - ) + builder.py_call(builder.read(exit_), [builder.read(mgr), none, none, none], line) builder.goto_and_activate(out_block) transform_try_finally_stmt( builder, - lambda: transform_try_except(builder, - try_body, - [(None, None, except_body)], - None, - line), - finally_body + lambda: transform_try_except(builder, try_body, [(None, None, except_body)], None, line), + finally_body, ) def transform_with_stmt(builder: IRBuilder, o: WithStmt) -> None: if o.is_async: - builder.error('async with is unimplemented', o.line) + builder.error("async with is unimplemented", o.line) # Generate separate logic for each expr in it, left to right def generate(i: int) -> None: @@ -650,12 +688,11 @@ def transform_assert_stmt(builder: IRBuilder, a: AssertStmt) -> None: builder.add(RaiseStandardError(RaiseStandardError.ASSERTION_ERROR, None, a.line)) elif isinstance(a.msg, StrExpr): # Another special case - builder.add(RaiseStandardError(RaiseStandardError.ASSERTION_ERROR, a.msg.value, - a.line)) + builder.add(RaiseStandardError(RaiseStandardError.ASSERTION_ERROR, a.msg.value, a.line)) else: # The general case -- explicitly construct an exception instance message = builder.accept(a.msg) - exc_type = builder.load_module_attr_by_fullname('builtins.AssertionError', a.line) + exc_type = builder.load_module_attr_by_fullname("builtins.AssertionError", a.line) exc = builder.py_call(exc_type, [message], a.line) builder.call_c(raise_exception_op, [exc], a.line) builder.add(Unreachable()) @@ -669,11 +706,7 @@ def transform_del_stmt(builder: IRBuilder, o: DelStmt) -> None: def transform_del_item(builder: IRBuilder, target: AssignmentTarget, line: int) -> None: if isinstance(target, AssignmentTargetIndex): builder.gen_method_call( - target.base, - '__delitem__', - [target.index], - result_type=None, - line=line + target.base, "__delitem__", [target.index], result_type=None, line=line ) elif isinstance(target, AssignmentTargetAttr): if isinstance(target.obj_type, RInstance): @@ -681,15 +714,18 @@ def transform_del_item(builder: IRBuilder, target: AssignmentTarget, line: int) if not cl.is_deletable(target.attr): builder.error(f'"{target.attr}" cannot be deleted', line) builder.note( - 'Using "__deletable__ = ' + - '[\'\']" in the class body enables "del obj."', line) + 'Using "__deletable__ = ' + + '[\'\']" in the class body enables "del obj."', + line, + ) key = builder.load_str(target.attr) builder.call_c(py_delattr_op, [target.obj, key], line) elif isinstance(target, AssignmentTargetRegister): # Delete a local by assigning an error value to it, which will # prompt the insertion of uninit checks. - builder.add(Assign(target.register, - builder.add(LoadErrorValue(target.type, undefines=True)))) + builder.add( + Assign(target.register, builder.add(LoadErrorValue(target.type, undefines=True))) + ) elif isinstance(target, AssignmentTargetTuple): for subtarget in target.items: transform_del_item(builder, subtarget, line) diff --git a/mypyc/irbuild/targets.py b/mypyc/irbuild/targets.py index f2daa701f7e85..e47a5621f11d2 100644 --- a/mypyc/irbuild/targets.py +++ b/mypyc/irbuild/targets.py @@ -1,7 +1,7 @@ from typing import List, Optional -from mypyc.ir.ops import Value, Register -from mypyc.ir.rtypes import RType, RInstance, object_rprimitive +from mypyc.ir.ops import Register, Value +from mypyc.ir.rtypes import RInstance, RType, object_rprimitive class AssignmentTarget: @@ -52,8 +52,6 @@ def __init__(self, obj: Value, attr: str, can_borrow: bool = False) -> None: class AssignmentTargetTuple(AssignmentTarget): """x, ..., y as assignment target""" - def __init__(self, - items: List[AssignmentTarget], - star_idx: Optional[int] = None) -> None: + def __init__(self, items: List[AssignmentTarget], star_idx: Optional[int] = None) -> None: self.items = items self.star_idx = star_idx diff --git a/mypyc/irbuild/util.py b/mypyc/irbuild/util.py index 7a7b95245d4c3..5ab9d2f2fc9a1 100644 --- a/mypyc/irbuild/util.py +++ b/mypyc/irbuild/util.py @@ -1,23 +1,36 @@ """Various utilities that don't depend on other modules in mypyc.irbuild.""" -from typing import Dict, Any, Union, Optional +from typing import Any, Dict, Optional, Union from mypy.nodes import ( - ClassDef, FuncDef, Decorator, OverloadedFuncDef, StrExpr, CallExpr, RefExpr, Expression, - IntExpr, FloatExpr, Var, NameExpr, TupleExpr, UnaryExpr, BytesExpr, - ArgKind, ARG_NAMED, ARG_NAMED_OPT, ARG_POS, ARG_OPT, GDEF, + ARG_NAMED, + ARG_NAMED_OPT, + ARG_OPT, + ARG_POS, + GDEF, + ArgKind, + BytesExpr, + CallExpr, + ClassDef, + Decorator, + Expression, + FloatExpr, + FuncDef, + IntExpr, + NameExpr, + OverloadedFuncDef, + RefExpr, + StrExpr, + TupleExpr, + UnaryExpr, + Var, ) - -DATACLASS_DECORATORS = { - 'dataclasses.dataclass', - 'attr.s', - 'attr.attrs', -} +DATACLASS_DECORATORS = {"dataclasses.dataclass", "attr.s", "attr.attrs"} def is_trait_decorator(d: Expression) -> bool: - return isinstance(d, RefExpr) and d.fullname == 'mypy_extensions.trait' + return isinstance(d, RefExpr) and d.fullname == "mypy_extensions.trait" def is_trait(cdef: ClassDef) -> bool: @@ -26,17 +39,19 @@ def is_trait(cdef: ClassDef) -> bool: def dataclass_decorator_type(d: Expression) -> Optional[str]: if isinstance(d, RefExpr) and d.fullname in DATACLASS_DECORATORS: - return d.fullname.split('.')[0] - elif (isinstance(d, CallExpr) - and isinstance(d.callee, RefExpr) - and d.callee.fullname in DATACLASS_DECORATORS): - name = d.callee.fullname.split('.')[0] - if name == 'attr' and 'auto_attribs' in d.arg_names: + return d.fullname.split(".")[0] + elif ( + isinstance(d, CallExpr) + and isinstance(d.callee, RefExpr) + and d.callee.fullname in DATACLASS_DECORATORS + ): + name = d.callee.fullname.split(".")[0] + if name == "attr" and "auto_attribs" in d.arg_names: # Note: the mypy attrs plugin checks that the value of auto_attribs is # not computed at runtime, so we don't need to perform that check here - auto = d.args[d.arg_names.index('auto_attribs')] - if isinstance(auto, NameExpr) and auto.name == 'True': - return 'attr-auto' + auto = d.args[d.arg_names.index("auto_attribs")] + if isinstance(auto, NameExpr) and auto.name == "True": + return "attr-auto" return name else: return None @@ -64,11 +79,11 @@ def get_mypyc_attr_literal(e: Expression) -> Any: Supports a pretty limited range.""" if isinstance(e, (StrExpr, IntExpr, FloatExpr)): return e.value - elif isinstance(e, RefExpr) and e.fullname == 'builtins.True': + elif isinstance(e, RefExpr) and e.fullname == "builtins.True": return True - elif isinstance(e, RefExpr) and e.fullname == 'builtins.False': + elif isinstance(e, RefExpr) and e.fullname == "builtins.False": return False - elif isinstance(e, RefExpr) and e.fullname == 'builtins.None': + elif isinstance(e, RefExpr) and e.fullname == "builtins.None": return None return NotImplemented @@ -78,7 +93,7 @@ def get_mypyc_attr_call(d: Expression) -> Optional[CallExpr]: if ( isinstance(d, CallExpr) and isinstance(d.callee, RefExpr) - and d.callee.fullname == 'mypy_extensions.mypyc_attr' + and d.callee.fullname == "mypy_extensions.mypyc_attr" ): return d return None @@ -102,9 +117,7 @@ def get_mypyc_attrs(stmt: Union[ClassDef, Decorator]) -> Dict[str, Any]: def is_extension_class(cdef: ClassDef) -> bool: if any( - not is_trait_decorator(d) - and not is_dataclass_decorator(d) - and not get_mypyc_attr_call(d) + not is_trait_decorator(d) and not is_dataclass_decorator(d) and not get_mypyc_attr_call(d) for d in cdef.decorators ): return False @@ -112,8 +125,11 @@ def is_extension_class(cdef: ClassDef) -> bool: return False if cdef.info.is_named_tuple: return False - if (cdef.info.metaclass_type and cdef.info.metaclass_type.type.fullname not in ( - 'abc.ABCMeta', 'typing.TypingMeta', 'typing.GenericMeta')): + if cdef.info.metaclass_type and cdef.info.metaclass_type.type.fullname not in ( + "abc.ABCMeta", + "typing.TypingMeta", + "typing.GenericMeta", + ): return False return True @@ -146,11 +162,16 @@ def is_constant(e: Expression) -> bool: primitives types, None, and references to Final global variables. """ - return (isinstance(e, (StrExpr, BytesExpr, IntExpr, FloatExpr)) - or (isinstance(e, UnaryExpr) and e.op == '-' - and isinstance(e.expr, (IntExpr, FloatExpr))) - or (isinstance(e, TupleExpr) - and all(is_constant(e) for e in e.items)) - or (isinstance(e, RefExpr) and e.kind == GDEF - and (e.fullname in ('builtins.True', 'builtins.False', 'builtins.None') - or (isinstance(e.node, Var) and e.node.is_final)))) + return ( + isinstance(e, (StrExpr, BytesExpr, IntExpr, FloatExpr)) + or (isinstance(e, UnaryExpr) and e.op == "-" and isinstance(e.expr, (IntExpr, FloatExpr))) + or (isinstance(e, TupleExpr) and all(is_constant(e) for e in e.items)) + or ( + isinstance(e, RefExpr) + and e.kind == GDEF + and ( + e.fullname in ("builtins.True", "builtins.False", "builtins.None") + or (isinstance(e.node, Var) and e.node.is_final) + ) + ) + ) diff --git a/mypyc/irbuild/visitor.py b/mypyc/irbuild/visitor.py index 15ac08d9c9739..0887bec7cd55a 100644 --- a/mypyc/irbuild/visitor.py +++ b/mypyc/irbuild/visitor.py @@ -6,78 +6,141 @@ from typing_extensions import NoReturn from mypy.nodes import ( - AssertTypeExpr, MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr, - IntExpr, NameExpr, Var, IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, CallExpr, - IndexExpr, Block, ListExpr, ExpressionStmt, MemberExpr, ForStmt, - BreakStmt, ContinueStmt, ConditionalExpr, OperatorAssignmentStmt, TupleExpr, ClassDef, - Import, ImportFrom, ImportAll, DictExpr, StrExpr, CastExpr, TempNode, - PassStmt, PromoteExpr, AssignmentExpr, AwaitExpr, BackquoteExpr, AssertStmt, BytesExpr, - ComplexExpr, Decorator, DelStmt, DictionaryComprehension, EllipsisExpr, EnumCallExpr, ExecStmt, - FloatExpr, GeneratorExpr, GlobalDecl, LambdaExpr, ListComprehension, SetComprehension, - NamedTupleExpr, NewTypeExpr, NonlocalDecl, OverloadedFuncDef, PrintStmt, RaiseStmt, - RevealExpr, SetExpr, SliceExpr, StarExpr, SuperExpr, TryStmt, TypeAliasExpr, TypeApplication, - TypeVarExpr, TypedDictExpr, UnicodeExpr, WithStmt, YieldFromExpr, YieldExpr, ParamSpecExpr, - MatchStmt, TypeVarTupleExpr + AssertStmt, + AssertTypeExpr, + AssignmentExpr, + AssignmentStmt, + AwaitExpr, + BackquoteExpr, + Block, + BreakStmt, + BytesExpr, + CallExpr, + CastExpr, + ClassDef, + ComparisonExpr, + ComplexExpr, + ConditionalExpr, + ContinueStmt, + Decorator, + DelStmt, + DictExpr, + DictionaryComprehension, + EllipsisExpr, + EnumCallExpr, + ExecStmt, + ExpressionStmt, + FloatExpr, + ForStmt, + FuncDef, + GeneratorExpr, + GlobalDecl, + IfStmt, + Import, + ImportAll, + ImportFrom, + IndexExpr, + IntExpr, + LambdaExpr, + ListComprehension, + ListExpr, + MatchStmt, + MemberExpr, + MypyFile, + NamedTupleExpr, + NameExpr, + NewTypeExpr, + NonlocalDecl, + OperatorAssignmentStmt, + OpExpr, + OverloadedFuncDef, + ParamSpecExpr, + PassStmt, + PrintStmt, + PromoteExpr, + RaiseStmt, + ReturnStmt, + RevealExpr, + SetComprehension, + SetExpr, + SliceExpr, + StarExpr, + StrExpr, + SuperExpr, + TempNode, + TryStmt, + TupleExpr, + TypeAliasExpr, + TypeApplication, + TypedDictExpr, + TypeVarExpr, + TypeVarTupleExpr, + UnaryExpr, + UnicodeExpr, + Var, + WhileStmt, + WithStmt, + YieldExpr, + YieldFromExpr, ) - from mypyc.ir.ops import Value -from mypyc.irbuild.builder import IRVisitor, IRBuilder, UnsupportedException +from mypyc.irbuild.builder import IRBuilder, IRVisitor, UnsupportedException from mypyc.irbuild.classdef import transform_class_def +from mypyc.irbuild.expression import ( + transform_assignment_expr, + transform_bytes_expr, + transform_call_expr, + transform_comparison_expr, + transform_complex_expr, + transform_conditional_expr, + transform_dict_expr, + transform_dictionary_comprehension, + transform_ellipsis, + transform_float_expr, + transform_generator_expr, + transform_index_expr, + transform_int_expr, + transform_list_comprehension, + transform_list_expr, + transform_member_expr, + transform_name_expr, + transform_op_expr, + transform_set_comprehension, + transform_set_expr, + transform_slice_expr, + transform_str_expr, + transform_super_expr, + transform_tuple_expr, + transform_unary_expr, +) from mypyc.irbuild.function import ( - transform_func_def, - transform_overloaded_func_def, + transform_await_expr, transform_decorator, + transform_func_def, transform_lambda_expr, + transform_overloaded_func_def, transform_yield_expr, transform_yield_from_expr, - transform_await_expr, ) from mypyc.irbuild.statement import ( + transform_assert_stmt, + transform_assignment_stmt, transform_block, + transform_break_stmt, + transform_continue_stmt, + transform_del_stmt, transform_expression_stmt, - transform_return_stmt, - transform_assignment_stmt, - transform_operator_assignment_stmt, + transform_for_stmt, + transform_if_stmt, transform_import, - transform_import_from, transform_import_all, - transform_if_stmt, - transform_while_stmt, - transform_for_stmt, - transform_break_stmt, - transform_continue_stmt, + transform_import_from, + transform_operator_assignment_stmt, transform_raise_stmt, + transform_return_stmt, transform_try_stmt, + transform_while_stmt, transform_with_stmt, - transform_assert_stmt, - transform_del_stmt, -) -from mypyc.irbuild.expression import ( - transform_name_expr, - transform_member_expr, - transform_super_expr, - transform_call_expr, - transform_unary_expr, - transform_op_expr, - transform_index_expr, - transform_conditional_expr, - transform_int_expr, - transform_float_expr, - transform_complex_expr, - transform_comparison_expr, - transform_str_expr, - transform_bytes_expr, - transform_ellipsis, - transform_list_expr, - transform_tuple_expr, - transform_dict_expr, - transform_set_expr, - transform_list_comprehension, - transform_set_comprehension, - transform_dictionary_comprehension, - transform_slice_expr, - transform_generator_expr, - transform_assignment_expr, ) diff --git a/mypyc/irbuild/vtable.py b/mypyc/irbuild/vtable.py index ce2c2d3b22225..6fe0cf568fea3 100644 --- a/mypyc/irbuild/vtable.py +++ b/mypyc/irbuild/vtable.py @@ -64,12 +64,17 @@ def specialize_parent_vtable(cls: ClassIR, parent: ClassIR) -> VTableEntries: if method_cls: child_method, defining_cls = method_cls # TODO: emit a wrapper for __init__ that raises or something - if (is_same_method_signature(orig_parent_method.sig, child_method.sig) - or orig_parent_method.name == '__init__'): + if ( + is_same_method_signature(orig_parent_method.sig, child_method.sig) + or orig_parent_method.name == "__init__" + ): entry = VTableMethod(entry.cls, entry.name, child_method, entry.shadow_method) else: - entry = VTableMethod(entry.cls, entry.name, - defining_cls.glue_methods[(entry.cls, entry.name)], - entry.shadow_method) + entry = VTableMethod( + entry.cls, + entry.name, + defining_cls.glue_methods[(entry.cls, entry.name)], + entry.shadow_method, + ) updated.append(entry) return updated diff --git a/mypyc/lib-rt/setup.py b/mypyc/lib-rt/setup.py index 482db5ded8f72..8f07a55bd0cfe 100644 --- a/mypyc/lib-rt/setup.py +++ b/mypyc/lib-rt/setup.py @@ -3,25 +3,29 @@ The tests are written in C++ and use the Google Test framework. """ -from distutils.core import setup, Extension import sys +from distutils.core import Extension, setup -if sys.platform == 'darwin': - kwargs = {'language': 'c++'} +if sys.platform == "darwin": + kwargs = {"language": "c++"} compile_args = [] else: kwargs = {} # type: ignore - compile_args = ['--std=c++11'] + compile_args = ["--std=c++11"] -setup(name='test_capi', - version='0.1', - ext_modules=[Extension( - 'test_capi', - ['test_capi.cc', 'init.c', 'int_ops.c', 'list_ops.c', 'exc_ops.c', 'generic_ops.c'], - depends=['CPy.h', 'mypyc_util.h', 'pythonsupport.h'], - extra_compile_args=['-Wno-unused-function', '-Wno-sign-compare'] + compile_args, - library_dirs=['../external/googletest/make'], - libraries=['gtest'], - include_dirs=['../external/googletest', '../external/googletest/include'], - **kwargs - )]) +setup( + name="test_capi", + version="0.1", + ext_modules=[ + Extension( + "test_capi", + ["test_capi.cc", "init.c", "int_ops.c", "list_ops.c", "exc_ops.c", "generic_ops.c"], + depends=["CPy.h", "mypyc_util.h", "pythonsupport.h"], + extra_compile_args=["-Wno-unused-function", "-Wno-sign-compare"] + compile_args, + library_dirs=["../external/googletest/make"], + libraries=["gtest"], + include_dirs=["../external/googletest", "../external/googletest/include"], + **kwargs, + ) + ], +) diff --git a/mypyc/namegen.py b/mypyc/namegen.py index 99abf8a759ff4..9df9be82d3a74 100644 --- a/mypyc/namegen.py +++ b/mypyc/namegen.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Tuple, Set, Optional, Iterable +from typing import Dict, Iterable, List, Optional, Set, Tuple class NameGenerator: @@ -64,16 +64,16 @@ def private_name(self, module: str, partial_name: Optional[str] = None) -> str: """ # TODO: Support unicode if partial_name is None: - return exported_name(self.module_map[module].rstrip('.')) + return exported_name(self.module_map[module].rstrip(".")) if (module, partial_name) in self.translations: return self.translations[module, partial_name] if module in self.module_map: module_prefix = self.module_map[module] elif module: - module_prefix = module + '.' + module_prefix = module + "." else: - module_prefix = '' - actual = exported_name(f'{module_prefix}{partial_name}') + module_prefix = "" + actual = exported_name(f"{module_prefix}{partial_name}") self.translations[module, partial_name] = actual return actual @@ -86,7 +86,7 @@ def exported_name(fullname: str) -> str: builds. """ # TODO: Support unicode - return fullname.replace('___', '___3_').replace('.', '___') + return fullname.replace("___", "___3_").replace(".", "___") def make_module_translation_map(names: List[str]) -> Dict[str, str]: @@ -106,8 +106,8 @@ def make_module_translation_map(names: List[str]) -> Dict[str, str]: def candidate_suffixes(fullname: str) -> List[str]: - components = fullname.split('.') - result = [''] + components = fullname.split(".") + result = [""] for i in range(len(components)): - result.append('.'.join(components[-i - 1:]) + '.') + result.append(".".join(components[-i - 1 :]) + ".") return result diff --git a/mypyc/options.py b/mypyc/options.py index 94ef64cd0df79..bf8bacba9117b 100644 --- a/mypyc/options.py +++ b/mypyc/options.py @@ -1,22 +1,24 @@ -from typing import Optional, Tuple import sys +from typing import Optional, Tuple class CompilerOptions: - def __init__(self, - strip_asserts: bool = False, - multi_file: bool = False, - verbose: bool = False, - separate: bool = False, - target_dir: Optional[str] = None, - include_runtime_files: Optional[bool] = None, - capi_version: Optional[Tuple[int, int]] = None) -> None: + def __init__( + self, + strip_asserts: bool = False, + multi_file: bool = False, + verbose: bool = False, + separate: bool = False, + target_dir: Optional[str] = None, + include_runtime_files: Optional[bool] = None, + capi_version: Optional[Tuple[int, int]] = None, + ) -> None: self.strip_asserts = strip_asserts self.multi_file = multi_file self.verbose = verbose self.separate = separate self.global_opts = not separate - self.target_dir = target_dir or 'build' + self.target_dir = target_dir or "build" self.include_runtime_files = ( include_runtime_files if include_runtime_files is not None else not multi_file ) diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index 6ddb5e38111c6..4ff1ffa86760e 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -2,83 +2,98 @@ from mypyc.ir.ops import ERR_MAGIC from mypyc.ir.rtypes import ( - object_rprimitive, bytes_rprimitive, list_rprimitive, dict_rprimitive, - str_rprimitive, c_int_rprimitive, RUnion, c_pyssize_t_rprimitive, + RUnion, + bytes_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + dict_rprimitive, int_rprimitive, + list_rprimitive, + object_rprimitive, + str_rprimitive, ) from mypyc.primitives.registry import ( - load_address_op, function_op, method_op, binary_op, custom_op, ERR_NEG_INT + ERR_NEG_INT, + binary_op, + custom_op, + function_op, + load_address_op, + method_op, ) # Get the 'bytes' type object. -load_address_op( - name='builtins.bytes', - type=object_rprimitive, - src='PyBytes_Type') +load_address_op(name="builtins.bytes", type=object_rprimitive, src="PyBytes_Type") # bytes(obj) function_op( - name='builtins.bytes', + name="builtins.bytes", arg_types=[RUnion([list_rprimitive, dict_rprimitive, str_rprimitive])], return_type=bytes_rprimitive, - c_function_name='PyBytes_FromObject', - error_kind=ERR_MAGIC) + c_function_name="PyBytes_FromObject", + error_kind=ERR_MAGIC, +) # bytearray(obj) function_op( - name='builtins.bytearray', + name="builtins.bytearray", arg_types=[object_rprimitive], return_type=bytes_rprimitive, - c_function_name='PyByteArray_FromObject', - error_kind=ERR_MAGIC) + c_function_name="PyByteArray_FromObject", + error_kind=ERR_MAGIC, +) # bytes ==/!= (return -1/0/1) bytes_compare = custom_op( arg_types=[bytes_rprimitive, bytes_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyBytes_Compare', - error_kind=ERR_NEG_INT) + c_function_name="CPyBytes_Compare", + error_kind=ERR_NEG_INT, +) # bytes + bytes # bytearray + bytearray binary_op( - name='+', + name="+", arg_types=[bytes_rprimitive, bytes_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPyBytes_Concat', + c_function_name="CPyBytes_Concat", error_kind=ERR_MAGIC, - steals=[True, False]) + steals=[True, False], +) # bytes[begin:end] bytes_slice_op = custom_op( arg_types=[bytes_rprimitive, int_rprimitive, int_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPyBytes_GetSlice', - error_kind=ERR_MAGIC) + c_function_name="CPyBytes_GetSlice", + error_kind=ERR_MAGIC, +) # bytes[index] # bytearray[index] method_op( - name='__getitem__', + name="__getitem__", arg_types=[bytes_rprimitive, int_rprimitive], return_type=int_rprimitive, - c_function_name='CPyBytes_GetItem', - error_kind=ERR_MAGIC) + c_function_name="CPyBytes_GetItem", + error_kind=ERR_MAGIC, +) # bytes.join(obj) method_op( - name='join', + name="join", arg_types=[bytes_rprimitive, object_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPyBytes_Join', - error_kind=ERR_MAGIC) + c_function_name="CPyBytes_Join", + error_kind=ERR_MAGIC, +) # Join bytes objects and return a new bytes. # The first argument is the total number of the following bytes. bytes_build_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPyBytes_Build', + c_function_name="CPyBytes_Build", error_kind=ERR_MAGIC, - var_arg_type=bytes_rprimitive + var_arg_type=bytes_rprimitive, ) diff --git a/mypyc/primitives/dict_ops.py b/mypyc/primitives/dict_ops.py index c97d49d71d01b..b103cfad2621b 100644 --- a/mypyc/primitives/dict_ops.py +++ b/mypyc/primitives/dict_ops.py @@ -2,35 +2,42 @@ from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - dict_rprimitive, object_rprimitive, bool_rprimitive, int_rprimitive, - list_rprimitive, dict_next_rtuple_single, dict_next_rtuple_pair, c_pyssize_t_rprimitive, - c_int_rprimitive, bit_rprimitive + bit_rprimitive, + bool_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + dict_next_rtuple_pair, + dict_next_rtuple_single, + dict_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, ) - from mypyc.primitives.registry import ( - custom_op, method_op, function_op, binary_op, load_address_op, ERR_NEG_INT + ERR_NEG_INT, + binary_op, + custom_op, + function_op, + load_address_op, + method_op, ) # Get the 'dict' type object. -load_address_op( - name='builtins.dict', - type=object_rprimitive, - src='PyDict_Type') +load_address_op(name="builtins.dict", type=object_rprimitive, src="PyDict_Type") # Construct an empty dictionary via dict(). function_op( - name='builtins.dict', + name="builtins.dict", arg_types=[], return_type=dict_rprimitive, - c_function_name='PyDict_New', - error_kind=ERR_MAGIC) + c_function_name="PyDict_New", + error_kind=ERR_MAGIC, +) # Construct an empty dictionary. dict_new_op = custom_op( - arg_types=[], - return_type=dict_rprimitive, - c_function_name='PyDict_New', - error_kind=ERR_MAGIC) + arg_types=[], return_type=dict_rprimitive, c_function_name="PyDict_New", error_kind=ERR_MAGIC +) # Construct a dictionary from keys and values. # Positional argument is the number of key-value pairs @@ -38,109 +45,122 @@ dict_build_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=dict_rprimitive, - c_function_name='CPyDict_Build', + c_function_name="CPyDict_Build", error_kind=ERR_MAGIC, - var_arg_type=object_rprimitive) + var_arg_type=object_rprimitive, +) # Construct a dictionary from another dictionary. function_op( - name='builtins.dict', + name="builtins.dict", arg_types=[dict_rprimitive], return_type=dict_rprimitive, - c_function_name='PyDict_Copy', + c_function_name="PyDict_Copy", error_kind=ERR_MAGIC, - priority=2) + priority=2, +) # Generic one-argument dict constructor: dict(obj) function_op( - name='builtins.dict', + name="builtins.dict", arg_types=[object_rprimitive], return_type=dict_rprimitive, - c_function_name='CPyDict_FromAny', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_FromAny", + error_kind=ERR_MAGIC, +) # dict[key] dict_get_item_op = method_op( - name='__getitem__', + name="__getitem__", arg_types=[dict_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_GetItem', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_GetItem", + error_kind=ERR_MAGIC, +) # dict[key] = value dict_set_item_op = method_op( - name='__setitem__', + name="__setitem__", arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyDict_SetItem', - error_kind=ERR_NEG_INT) + c_function_name="CPyDict_SetItem", + error_kind=ERR_NEG_INT, +) # key in dict binary_op( - name='in', + name="in", arg_types=[object_rprimitive, dict_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyDict_Contains', + c_function_name="PyDict_Contains", error_kind=ERR_NEG_INT, truncated_type=bool_rprimitive, - ordering=[1, 0]) + ordering=[1, 0], +) # dict1.update(dict2) dict_update_op = method_op( - name='update', + name="update", arg_types=[dict_rprimitive, dict_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyDict_Update', + c_function_name="CPyDict_Update", error_kind=ERR_NEG_INT, - priority=2) + priority=2, +) # Operation used for **value in dict displays. # This is mostly like dict.update(obj), but has customized error handling. dict_update_in_display_op = custom_op( arg_types=[dict_rprimitive, dict_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyDict_UpdateInDisplay', - error_kind=ERR_NEG_INT) + c_function_name="CPyDict_UpdateInDisplay", + error_kind=ERR_NEG_INT, +) # dict.update(obj) method_op( - name='update', + name="update", arg_types=[dict_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyDict_UpdateFromAny', - error_kind=ERR_NEG_INT) + c_function_name="CPyDict_UpdateFromAny", + error_kind=ERR_NEG_INT, +) # dict.get(key, default) method_op( - name='get', + name="get", arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_Get', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_Get", + error_kind=ERR_MAGIC, +) # dict.get(key) dict_get_method_with_none = method_op( - name='get', + name="get", arg_types=[dict_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_GetWithNone', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_GetWithNone", + error_kind=ERR_MAGIC, +) # dict.setdefault(key, default) dict_setdefault_op = method_op( - name='setdefault', + name="setdefault", arg_types=[dict_rprimitive, object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_SetDefault', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_SetDefault", + error_kind=ERR_MAGIC, +) # dict.setdefault(key) method_op( - name='setdefault', + name="setdefault", arg_types=[dict_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_SetDefaultWithNone', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_SetDefaultWithNone", + error_kind=ERR_MAGIC, +) # dict.setdefault(key, empty tuple/list/set) # The third argument marks the data type of the second argument. @@ -149,116 +169,133 @@ dict_setdefault_spec_init_op = custom_op( arg_types=[dict_rprimitive, object_rprimitive, c_int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_SetDefaultWithEmptyDatatype', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_SetDefaultWithEmptyDatatype", + error_kind=ERR_MAGIC, +) # dict.keys() method_op( - name='keys', + name="keys", arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_KeysView', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_KeysView", + error_kind=ERR_MAGIC, +) # dict.values() method_op( - name='values', + name="values", arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_ValuesView', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_ValuesView", + error_kind=ERR_MAGIC, +) # dict.items() method_op( - name='items', + name="items", arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_ItemsView', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_ItemsView", + error_kind=ERR_MAGIC, +) # dict.clear() method_op( - name='clear', + name="clear", arg_types=[dict_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyDict_Clear', - error_kind=ERR_FALSE) + c_function_name="CPyDict_Clear", + error_kind=ERR_FALSE, +) # dict.copy() method_op( - name='copy', + name="copy", arg_types=[dict_rprimitive], return_type=dict_rprimitive, - c_function_name='CPyDict_Copy', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_Copy", + error_kind=ERR_MAGIC, +) # list(dict.keys()) dict_keys_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, - c_function_name='CPyDict_Keys', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_Keys", + error_kind=ERR_MAGIC, +) # list(dict.values()) dict_values_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, - c_function_name='CPyDict_Values', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_Values", + error_kind=ERR_MAGIC, +) # list(dict.items()) dict_items_op = custom_op( arg_types=[dict_rprimitive], return_type=list_rprimitive, - c_function_name='CPyDict_Items', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_Items", + error_kind=ERR_MAGIC, +) # PyDict_Next() fast iteration dict_key_iter_op = custom_op( arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_GetKeysIter', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_GetKeysIter", + error_kind=ERR_MAGIC, +) dict_value_iter_op = custom_op( arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_GetValuesIter', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_GetValuesIter", + error_kind=ERR_MAGIC, +) dict_item_iter_op = custom_op( arg_types=[dict_rprimitive], return_type=object_rprimitive, - c_function_name='CPyDict_GetItemsIter', - error_kind=ERR_MAGIC) + c_function_name="CPyDict_GetItemsIter", + error_kind=ERR_MAGIC, +) dict_next_key_op = custom_op( arg_types=[object_rprimitive, int_rprimitive], return_type=dict_next_rtuple_single, - c_function_name='CPyDict_NextKey', - error_kind=ERR_NEVER) + c_function_name="CPyDict_NextKey", + error_kind=ERR_NEVER, +) dict_next_value_op = custom_op( arg_types=[object_rprimitive, int_rprimitive], return_type=dict_next_rtuple_single, - c_function_name='CPyDict_NextValue', - error_kind=ERR_NEVER) + c_function_name="CPyDict_NextValue", + error_kind=ERR_NEVER, +) dict_next_item_op = custom_op( arg_types=[object_rprimitive, int_rprimitive], return_type=dict_next_rtuple_pair, - c_function_name='CPyDict_NextItem', - error_kind=ERR_NEVER) + c_function_name="CPyDict_NextItem", + error_kind=ERR_NEVER, +) # check that len(dict) == const during iteration dict_check_size_op = custom_op( arg_types=[dict_rprimitive, int_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyDict_CheckSize', - error_kind=ERR_FALSE) + c_function_name="CPyDict_CheckSize", + error_kind=ERR_FALSE, +) dict_ssize_t_size_op = custom_op( arg_types=[dict_rprimitive], return_type=c_pyssize_t_rprimitive, - c_function_name='PyDict_Size', - error_kind=ERR_NEVER) + c_function_name="PyDict_Size", + error_kind=ERR_NEVER, +) diff --git a/mypyc/primitives/exc_ops.py b/mypyc/primitives/exc_ops.py index 99d57ff3aaa97..7fe21d2592abc 100644 --- a/mypyc/primitives/exc_ops.py +++ b/mypyc/primitives/exc_ops.py @@ -1,7 +1,7 @@ """Exception-related primitive ops.""" -from mypyc.ir.ops import ERR_NEVER, ERR_FALSE, ERR_ALWAYS -from mypyc.ir.rtypes import object_rprimitive, void_rtype, exc_rtuple, bit_rprimitive +from mypyc.ir.ops import ERR_ALWAYS, ERR_FALSE, ERR_NEVER +from mypyc.ir.rtypes import bit_rprimitive, exc_rtuple, object_rprimitive, void_rtype from mypyc.primitives.registry import custom_op # If the argument is a class, raise an instance of the class. Otherwise, assume @@ -9,88 +9,91 @@ raise_exception_op = custom_op( arg_types=[object_rprimitive], return_type=void_rtype, - c_function_name='CPy_Raise', - error_kind=ERR_ALWAYS) + c_function_name="CPy_Raise", + error_kind=ERR_ALWAYS, +) # Raise StopIteration exception with the specified value (which can be NULL). set_stop_iteration_value = custom_op( arg_types=[object_rprimitive], return_type=void_rtype, - c_function_name='CPyGen_SetStopIterationValue', - error_kind=ERR_ALWAYS) + c_function_name="CPyGen_SetStopIterationValue", + error_kind=ERR_ALWAYS, +) # Raise exception with traceback. # Arguments are (exception type, exception value, traceback). raise_exception_with_tb_op = custom_op( arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=void_rtype, - c_function_name='CPyErr_SetObjectAndTraceback', - error_kind=ERR_ALWAYS) + c_function_name="CPyErr_SetObjectAndTraceback", + error_kind=ERR_ALWAYS, +) # Reraise the currently raised exception. reraise_exception_op = custom_op( - arg_types=[], - return_type=void_rtype, - c_function_name='CPy_Reraise', - error_kind=ERR_ALWAYS) + arg_types=[], return_type=void_rtype, c_function_name="CPy_Reraise", error_kind=ERR_ALWAYS +) # Propagate exception if the CPython error indicator is set (an exception was raised). no_err_occurred_op = custom_op( arg_types=[], return_type=bit_rprimitive, - c_function_name='CPy_NoErrOccured', - error_kind=ERR_FALSE) + c_function_name="CPy_NoErrOccured", + error_kind=ERR_FALSE, +) err_occurred_op = custom_op( arg_types=[], return_type=object_rprimitive, - c_function_name='PyErr_Occurred', + c_function_name="PyErr_Occurred", error_kind=ERR_NEVER, - is_borrowed=True) + is_borrowed=True, +) # Keep propagating a raised exception by unconditionally giving an error value. # This doesn't actually raise an exception. keep_propagating_op = custom_op( arg_types=[], return_type=bit_rprimitive, - c_function_name='CPy_KeepPropagating', - error_kind=ERR_FALSE) + c_function_name="CPy_KeepPropagating", + error_kind=ERR_FALSE, +) # Catches a propagating exception and makes it the "currently # handled exception" (by sticking it into sys.exc_info()). Returns the # exception that was previously being handled, which must be restored # later. error_catch_op = custom_op( - arg_types=[], - return_type=exc_rtuple, - c_function_name='CPy_CatchError', - error_kind=ERR_NEVER) + arg_types=[], return_type=exc_rtuple, c_function_name="CPy_CatchError", error_kind=ERR_NEVER +) # Restore an old "currently handled exception" returned from. # error_catch (by sticking it into sys.exc_info()) restore_exc_info_op = custom_op( arg_types=[exc_rtuple], return_type=void_rtype, - c_function_name='CPy_RestoreExcInfo', - error_kind=ERR_NEVER) + c_function_name="CPy_RestoreExcInfo", + error_kind=ERR_NEVER, +) # Checks whether the exception currently being handled matches a particular type. exc_matches_op = custom_op( arg_types=[object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPy_ExceptionMatches', - error_kind=ERR_NEVER) + c_function_name="CPy_ExceptionMatches", + error_kind=ERR_NEVER, +) # Get the value of the exception currently being handled. get_exc_value_op = custom_op( arg_types=[], return_type=object_rprimitive, - c_function_name='CPy_GetExcValue', - error_kind=ERR_NEVER) + c_function_name="CPy_GetExcValue", + error_kind=ERR_NEVER, +) # Get exception info (exception type, exception instance, traceback object). get_exc_info_op = custom_op( - arg_types=[], - return_type=exc_rtuple, - c_function_name='CPy_GetExcInfo', - error_kind=ERR_NEVER) + arg_types=[], return_type=exc_rtuple, c_function_name="CPy_GetExcInfo", error_kind=ERR_NEVER +) diff --git a/mypyc/primitives/float_ops.py b/mypyc/primitives/float_ops.py index 3359cf6fe122f..f8b0855483fe8 100644 --- a/mypyc/primitives/float_ops.py +++ b/mypyc/primitives/float_ops.py @@ -1,31 +1,26 @@ """Primitive float ops.""" from mypyc.ir.ops import ERR_MAGIC -from mypyc.ir.rtypes import ( - str_rprimitive, float_rprimitive, object_rprimitive -) -from mypyc.primitives.registry import ( - load_address_op, function_op -) +from mypyc.ir.rtypes import float_rprimitive, object_rprimitive, str_rprimitive +from mypyc.primitives.registry import function_op, load_address_op # Get the 'builtins.float' type object. -load_address_op( - name='builtins.float', - type=object_rprimitive, - src='PyFloat_Type') +load_address_op(name="builtins.float", type=object_rprimitive, src="PyFloat_Type") # float(str) function_op( - name='builtins.float', + name="builtins.float", arg_types=[str_rprimitive], return_type=float_rprimitive, - c_function_name='PyFloat_FromString', - error_kind=ERR_MAGIC) + c_function_name="PyFloat_FromString", + error_kind=ERR_MAGIC, +) # abs(float) function_op( - name='builtins.abs', + name="builtins.abs", arg_types=[float_rprimitive], return_type=float_rprimitive, - c_function_name='PyNumber_Absolute', - error_kind=ERR_MAGIC) + c_function_name="PyNumber_Absolute", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/primitives/generic_ops.py b/mypyc/primitives/generic_ops.py index 402de4524b881..4f2bfec002c39 100644 --- a/mypyc/primitives/generic_ops.py +++ b/mypyc/primitives/generic_ops.py @@ -9,261 +9,316 @@ check that the priorities are configured properly. """ -from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC +from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - object_rprimitive, int_rprimitive, bool_rprimitive, c_int_rprimitive, pointer_rprimitive, - object_pointer_rprimitive, c_size_t_rprimitive, c_pyssize_t_rprimitive + bool_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + c_size_t_rprimitive, + int_rprimitive, + object_pointer_rprimitive, + object_rprimitive, + pointer_rprimitive, ) from mypyc.primitives.registry import ( - binary_op, unary_op, method_op, function_op, custom_op, ERR_NEG_INT + ERR_NEG_INT, + binary_op, + custom_op, + function_op, + method_op, + unary_op, ) - # Binary operations -for op, opid in [('==', 2), # PY_EQ - ('!=', 3), # PY_NE - ('<', 0), # PY_LT - ('<=', 1), # PY_LE - ('>', 4), # PY_GT - ('>=', 5)]: # PY_GE +for op, opid in [ + ("==", 2), # PY_EQ + ("!=", 3), # PY_NE + ("<", 0), # PY_LT + ("<=", 1), # PY_LE + (">", 4), # PY_GT + (">=", 5), +]: # PY_GE # The result type is 'object' since that's what PyObject_RichCompare returns. - binary_op(name=op, - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - c_function_name='PyObject_RichCompare', - error_kind=ERR_MAGIC, - extra_int_constants=[(opid, c_int_rprimitive)], - priority=0) + binary_op( + name=op, + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyObject_RichCompare", + error_kind=ERR_MAGIC, + extra_int_constants=[(opid, c_int_rprimitive)], + priority=0, + ) -for op, funcname in [('+', 'PyNumber_Add'), - ('-', 'PyNumber_Subtract'), - ('*', 'PyNumber_Multiply'), - ('//', 'PyNumber_FloorDivide'), - ('/', 'PyNumber_TrueDivide'), - ('%', 'PyNumber_Remainder'), - ('<<', 'PyNumber_Lshift'), - ('>>', 'PyNumber_Rshift'), - ('&', 'PyNumber_And'), - ('^', 'PyNumber_Xor'), - ('|', 'PyNumber_Or'), - ('@', 'PyNumber_MatrixMultiply')]: - binary_op(name=op, - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - c_function_name=funcname, - error_kind=ERR_MAGIC, - priority=0) +for op, funcname in [ + ("+", "PyNumber_Add"), + ("-", "PyNumber_Subtract"), + ("*", "PyNumber_Multiply"), + ("//", "PyNumber_FloorDivide"), + ("/", "PyNumber_TrueDivide"), + ("%", "PyNumber_Remainder"), + ("<<", "PyNumber_Lshift"), + (">>", "PyNumber_Rshift"), + ("&", "PyNumber_And"), + ("^", "PyNumber_Xor"), + ("|", "PyNumber_Or"), + ("@", "PyNumber_MatrixMultiply"), +]: + binary_op( + name=op, + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name=funcname, + error_kind=ERR_MAGIC, + priority=0, + ) -for op, funcname in [('+=', 'PyNumber_InPlaceAdd'), - ('-=', 'PyNumber_InPlaceSubtract'), - ('*=', 'PyNumber_InPlaceMultiply'), - ('@=', 'PyNumber_InPlaceMatrixMultiply'), - ('//=', 'PyNumber_InPlaceFloorDivide'), - ('/=', 'PyNumber_InPlaceTrueDivide'), - ('%=', 'PyNumber_InPlaceRemainder'), - ('<<=', 'PyNumber_InPlaceLshift'), - ('>>=', 'PyNumber_InPlaceRshift'), - ('&=', 'PyNumber_InPlaceAnd'), - ('^=', 'PyNumber_InPlaceXor'), - ('|=', 'PyNumber_InPlaceOr')]: - binary_op(name=op, - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - c_function_name=funcname, - error_kind=ERR_MAGIC, - priority=0) +for op, funcname in [ + ("+=", "PyNumber_InPlaceAdd"), + ("-=", "PyNumber_InPlaceSubtract"), + ("*=", "PyNumber_InPlaceMultiply"), + ("@=", "PyNumber_InPlaceMatrixMultiply"), + ("//=", "PyNumber_InPlaceFloorDivide"), + ("/=", "PyNumber_InPlaceTrueDivide"), + ("%=", "PyNumber_InPlaceRemainder"), + ("<<=", "PyNumber_InPlaceLshift"), + (">>=", "PyNumber_InPlaceRshift"), + ("&=", "PyNumber_InPlaceAnd"), + ("^=", "PyNumber_InPlaceXor"), + ("|=", "PyNumber_InPlaceOr"), +]: + binary_op( + name=op, + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name=funcname, + error_kind=ERR_MAGIC, + priority=0, + ) -binary_op(name='**', - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - error_kind=ERR_MAGIC, - c_function_name='CPyNumber_Power', - priority=0) +binary_op( + name="**", + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + error_kind=ERR_MAGIC, + c_function_name="CPyNumber_Power", + priority=0, +) binary_op( - name='in', + name="in", arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PySequence_Contains', + c_function_name="PySequence_Contains", error_kind=ERR_NEG_INT, truncated_type=bool_rprimitive, ordering=[1, 0], - priority=0) + priority=0, +) # Unary operations -for op, funcname in [('-', 'PyNumber_Negative'), - ('+', 'PyNumber_Positive'), - ('~', 'PyNumber_Invert')]: - unary_op(name=op, - arg_type=object_rprimitive, - return_type=object_rprimitive, - c_function_name=funcname, - error_kind=ERR_MAGIC, - priority=0) +for op, funcname in [ + ("-", "PyNumber_Negative"), + ("+", "PyNumber_Positive"), + ("~", "PyNumber_Invert"), +]: + unary_op( + name=op, + arg_type=object_rprimitive, + return_type=object_rprimitive, + c_function_name=funcname, + error_kind=ERR_MAGIC, + priority=0, + ) unary_op( - name='not', + name="not", arg_type=object_rprimitive, return_type=c_int_rprimitive, - c_function_name='PyObject_Not', + c_function_name="PyObject_Not", error_kind=ERR_NEG_INT, truncated_type=bool_rprimitive, - priority=0) + priority=0, +) # obj1[obj2] -method_op(name='__getitem__', - arg_types=[object_rprimitive, object_rprimitive], - return_type=object_rprimitive, - c_function_name='PyObject_GetItem', - error_kind=ERR_MAGIC, - priority=0) +method_op( + name="__getitem__", + arg_types=[object_rprimitive, object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyObject_GetItem", + error_kind=ERR_MAGIC, + priority=0, +) # obj1[obj2] = obj3 method_op( - name='__setitem__', + name="__setitem__", arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_SetItem', + c_function_name="PyObject_SetItem", error_kind=ERR_NEG_INT, - priority=0) + priority=0, +) # del obj1[obj2] method_op( - name='__delitem__', + name="__delitem__", arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_DelItem', + c_function_name="PyObject_DelItem", error_kind=ERR_NEG_INT, - priority=0) + priority=0, +) # hash(obj) function_op( - name='builtins.hash', + name="builtins.hash", arg_types=[object_rprimitive], return_type=int_rprimitive, - c_function_name='CPyObject_Hash', - error_kind=ERR_MAGIC) + c_function_name="CPyObject_Hash", + error_kind=ERR_MAGIC, +) # getattr(obj, attr) py_getattr_op = function_op( - name='builtins.getattr', + name="builtins.getattr", arg_types=[object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyObject_GetAttr', - error_kind=ERR_MAGIC) + c_function_name="CPyObject_GetAttr", + error_kind=ERR_MAGIC, +) # getattr(obj, attr, default) function_op( - name='builtins.getattr', + name="builtins.getattr", arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyObject_GetAttr3', - error_kind=ERR_MAGIC) + c_function_name="CPyObject_GetAttr3", + error_kind=ERR_MAGIC, +) # setattr(obj, attr, value) py_setattr_op = function_op( - name='builtins.setattr', + name="builtins.setattr", arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_SetAttr', - error_kind=ERR_NEG_INT) + c_function_name="PyObject_SetAttr", + error_kind=ERR_NEG_INT, +) # hasattr(obj, attr) py_hasattr_op = function_op( - name='builtins.hasattr', + name="builtins.hasattr", arg_types=[object_rprimitive, object_rprimitive], return_type=bool_rprimitive, - c_function_name='PyObject_HasAttr', - error_kind=ERR_NEVER) + c_function_name="PyObject_HasAttr", + error_kind=ERR_NEVER, +) # del obj.attr py_delattr_op = function_op( - name='builtins.delattr', + name="builtins.delattr", arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_DelAttr', - error_kind=ERR_NEG_INT) + c_function_name="PyObject_DelAttr", + error_kind=ERR_NEG_INT, +) # Call callable object with N positional arguments: func(arg1, ..., argN) # Arguments are (func, arg1, ..., argN). py_call_op = custom_op( arg_types=[], return_type=object_rprimitive, - c_function_name='PyObject_CallFunctionObjArgs', + c_function_name="PyObject_CallFunctionObjArgs", error_kind=ERR_MAGIC, var_arg_type=object_rprimitive, - extra_int_constants=[(0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive)], +) # Call callable object using positional and/or keyword arguments (Python 3.8+) py_vectorcall_op = custom_op( - arg_types=[object_rprimitive, # Callable - object_pointer_rprimitive, # Args (PyObject **) - c_size_t_rprimitive, # Number of positional args - object_rprimitive], # Keyword arg names tuple (or NULL) + arg_types=[ + object_rprimitive, # Callable + object_pointer_rprimitive, # Args (PyObject **) + c_size_t_rprimitive, # Number of positional args + object_rprimitive, + ], # Keyword arg names tuple (or NULL) return_type=object_rprimitive, - c_function_name='_PyObject_Vectorcall', - error_kind=ERR_MAGIC) + c_function_name="_PyObject_Vectorcall", + error_kind=ERR_MAGIC, +) # Call method using positional and/or keyword arguments (Python 3.9+) py_vectorcall_method_op = custom_op( - arg_types=[object_rprimitive, # Method name - object_pointer_rprimitive, # Args, including self (PyObject **) - c_size_t_rprimitive, # Number of positional args, including self - object_rprimitive], # Keyword arg names tuple (or NULL) + arg_types=[ + object_rprimitive, # Method name + object_pointer_rprimitive, # Args, including self (PyObject **) + c_size_t_rprimitive, # Number of positional args, including self + object_rprimitive, + ], # Keyword arg names tuple (or NULL) return_type=object_rprimitive, - c_function_name='PyObject_VectorcallMethod', - error_kind=ERR_MAGIC) + c_function_name="PyObject_VectorcallMethod", + error_kind=ERR_MAGIC, +) # Call callable object with positional + keyword args: func(*args, **kwargs) # Arguments are (func, *args tuple, **kwargs dict). py_call_with_kwargs_op = custom_op( arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='PyObject_Call', - error_kind=ERR_MAGIC) + c_function_name="PyObject_Call", + error_kind=ERR_MAGIC, +) # Call method with positional arguments: obj.method(arg1, ...) # Arguments are (object, attribute name, arg1, ...). py_method_call_op = custom_op( arg_types=[], return_type=object_rprimitive, - c_function_name='CPyObject_CallMethodObjArgs', + c_function_name="CPyObject_CallMethodObjArgs", error_kind=ERR_MAGIC, var_arg_type=object_rprimitive, - extra_int_constants=[(0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive)], +) # len(obj) generic_len_op = custom_op( arg_types=[object_rprimitive], return_type=int_rprimitive, - c_function_name='CPyObject_Size', - error_kind=ERR_MAGIC) + c_function_name="CPyObject_Size", + error_kind=ERR_MAGIC, +) # len(obj) # same as generic_len_op, however return py_ssize_t generic_ssize_t_len_op = custom_op( arg_types=[object_rprimitive], return_type=c_pyssize_t_rprimitive, - c_function_name='PyObject_Size', - error_kind=ERR_NEG_INT) + c_function_name="PyObject_Size", + error_kind=ERR_NEG_INT, +) # iter(obj) -iter_op = function_op(name='builtins.iter', - arg_types=[object_rprimitive], - return_type=object_rprimitive, - c_function_name='PyObject_GetIter', - error_kind=ERR_MAGIC) +iter_op = function_op( + name="builtins.iter", + arg_types=[object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyObject_GetIter", + error_kind=ERR_MAGIC, +) # next(iterator) # # Although the error_kind is set to be ERR_NEVER, this can actually # return NULL, and thus it must be checked using Branch.IS_ERROR. -next_op = custom_op(arg_types=[object_rprimitive], - return_type=object_rprimitive, - c_function_name='PyIter_Next', - error_kind=ERR_NEVER) +next_op = custom_op( + arg_types=[object_rprimitive], + return_type=object_rprimitive, + c_function_name="PyIter_Next", + error_kind=ERR_NEVER, +) # next(iterator) # # Do a next, don't swallow StopIteration, but also don't propagate an @@ -271,7 +326,9 @@ # represent an implicit StopIteration, but if StopIteration is # *explicitly* raised this will not swallow it.) # Can return NULL: see next_op. -next_raw_op = custom_op(arg_types=[object_rprimitive], - return_type=object_rprimitive, - c_function_name='CPyIter_Next', - error_kind=ERR_NEVER) +next_raw_op = custom_op( + arg_types=[object_rprimitive], + return_type=object_rprimitive, + c_function_name="CPyIter_Next", + error_kind=ERR_NEVER, +) diff --git a/mypyc/primitives/int_ops.py b/mypyc/primitives/int_ops.py index ad33de059f026..e84415d1ae164 100644 --- a/mypyc/primitives/int_ops.py +++ b/mypyc/primitives/int_ops.py @@ -9,14 +9,28 @@ """ from typing import Dict, NamedTuple -from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_ALWAYS, ComparisonOp + +from mypyc.ir.ops import ERR_ALWAYS, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_NEVER, ComparisonOp from mypyc.ir.rtypes import ( - int_rprimitive, bool_rprimitive, float_rprimitive, object_rprimitive, - str_rprimitive, bit_rprimitive, int64_rprimitive, int32_rprimitive, void_rtype, RType, - c_pyssize_t_rprimitive + RType, + bit_rprimitive, + bool_rprimitive, + c_pyssize_t_rprimitive, + float_rprimitive, + int32_rprimitive, + int64_rprimitive, + int_rprimitive, + object_rprimitive, + str_rprimitive, + void_rtype, ) from mypyc.primitives.registry import ( - load_address_op, unary_op, CFunctionDescription, function_op, binary_op, custom_op + CFunctionDescription, + binary_op, + custom_op, + function_op, + load_address_op, + unary_op, ) # These int constructors produce object_rprimitives that then need to be unboxed @@ -24,106 +38,115 @@ # Get the type object for 'builtins.int'. # For ordinary calls to int() we use a load_address to the type -load_address_op( - name='builtins.int', - type=object_rprimitive, - src='PyLong_Type') +load_address_op(name="builtins.int", type=object_rprimitive, src="PyLong_Type") # int(float). We could do a bit better directly. function_op( - name='builtins.int', + name="builtins.int", arg_types=[float_rprimitive], return_type=object_rprimitive, - c_function_name='CPyLong_FromFloat', - error_kind=ERR_MAGIC) + c_function_name="CPyLong_FromFloat", + error_kind=ERR_MAGIC, +) # int(string) function_op( - name='builtins.int', + name="builtins.int", arg_types=[str_rprimitive], return_type=object_rprimitive, - c_function_name='CPyLong_FromStr', - error_kind=ERR_MAGIC) + c_function_name="CPyLong_FromStr", + error_kind=ERR_MAGIC, +) # int(string, base) function_op( - name='builtins.int', + name="builtins.int", arg_types=[str_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyLong_FromStrWithBase', - error_kind=ERR_MAGIC) + c_function_name="CPyLong_FromStrWithBase", + error_kind=ERR_MAGIC, +) # str(int) int_to_str_op = function_op( - name='builtins.str', + name="builtins.str", arg_types=[int_rprimitive], return_type=str_rprimitive, - c_function_name='CPyTagged_Str', + c_function_name="CPyTagged_Str", error_kind=ERR_MAGIC, - priority=2) + priority=2, +) # We need a specialization for str on bools also since the int one is wrong... function_op( - name='builtins.str', + name="builtins.str", arg_types=[bool_rprimitive], return_type=str_rprimitive, - c_function_name='CPyBool_Str', + c_function_name="CPyBool_Str", error_kind=ERR_MAGIC, - priority=3) + priority=3, +) -def int_binary_op(name: str, c_function_name: str, - return_type: RType = int_rprimitive, - error_kind: int = ERR_NEVER) -> None: - binary_op(name=name, - arg_types=[int_rprimitive, int_rprimitive], - return_type=return_type, - c_function_name=c_function_name, - error_kind=error_kind) +def int_binary_op( + name: str, + c_function_name: str, + return_type: RType = int_rprimitive, + error_kind: int = ERR_NEVER, +) -> None: + binary_op( + name=name, + arg_types=[int_rprimitive, int_rprimitive], + return_type=return_type, + c_function_name=c_function_name, + error_kind=error_kind, + ) # Binary, unary and augmented assignment operations that operate on CPyTagged ints # are implemented as C functions. -int_binary_op('+', 'CPyTagged_Add') -int_binary_op('-', 'CPyTagged_Subtract') -int_binary_op('*', 'CPyTagged_Multiply') -int_binary_op('&', 'CPyTagged_And') -int_binary_op('|', 'CPyTagged_Or') -int_binary_op('^', 'CPyTagged_Xor') +int_binary_op("+", "CPyTagged_Add") +int_binary_op("-", "CPyTagged_Subtract") +int_binary_op("*", "CPyTagged_Multiply") +int_binary_op("&", "CPyTagged_And") +int_binary_op("|", "CPyTagged_Or") +int_binary_op("^", "CPyTagged_Xor") # Divide and remainder we honestly propagate errors from because they # can raise ZeroDivisionError -int_binary_op('//', 'CPyTagged_FloorDivide', error_kind=ERR_MAGIC) -int_binary_op('%', 'CPyTagged_Remainder', error_kind=ERR_MAGIC) +int_binary_op("//", "CPyTagged_FloorDivide", error_kind=ERR_MAGIC) +int_binary_op("%", "CPyTagged_Remainder", error_kind=ERR_MAGIC) # Negative shift counts raise an exception -int_binary_op('>>', 'CPyTagged_Rshift', error_kind=ERR_MAGIC) -int_binary_op('<<', 'CPyTagged_Lshift', error_kind=ERR_MAGIC) +int_binary_op(">>", "CPyTagged_Rshift", error_kind=ERR_MAGIC) +int_binary_op("<<", "CPyTagged_Lshift", error_kind=ERR_MAGIC) # This should work because assignment operators are parsed differently # and the code in irbuild that handles it does the assignment # regardless of whether or not the operator works in place anyway. -int_binary_op('+=', 'CPyTagged_Add') -int_binary_op('-=', 'CPyTagged_Subtract') -int_binary_op('*=', 'CPyTagged_Multiply') -int_binary_op('&=', 'CPyTagged_And') -int_binary_op('|=', 'CPyTagged_Or') -int_binary_op('^=', 'CPyTagged_Xor') -int_binary_op('//=', 'CPyTagged_FloorDivide', error_kind=ERR_MAGIC) -int_binary_op('%=', 'CPyTagged_Remainder', error_kind=ERR_MAGIC) -int_binary_op('>>=', 'CPyTagged_Rshift', error_kind=ERR_MAGIC) -int_binary_op('<<=', 'CPyTagged_Lshift', error_kind=ERR_MAGIC) +int_binary_op("+=", "CPyTagged_Add") +int_binary_op("-=", "CPyTagged_Subtract") +int_binary_op("*=", "CPyTagged_Multiply") +int_binary_op("&=", "CPyTagged_And") +int_binary_op("|=", "CPyTagged_Or") +int_binary_op("^=", "CPyTagged_Xor") +int_binary_op("//=", "CPyTagged_FloorDivide", error_kind=ERR_MAGIC) +int_binary_op("%=", "CPyTagged_Remainder", error_kind=ERR_MAGIC) +int_binary_op(">>=", "CPyTagged_Rshift", error_kind=ERR_MAGIC) +int_binary_op("<<=", "CPyTagged_Lshift", error_kind=ERR_MAGIC) def int_unary_op(name: str, c_function_name: str) -> CFunctionDescription: - return unary_op(name=name, - arg_type=int_rprimitive, - return_type=int_rprimitive, - c_function_name=c_function_name, - error_kind=ERR_NEVER) + return unary_op( + name=name, + arg_type=int_rprimitive, + return_type=int_rprimitive, + c_function_name=c_function_name, + error_kind=ERR_NEVER, + ) -int_neg_op = int_unary_op('-', 'CPyTagged_Negate') -int_invert_op = int_unary_op('~', 'CPyTagged_Invert') +int_neg_op = int_unary_op("-", "CPyTagged_Negate") +int_invert_op = int_unary_op("~", "CPyTagged_Invert") # Primitives related to integer comparison operations: @@ -136,89 +159,104 @@ def int_unary_op(name: str, c_function_name: str) -> CFunctionDescription: # c_func_negated: whether to negate the C function call's result # c_func_swap_operands: whether to swap lhs and rhs when call the function IntComparisonOpDescription = NamedTuple( - 'IntComparisonOpDescription', [('binary_op_variant', int), - ('c_func_description', CFunctionDescription), - ('c_func_negated', bool), - ('c_func_swap_operands', bool)]) + "IntComparisonOpDescription", + [ + ("binary_op_variant", int), + ("c_func_description", CFunctionDescription), + ("c_func_negated", bool), + ("c_func_swap_operands", bool), + ], +) # Equals operation on two boxed tagged integers int_equal_ = custom_op( arg_types=[int_rprimitive, int_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyTagged_IsEq_', - error_kind=ERR_NEVER) + c_function_name="CPyTagged_IsEq_", + error_kind=ERR_NEVER, +) # Less than operation on two boxed tagged integers int_less_than_ = custom_op( arg_types=[int_rprimitive, int_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyTagged_IsLt_', - error_kind=ERR_NEVER) + c_function_name="CPyTagged_IsLt_", + error_kind=ERR_NEVER, +) # Provide mapping from textual op to short int's op variant and boxed int's description. # Note that these are not complete implementations and require extra IR. int_comparison_op_mapping: Dict[str, IntComparisonOpDescription] = { - '==': IntComparisonOpDescription(ComparisonOp.EQ, int_equal_, False, False), - '!=': IntComparisonOpDescription(ComparisonOp.NEQ, int_equal_, True, False), - '<': IntComparisonOpDescription(ComparisonOp.SLT, int_less_than_, False, False), - '<=': IntComparisonOpDescription(ComparisonOp.SLE, int_less_than_, True, True), - '>': IntComparisonOpDescription(ComparisonOp.SGT, int_less_than_, False, True), - '>=': IntComparisonOpDescription(ComparisonOp.SGE, int_less_than_, True, False), + "==": IntComparisonOpDescription(ComparisonOp.EQ, int_equal_, False, False), + "!=": IntComparisonOpDescription(ComparisonOp.NEQ, int_equal_, True, False), + "<": IntComparisonOpDescription(ComparisonOp.SLT, int_less_than_, False, False), + "<=": IntComparisonOpDescription(ComparisonOp.SLE, int_less_than_, True, True), + ">": IntComparisonOpDescription(ComparisonOp.SGT, int_less_than_, False, True), + ">=": IntComparisonOpDescription(ComparisonOp.SGE, int_less_than_, True, False), } int64_divide_op = custom_op( arg_types=[int64_rprimitive, int64_rprimitive], return_type=int64_rprimitive, - c_function_name='CPyInt64_Divide', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyInt64_Divide", + error_kind=ERR_MAGIC_OVERLAPPING, +) int64_mod_op = custom_op( arg_types=[int64_rprimitive, int64_rprimitive], return_type=int64_rprimitive, - c_function_name='CPyInt64_Remainder', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyInt64_Remainder", + error_kind=ERR_MAGIC_OVERLAPPING, +) int32_divide_op = custom_op( arg_types=[int32_rprimitive, int32_rprimitive], return_type=int32_rprimitive, - c_function_name='CPyInt32_Divide', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyInt32_Divide", + error_kind=ERR_MAGIC_OVERLAPPING, +) int32_mod_op = custom_op( arg_types=[int32_rprimitive, int32_rprimitive], return_type=int32_rprimitive, - c_function_name='CPyInt32_Remainder', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyInt32_Remainder", + error_kind=ERR_MAGIC_OVERLAPPING, +) # Convert tagged int (as PyObject *) to i64 int_to_int64_op = custom_op( arg_types=[object_rprimitive], return_type=int64_rprimitive, - c_function_name='CPyLong_AsInt64', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyLong_AsInt64", + error_kind=ERR_MAGIC_OVERLAPPING, +) ssize_t_to_int_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=int_rprimitive, - c_function_name='CPyTagged_FromSsize_t', - error_kind=ERR_MAGIC) + c_function_name="CPyTagged_FromSsize_t", + error_kind=ERR_MAGIC, +) int64_to_int_op = custom_op( arg_types=[int64_rprimitive], return_type=int_rprimitive, - c_function_name='CPyTagged_FromInt64', - error_kind=ERR_MAGIC) + c_function_name="CPyTagged_FromInt64", + error_kind=ERR_MAGIC, +) # Convert tagged int (as PyObject *) to i32 int_to_int32_op = custom_op( arg_types=[object_rprimitive], return_type=int32_rprimitive, - c_function_name='CPyLong_AsInt32', - error_kind=ERR_MAGIC_OVERLAPPING) + c_function_name="CPyLong_AsInt32", + error_kind=ERR_MAGIC_OVERLAPPING, +) int32_overflow = custom_op( arg_types=[], return_type=void_rtype, - c_function_name='CPyInt32_Overflow', - error_kind=ERR_ALWAYS) + c_function_name="CPyInt32_Overflow", + error_kind=ERR_ALWAYS, +) diff --git a/mypyc/primitives/list_ops.py b/mypyc/primitives/list_ops.py index 2bba4207cd27b..b66504d39d383 100644 --- a/mypyc/primitives/list_ops.py +++ b/mypyc/primitives/list_ops.py @@ -1,243 +1,277 @@ """List primitive ops.""" -from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER, ERR_FALSE +from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - int_rprimitive, short_int_rprimitive, list_rprimitive, object_rprimitive, c_int_rprimitive, - c_pyssize_t_rprimitive, bit_rprimitive, int64_rprimitive + bit_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + int64_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, + short_int_rprimitive, ) from mypyc.primitives.registry import ( - load_address_op, function_op, binary_op, method_op, custom_op, ERR_NEG_INT + ERR_NEG_INT, + binary_op, + custom_op, + function_op, + load_address_op, + method_op, ) - # Get the 'builtins.list' type object. -load_address_op( - name='builtins.list', - type=object_rprimitive, - src='PyList_Type') +load_address_op(name="builtins.list", type=object_rprimitive, src="PyList_Type") # list(obj) to_list = function_op( - name='builtins.list', + name="builtins.list", arg_types=[object_rprimitive], return_type=list_rprimitive, - c_function_name='PySequence_List', - error_kind=ERR_MAGIC) + c_function_name="PySequence_List", + error_kind=ERR_MAGIC, +) # Construct an empty list via list(). function_op( - name='builtins.list', + name="builtins.list", arg_types=[], return_type=list_rprimitive, - c_function_name='PyList_New', + c_function_name="PyList_New", error_kind=ERR_MAGIC, - extra_int_constants=[(0, int_rprimitive)]) + extra_int_constants=[(0, int_rprimitive)], +) new_list_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=list_rprimitive, - c_function_name='PyList_New', - error_kind=ERR_MAGIC) + c_function_name="PyList_New", + error_kind=ERR_MAGIC, +) list_build_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=list_rprimitive, - c_function_name='CPyList_Build', + c_function_name="CPyList_Build", error_kind=ERR_MAGIC, var_arg_type=object_rprimitive, - steals=True) + steals=True, +) # list[index] (for an integer index) list_get_item_op = method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItem', - error_kind=ERR_MAGIC) + c_function_name="CPyList_GetItem", + error_kind=ERR_MAGIC, +) # list[index] version with no int tag check for when it is known to be short method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, short_int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemShort', + c_function_name="CPyList_GetItemShort", error_kind=ERR_MAGIC, - priority=2) + priority=2, +) # list[index] that produces a borrowed result method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemBorrow', + c_function_name="CPyList_GetItemBorrow", error_kind=ERR_MAGIC, is_borrowed=True, - priority=3) + priority=3, +) # list[index] that produces a borrowed result and index is known to be short method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, short_int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemShortBorrow', + c_function_name="CPyList_GetItemShortBorrow", error_kind=ERR_MAGIC, is_borrowed=True, - priority=4) + priority=4, +) # Version with native int index method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, int64_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemInt64', + c_function_name="CPyList_GetItemInt64", error_kind=ERR_MAGIC, - priority=5) + priority=5, +) # Version with native int index method_op( - name='__getitem__', + name="__getitem__", arg_types=[list_rprimitive, int64_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemInt64Borrow', + c_function_name="CPyList_GetItemInt64Borrow", is_borrowed=True, error_kind=ERR_MAGIC, - priority=6) + priority=6, +) # This is unsafe because it assumes that the index is a non-negative short integer # that is in-bounds for the list. list_get_item_unsafe_op = custom_op( arg_types=[list_rprimitive, short_int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetItemUnsafe', - error_kind=ERR_NEVER) + c_function_name="CPyList_GetItemUnsafe", + error_kind=ERR_NEVER, +) # list[index] = obj list_set_item_op = method_op( - name='__setitem__', + name="__setitem__", arg_types=[list_rprimitive, int_rprimitive, object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyList_SetItem', + c_function_name="CPyList_SetItem", error_kind=ERR_FALSE, - steals=[False, False, True]) + steals=[False, False, True], +) # list[index_i64] = obj method_op( - name='__setitem__', + name="__setitem__", arg_types=[list_rprimitive, int64_rprimitive, object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyList_SetItemInt64', + c_function_name="CPyList_SetItemInt64", error_kind=ERR_FALSE, steals=[False, False, True], - priority=2) + priority=2, +) # PyList_SET_ITEM does no error checking, # and should only be used to fill in brand new lists. new_list_set_item_op = custom_op( arg_types=[list_rprimitive, int_rprimitive, object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyList_SetItemUnsafe', + c_function_name="CPyList_SetItemUnsafe", error_kind=ERR_FALSE, - steals=[False, False, True]) + steals=[False, False, True], +) # list.append(obj) list_append_op = method_op( - name='append', + name="append", arg_types=[list_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyList_Append', - error_kind=ERR_NEG_INT) + c_function_name="PyList_Append", + error_kind=ERR_NEG_INT, +) # list.extend(obj) list_extend_op = method_op( - name='extend', + name="extend", arg_types=[list_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_Extend', - error_kind=ERR_MAGIC) + c_function_name="CPyList_Extend", + error_kind=ERR_MAGIC, +) # list.pop() list_pop_last = method_op( - name='pop', + name="pop", arg_types=[list_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_PopLast', - error_kind=ERR_MAGIC) + c_function_name="CPyList_PopLast", + error_kind=ERR_MAGIC, +) # list.pop(index) list_pop = method_op( - name='pop', + name="pop", arg_types=[list_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_Pop', - error_kind=ERR_MAGIC) + c_function_name="CPyList_Pop", + error_kind=ERR_MAGIC, +) # list.count(obj) method_op( - name='count', + name="count", arg_types=[list_rprimitive, object_rprimitive], return_type=short_int_rprimitive, - c_function_name='CPyList_Count', - error_kind=ERR_MAGIC) + c_function_name="CPyList_Count", + error_kind=ERR_MAGIC, +) # list.insert(index, obj) method_op( - name='insert', + name="insert", arg_types=[list_rprimitive, int_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyList_Insert', - error_kind=ERR_NEG_INT) + c_function_name="CPyList_Insert", + error_kind=ERR_NEG_INT, +) # list.sort() method_op( - name='sort', + name="sort", arg_types=[list_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyList_Sort', - error_kind=ERR_NEG_INT) + c_function_name="PyList_Sort", + error_kind=ERR_NEG_INT, +) # list.reverse() method_op( - name='reverse', + name="reverse", arg_types=[list_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyList_Reverse', - error_kind=ERR_NEG_INT) + c_function_name="PyList_Reverse", + error_kind=ERR_NEG_INT, +) # list.remove(obj) method_op( - name='remove', + name="remove", arg_types=[list_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPyList_Remove', - error_kind=ERR_NEG_INT) + c_function_name="CPyList_Remove", + error_kind=ERR_NEG_INT, +) # list.index(obj) method_op( - name='index', + name="index", arg_types=[list_rprimitive, object_rprimitive], return_type=int_rprimitive, - c_function_name='CPyList_Index', - error_kind=ERR_MAGIC) + c_function_name="CPyList_Index", + error_kind=ERR_MAGIC, +) # list * int binary_op( - name='*', + name="*", arg_types=[list_rprimitive, int_rprimitive], return_type=list_rprimitive, - c_function_name='CPySequence_Multiply', - error_kind=ERR_MAGIC) + c_function_name="CPySequence_Multiply", + error_kind=ERR_MAGIC, +) # int * list binary_op( - name='*', + name="*", arg_types=[int_rprimitive, list_rprimitive], return_type=list_rprimitive, - c_function_name='CPySequence_RMultiply', - error_kind=ERR_MAGIC) + c_function_name="CPySequence_RMultiply", + error_kind=ERR_MAGIC, +) # list[begin:end] list_slice_op = custom_op( arg_types=[list_rprimitive, int_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyList_GetSlice', - error_kind=ERR_MAGIC,) + c_function_name="CPyList_GetSlice", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index cfdbb8a0f78d7..4232f16f10fc8 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -1,59 +1,53 @@ """Miscellaneous primitive ops.""" -from mypyc.ir.ops import ERR_NEVER, ERR_MAGIC, ERR_FALSE +from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - bool_rprimitive, object_rprimitive, str_rprimitive, object_pointer_rprimitive, - int_rprimitive, dict_rprimitive, c_int_rprimitive, bit_rprimitive, c_pyssize_t_rprimitive, + bit_rprimitive, + bool_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + dict_rprimitive, + int_rprimitive, list_rprimitive, + object_pointer_rprimitive, + object_rprimitive, + str_rprimitive, ) -from mypyc.primitives.registry import ( - function_op, custom_op, load_address_op, ERR_NEG_INT -) +from mypyc.primitives.registry import ERR_NEG_INT, custom_op, function_op, load_address_op # Get the 'bool' type object. -load_address_op( - name='builtins.bool', - type=object_rprimitive, - src='PyBool_Type') +load_address_op(name="builtins.bool", type=object_rprimitive, src="PyBool_Type") # Get the 'range' type object. -load_address_op( - name='builtins.range', - type=object_rprimitive, - src='PyRange_Type') +load_address_op(name="builtins.range", type=object_rprimitive, src="PyRange_Type") # Get the boxed Python 'None' object -none_object_op = load_address_op( - name='Py_None', - type=object_rprimitive, - src='_Py_NoneStruct') +none_object_op = load_address_op(name="Py_None", type=object_rprimitive, src="_Py_NoneStruct") # Get the boxed object '...' -ellipsis_op = load_address_op( - name='...', - type=object_rprimitive, - src='_Py_EllipsisObject') +ellipsis_op = load_address_op(name="...", type=object_rprimitive, src="_Py_EllipsisObject") # Get the boxed NotImplemented object not_implemented_op = load_address_op( - name='builtins.NotImplemented', - type=object_rprimitive, - src='_Py_NotImplementedStruct') + name="builtins.NotImplemented", type=object_rprimitive, src="_Py_NotImplementedStruct" +) # id(obj) function_op( - name='builtins.id', + name="builtins.id", arg_types=[object_rprimitive], return_type=int_rprimitive, - c_function_name='CPyTagged_Id', - error_kind=ERR_NEVER) + c_function_name="CPyTagged_Id", + error_kind=ERR_NEVER, +) # Return the result of obj.__await()__ or obj.__iter__() (if no __await__ exists) coro_op = custom_op( arg_types=[object_rprimitive], return_type=object_rprimitive, - c_function_name='CPy_GetCoro', - error_kind=ERR_MAGIC) + c_function_name="CPy_GetCoro", + error_kind=ERR_MAGIC, +) # Do obj.send(value), or a next(obj) if second arg is None. # (This behavior is to match the PEP 380 spec for yield from.) @@ -63,8 +57,9 @@ send_op = custom_op( arg_types=[object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPyIter_Send', - error_kind=ERR_NEVER) + c_function_name="CPyIter_Send", + error_kind=ERR_NEVER, +) # This is sort of unfortunate but oh well: yield_from_except performs most of the # error handling logic in `yield from` operations. It returns a bool and passes @@ -78,15 +73,17 @@ yield_from_except_op = custom_op( arg_types=[object_rprimitive, object_pointer_rprimitive], return_type=bool_rprimitive, - c_function_name='CPy_YieldFromErrorHandle', - error_kind=ERR_MAGIC) + c_function_name="CPy_YieldFromErrorHandle", + error_kind=ERR_MAGIC, +) # Create method object from a callable object and self. method_new_op = custom_op( arg_types=[object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='PyMethod_New', - error_kind=ERR_MAGIC) + c_function_name="PyMethod_New", + error_kind=ERR_MAGIC, +) # Check if the current exception is a StopIteration and return its value if so. # Treats "no exception" as StopIteration with a None value. @@ -94,132 +91,143 @@ check_stop_op = custom_op( arg_types=[], return_type=object_rprimitive, - c_function_name='CPy_FetchStopIterationValue', - error_kind=ERR_MAGIC) + c_function_name="CPy_FetchStopIterationValue", + error_kind=ERR_MAGIC, +) # Determine the most derived metaclass and check for metaclass conflicts. # Arguments are (metaclass, bases). py_calc_meta_op = custom_op( arg_types=[object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPy_CalculateMetaclass', + c_function_name="CPy_CalculateMetaclass", error_kind=ERR_MAGIC, - is_borrowed=True + is_borrowed=True, ) # Import a module import_op = custom_op( arg_types=[str_rprimitive], return_type=object_rprimitive, - c_function_name='PyImport_Import', - error_kind=ERR_MAGIC) + c_function_name="PyImport_Import", + error_kind=ERR_MAGIC, +) # Import with extra arguments (used in from import handling) import_extra_args_op = custom_op( - arg_types=[str_rprimitive, dict_rprimitive, dict_rprimitive, - list_rprimitive, c_int_rprimitive], + arg_types=[ + str_rprimitive, + dict_rprimitive, + dict_rprimitive, + list_rprimitive, + c_int_rprimitive, + ], return_type=object_rprimitive, - c_function_name='PyImport_ImportModuleLevelObject', - error_kind=ERR_MAGIC + c_function_name="PyImport_ImportModuleLevelObject", + error_kind=ERR_MAGIC, ) # Import-from helper op import_from_op = custom_op( - arg_types=[object_rprimitive, str_rprimitive, - str_rprimitive, str_rprimitive], + arg_types=[object_rprimitive, str_rprimitive, str_rprimitive, str_rprimitive], return_type=object_rprimitive, - c_function_name='CPyImport_ImportFrom', - error_kind=ERR_MAGIC + c_function_name="CPyImport_ImportFrom", + error_kind=ERR_MAGIC, ) # Get the sys.modules dictionary get_module_dict_op = custom_op( arg_types=[], return_type=dict_rprimitive, - c_function_name='PyImport_GetModuleDict', + c_function_name="PyImport_GetModuleDict", error_kind=ERR_NEVER, - is_borrowed=True) + is_borrowed=True, +) # isinstance(obj, cls) slow_isinstance_op = function_op( - name='builtins.isinstance', + name="builtins.isinstance", arg_types=[object_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_IsInstance', + c_function_name="PyObject_IsInstance", error_kind=ERR_NEG_INT, - truncated_type=bool_rprimitive + truncated_type=bool_rprimitive, ) # Faster isinstance(obj, cls) that only works with native classes and doesn't perform # type checking of the type argument. fast_isinstance_op = function_op( - 'builtins.isinstance', + "builtins.isinstance", arg_types=[object_rprimitive, object_rprimitive], return_type=bool_rprimitive, - c_function_name='CPy_TypeCheck', + c_function_name="CPy_TypeCheck", error_kind=ERR_NEVER, - priority=0) + priority=0, +) # bool(obj) with unboxed result bool_op = function_op( - name='builtins.bool', + name="builtins.bool", arg_types=[object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyObject_IsTrue', + c_function_name="PyObject_IsTrue", error_kind=ERR_NEG_INT, - truncated_type=bool_rprimitive) + truncated_type=bool_rprimitive, +) # slice(start, stop, step) new_slice_op = function_op( - name='builtins.slice', + name="builtins.slice", arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], - c_function_name='PySlice_New', + c_function_name="PySlice_New", return_type=object_rprimitive, - error_kind=ERR_MAGIC) + error_kind=ERR_MAGIC, +) # type(obj) type_op = function_op( - name='builtins.type', + name="builtins.type", arg_types=[object_rprimitive], - c_function_name='PyObject_Type', + c_function_name="PyObject_Type", return_type=object_rprimitive, - error_kind=ERR_NEVER) + error_kind=ERR_NEVER, +) # Get 'builtins.type' (base class of all classes) -type_object_op = load_address_op( - name='builtins.type', - type=object_rprimitive, - src='PyType_Type') +type_object_op = load_address_op(name="builtins.type", type=object_rprimitive, src="PyType_Type") # Create a heap type based on a template non-heap type. # See CPyType_FromTemplate for more docs. pytype_from_template_op = custom_op( arg_types=[object_rprimitive, object_rprimitive, str_rprimitive], return_type=object_rprimitive, - c_function_name='CPyType_FromTemplate', - error_kind=ERR_MAGIC) + c_function_name="CPyType_FromTemplate", + error_kind=ERR_MAGIC, +) # Create a dataclass from an extension class. See # CPyDataclass_SleightOfHand for more docs. dataclass_sleight_of_hand = custom_op( arg_types=[object_rprimitive, object_rprimitive, dict_rprimitive, dict_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyDataclass_SleightOfHand', - error_kind=ERR_FALSE) + c_function_name="CPyDataclass_SleightOfHand", + error_kind=ERR_FALSE, +) # Raise ValueError if length of first argument is not equal to the second argument. # The first argument must be a list or a variable-length tuple. check_unpack_count_op = custom_op( arg_types=[object_rprimitive, c_pyssize_t_rprimitive], return_type=c_int_rprimitive, - c_function_name='CPySequence_CheckUnpackCount', - error_kind=ERR_NEG_INT) + c_function_name="CPySequence_CheckUnpackCount", + error_kind=ERR_NEG_INT, +) # register an implementation for a singledispatch function register_function = custom_op( arg_types=[object_rprimitive, object_rprimitive, object_rprimitive], return_type=object_rprimitive, - c_function_name='CPySingledispatch_RegisterFunction', + c_function_name="CPySingledispatch_RegisterFunction", error_kind=ERR_MAGIC, ) diff --git a/mypyc/primitives/registry.py b/mypyc/primitives/registry.py index 0174051ec98d4..ca9f937ce768d 100644 --- a/mypyc/primitives/registry.py +++ b/mypyc/primitives/registry.py @@ -35,7 +35,8 @@ optimized implementations of all ops. """ -from typing import Dict, List, Optional, NamedTuple, Tuple +from typing import Dict, List, NamedTuple, Optional, Tuple + from typing_extensions import Final from mypyc.ir.ops import StealsDescription @@ -47,25 +48,28 @@ CFunctionDescription = NamedTuple( - 'CFunctionDescription', [('name', str), - ('arg_types', List[RType]), - ('return_type', RType), - ('var_arg_type', Optional[RType]), - ('truncated_type', Optional[RType]), - ('c_function_name', str), - ('error_kind', int), - ('steals', StealsDescription), - ('is_borrowed', bool), - ('ordering', Optional[List[int]]), - ('extra_int_constants', List[Tuple[int, RType]]), - ('priority', int)]) + "CFunctionDescription", + [ + ("name", str), + ("arg_types", List[RType]), + ("return_type", RType), + ("var_arg_type", Optional[RType]), + ("truncated_type", Optional[RType]), + ("c_function_name", str), + ("error_kind", int), + ("steals", StealsDescription), + ("is_borrowed", bool), + ("ordering", Optional[List[int]]), + ("extra_int_constants", List[Tuple[int, RType]]), + ("priority", int), + ], +) # A description for C load operations including LoadGlobal and LoadAddress LoadAddressDescription = NamedTuple( - 'LoadAddressDescription', [('name', str), - ('type', RType), - ('src', str)]) # name of the target to load + "LoadAddressDescription", [("name", str), ("type", RType), ("src", str)] +) # name of the target to load # CallC op for method call(such as 'str.join') @@ -83,18 +87,20 @@ builtin_names: Dict[str, Tuple[RType, str]] = {} -def method_op(name: str, - arg_types: List[RType], - return_type: RType, - c_function_name: str, - error_kind: int, - var_arg_type: Optional[RType] = None, - truncated_type: Optional[RType] = None, - ordering: Optional[List[int]] = None, - extra_int_constants: List[Tuple[int, RType]] = [], - steals: StealsDescription = False, - is_borrowed: bool = False, - priority: int = 1) -> CFunctionDescription: +def method_op( + name: str, + arg_types: List[RType], + return_type: RType, + c_function_name: str, + error_kind: int, + var_arg_type: Optional[RType] = None, + truncated_type: Optional[RType] = None, + ordering: Optional[List[int]] = None, + extra_int_constants: List[Tuple[int, RType]] = [], + steals: StealsDescription = False, + is_borrowed: bool = False, + priority: int = 1, +) -> CFunctionDescription: """Define a c function call op that replaces a method call. This will be automatically generated by matching against the AST. @@ -120,25 +126,38 @@ def method_op(name: str, priority: if multiple ops match, the one with the highest priority is picked """ ops = method_call_ops.setdefault(name, []) - desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type, - c_function_name, error_kind, steals, is_borrowed, ordering, - extra_int_constants, priority) + desc = CFunctionDescription( + name, + arg_types, + return_type, + var_arg_type, + truncated_type, + c_function_name, + error_kind, + steals, + is_borrowed, + ordering, + extra_int_constants, + priority, + ) ops.append(desc) return desc -def function_op(name: str, - arg_types: List[RType], - return_type: RType, - c_function_name: str, - error_kind: int, - var_arg_type: Optional[RType] = None, - truncated_type: Optional[RType] = None, - ordering: Optional[List[int]] = None, - extra_int_constants: List[Tuple[int, RType]] = [], - steals: StealsDescription = False, - is_borrowed: bool = False, - priority: int = 1) -> CFunctionDescription: +def function_op( + name: str, + arg_types: List[RType], + return_type: RType, + c_function_name: str, + error_kind: int, + var_arg_type: Optional[RType] = None, + truncated_type: Optional[RType] = None, + ordering: Optional[List[int]] = None, + extra_int_constants: List[Tuple[int, RType]] = [], + steals: StealsDescription = False, + is_borrowed: bool = False, + priority: int = 1, +) -> CFunctionDescription: """Define a c function call op that replaces a function call. This will be automatically generated by matching against the AST. @@ -150,25 +169,38 @@ def function_op(name: str, arg_types: positional argument types for which this applies """ ops = function_ops.setdefault(name, []) - desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type, - c_function_name, error_kind, steals, is_borrowed, ordering, - extra_int_constants, priority) + desc = CFunctionDescription( + name, + arg_types, + return_type, + var_arg_type, + truncated_type, + c_function_name, + error_kind, + steals, + is_borrowed, + ordering, + extra_int_constants, + priority, + ) ops.append(desc) return desc -def binary_op(name: str, - arg_types: List[RType], - return_type: RType, - c_function_name: str, - error_kind: int, - var_arg_type: Optional[RType] = None, - truncated_type: Optional[RType] = None, - ordering: Optional[List[int]] = None, - extra_int_constants: List[Tuple[int, RType]] = [], - steals: StealsDescription = False, - is_borrowed: bool = False, - priority: int = 1) -> CFunctionDescription: +def binary_op( + name: str, + arg_types: List[RType], + return_type: RType, + c_function_name: str, + error_kind: int, + var_arg_type: Optional[RType] = None, + truncated_type: Optional[RType] = None, + ordering: Optional[List[int]] = None, + extra_int_constants: List[Tuple[int, RType]] = [], + steals: StealsDescription = False, + is_borrowed: bool = False, + priority: int = 1, +) -> CFunctionDescription: """Define a c function call op for a binary operation. This will be automatically generated by matching against the AST. @@ -177,43 +209,69 @@ def binary_op(name: str, are expected. """ ops = binary_ops.setdefault(name, []) - desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type, - c_function_name, error_kind, steals, is_borrowed, ordering, - extra_int_constants, priority) + desc = CFunctionDescription( + name, + arg_types, + return_type, + var_arg_type, + truncated_type, + c_function_name, + error_kind, + steals, + is_borrowed, + ordering, + extra_int_constants, + priority, + ) ops.append(desc) return desc -def custom_op(arg_types: List[RType], - return_type: RType, - c_function_name: str, - error_kind: int, - var_arg_type: Optional[RType] = None, - truncated_type: Optional[RType] = None, - ordering: Optional[List[int]] = None, - extra_int_constants: List[Tuple[int, RType]] = [], - steals: StealsDescription = False, - is_borrowed: bool = False) -> CFunctionDescription: +def custom_op( + arg_types: List[RType], + return_type: RType, + c_function_name: str, + error_kind: int, + var_arg_type: Optional[RType] = None, + truncated_type: Optional[RType] = None, + ordering: Optional[List[int]] = None, + extra_int_constants: List[Tuple[int, RType]] = [], + steals: StealsDescription = False, + is_borrowed: bool = False, +) -> CFunctionDescription: """Create a one-off CallC op that can't be automatically generated from the AST. Most arguments are similar to method_op(). """ - return CFunctionDescription('', arg_types, return_type, var_arg_type, truncated_type, - c_function_name, error_kind, steals, is_borrowed, ordering, - extra_int_constants, 0) - - -def unary_op(name: str, - arg_type: RType, - return_type: RType, - c_function_name: str, - error_kind: int, - truncated_type: Optional[RType] = None, - ordering: Optional[List[int]] = None, - extra_int_constants: List[Tuple[int, RType]] = [], - steals: StealsDescription = False, - is_borrowed: bool = False, - priority: int = 1) -> CFunctionDescription: + return CFunctionDescription( + "", + arg_types, + return_type, + var_arg_type, + truncated_type, + c_function_name, + error_kind, + steals, + is_borrowed, + ordering, + extra_int_constants, + 0, + ) + + +def unary_op( + name: str, + arg_type: RType, + return_type: RType, + c_function_name: str, + error_kind: int, + truncated_type: Optional[RType] = None, + ordering: Optional[List[int]] = None, + extra_int_constants: List[Tuple[int, RType]] = [], + steals: StealsDescription = False, + is_borrowed: bool = False, + priority: int = 1, +) -> CFunctionDescription: """Define a c function call op for an unary operation. This will be automatically generated by matching against the AST. @@ -222,27 +280,37 @@ def unary_op(name: str, is expected. """ ops = unary_ops.setdefault(name, []) - desc = CFunctionDescription(name, [arg_type], return_type, None, truncated_type, - c_function_name, error_kind, steals, is_borrowed, ordering, - extra_int_constants, priority) + desc = CFunctionDescription( + name, + [arg_type], + return_type, + None, + truncated_type, + c_function_name, + error_kind, + steals, + is_borrowed, + ordering, + extra_int_constants, + priority, + ) ops.append(desc) return desc -def load_address_op(name: str, - type: RType, - src: str) -> LoadAddressDescription: - assert name not in builtin_names, 'already defined: %s' % name +def load_address_op(name: str, type: RType, src: str) -> LoadAddressDescription: + assert name not in builtin_names, "already defined: %s" % name builtin_names[name] = (type, src) return LoadAddressDescription(name, type, src) +import mypyc.primitives.bytes_ops # noqa +import mypyc.primitives.dict_ops # noqa +import mypyc.primitives.float_ops # noqa + # Import various modules that set up global state. import mypyc.primitives.int_ops # noqa -import mypyc.primitives.str_ops # noqa -import mypyc.primitives.bytes_ops # noqa import mypyc.primitives.list_ops # noqa -import mypyc.primitives.dict_ops # noqa -import mypyc.primitives.tuple_ops # noqa import mypyc.primitives.misc_ops # noqa -import mypyc.primitives.float_ops # noqa +import mypyc.primitives.str_ops # noqa +import mypyc.primitives.tuple_ops # noqa diff --git a/mypyc/primitives/set_ops.py b/mypyc/primitives/set_ops.py index 5d18e45ad528b..bc6523c17c08f 100644 --- a/mypyc/primitives/set_ops.py +++ b/mypyc/primitives/set_ops.py @@ -1,108 +1,119 @@ """Primitive set (and frozenset) ops.""" -from mypyc.primitives.registry import ( - load_address_op, function_op, method_op, binary_op, ERR_NEG_INT -) -from mypyc.ir.ops import ERR_MAGIC, ERR_FALSE +from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC from mypyc.ir.rtypes import ( - object_rprimitive, bool_rprimitive, set_rprimitive, c_int_rprimitive, pointer_rprimitive, - bit_rprimitive + bit_rprimitive, + bool_rprimitive, + c_int_rprimitive, + object_rprimitive, + pointer_rprimitive, + set_rprimitive, +) +from mypyc.primitives.registry import ( + ERR_NEG_INT, + binary_op, + function_op, + load_address_op, + method_op, ) - # Get the 'builtins.set' type object. -load_address_op( - name='builtins.set', - type=object_rprimitive, - src='PySet_Type') +load_address_op(name="builtins.set", type=object_rprimitive, src="PySet_Type") # Get the 'builtins.frozenset' tyoe object. -load_address_op( - name='builtins.frozenset', - type=object_rprimitive, - src='PyFrozenSet_Type') +load_address_op(name="builtins.frozenset", type=object_rprimitive, src="PyFrozenSet_Type") # Construct an empty set. new_set_op = function_op( - name='builtins.set', + name="builtins.set", arg_types=[], return_type=set_rprimitive, - c_function_name='PySet_New', + c_function_name="PySet_New", error_kind=ERR_MAGIC, - extra_int_constants=[(0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive)], +) # set(obj) function_op( - name='builtins.set', + name="builtins.set", arg_types=[object_rprimitive], return_type=set_rprimitive, - c_function_name='PySet_New', - error_kind=ERR_MAGIC) + c_function_name="PySet_New", + error_kind=ERR_MAGIC, +) # frozenset(obj) function_op( - name='builtins.frozenset', + name="builtins.frozenset", arg_types=[object_rprimitive], return_type=object_rprimitive, - c_function_name='PyFrozenSet_New', - error_kind=ERR_MAGIC) + c_function_name="PyFrozenSet_New", + error_kind=ERR_MAGIC, +) # item in set binary_op( - name='in', + name="in", arg_types=[object_rprimitive, set_rprimitive], return_type=c_int_rprimitive, - c_function_name='PySet_Contains', + c_function_name="PySet_Contains", error_kind=ERR_NEG_INT, truncated_type=bool_rprimitive, - ordering=[1, 0]) + ordering=[1, 0], +) # set.remove(obj) method_op( - name='remove', + name="remove", arg_types=[set_rprimitive, object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPySet_Remove', - error_kind=ERR_FALSE) + c_function_name="CPySet_Remove", + error_kind=ERR_FALSE, +) # set.discard(obj) method_op( - name='discard', + name="discard", arg_types=[set_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PySet_Discard', - error_kind=ERR_NEG_INT) + c_function_name="PySet_Discard", + error_kind=ERR_NEG_INT, +) # set.add(obj) set_add_op = method_op( - name='add', + name="add", arg_types=[set_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='PySet_Add', - error_kind=ERR_NEG_INT) + c_function_name="PySet_Add", + error_kind=ERR_NEG_INT, +) # set.update(obj) # # This is not a public API but looks like it should be fine. set_update_op = method_op( - name='update', + name="update", arg_types=[set_rprimitive, object_rprimitive], return_type=c_int_rprimitive, - c_function_name='_PySet_Update', - error_kind=ERR_NEG_INT) + c_function_name="_PySet_Update", + error_kind=ERR_NEG_INT, +) # set.clear() method_op( - name='clear', + name="clear", arg_types=[set_rprimitive], return_type=c_int_rprimitive, - c_function_name='PySet_Clear', - error_kind=ERR_NEG_INT) + c_function_name="PySet_Clear", + error_kind=ERR_NEG_INT, +) # set.pop() method_op( - name='pop', + name="pop", arg_types=[set_rprimitive], return_type=object_rprimitive, - c_function_name='PySet_Pop', - error_kind=ERR_MAGIC) + c_function_name="PySet_Pop", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index e7db008f4218c..79ac15c602343 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -4,105 +4,118 @@ from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( - RType, object_rprimitive, str_rprimitive, int_rprimitive, list_rprimitive, - c_int_rprimitive, pointer_rprimitive, bool_rprimitive, bit_rprimitive, - c_pyssize_t_rprimitive, bytes_rprimitive + RType, + bit_rprimitive, + bool_rprimitive, + bytes_rprimitive, + c_int_rprimitive, + c_pyssize_t_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, + pointer_rprimitive, + str_rprimitive, ) from mypyc.primitives.registry import ( - method_op, binary_op, function_op, - load_address_op, custom_op, ERR_NEG_INT + ERR_NEG_INT, + binary_op, + custom_op, + function_op, + load_address_op, + method_op, ) - # Get the 'str' type object. -load_address_op( - name='builtins.str', - type=object_rprimitive, - src='PyUnicode_Type') +load_address_op(name="builtins.str", type=object_rprimitive, src="PyUnicode_Type") # str(obj) str_op = function_op( - name='builtins.str', + name="builtins.str", arg_types=[object_rprimitive], return_type=str_rprimitive, - c_function_name='PyObject_Str', - error_kind=ERR_MAGIC) + c_function_name="PyObject_Str", + error_kind=ERR_MAGIC, +) # str1 + str2 binary_op( - name='+', + name="+", arg_types=[str_rprimitive, str_rprimitive], return_type=str_rprimitive, - c_function_name='PyUnicode_Concat', - error_kind=ERR_MAGIC) + c_function_name="PyUnicode_Concat", + error_kind=ERR_MAGIC, +) # str1 += str2 # # PyUnicode_Append makes an effort to reuse the LHS when the refcount # is 1. This is super dodgy but oh well, the interpreter does it. binary_op( - name='+=', + name="+=", arg_types=[str_rprimitive, str_rprimitive], return_type=str_rprimitive, - c_function_name='CPyStr_Append', + c_function_name="CPyStr_Append", error_kind=ERR_MAGIC, - steals=[True, False]) + steals=[True, False], +) unicode_compare = custom_op( arg_types=[str_rprimitive, str_rprimitive], return_type=c_int_rprimitive, - c_function_name='PyUnicode_Compare', - error_kind=ERR_NEVER) + c_function_name="PyUnicode_Compare", + error_kind=ERR_NEVER, +) # str[index] (for an int index) method_op( - name='__getitem__', + name="__getitem__", arg_types=[str_rprimitive, int_rprimitive], return_type=str_rprimitive, - c_function_name='CPyStr_GetItem', - error_kind=ERR_MAGIC + c_function_name="CPyStr_GetItem", + error_kind=ERR_MAGIC, ) # str[begin:end] str_slice_op = custom_op( arg_types=[str_rprimitive, int_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPyStr_GetSlice', - error_kind=ERR_MAGIC) + c_function_name="CPyStr_GetSlice", + error_kind=ERR_MAGIC, +) # str.join(obj) method_op( - name='join', + name="join", arg_types=[str_rprimitive, object_rprimitive], return_type=str_rprimitive, - c_function_name='PyUnicode_Join', - error_kind=ERR_MAGIC + c_function_name="PyUnicode_Join", + error_kind=ERR_MAGIC, ) str_build_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=str_rprimitive, - c_function_name='CPyStr_Build', + c_function_name="CPyStr_Build", error_kind=ERR_MAGIC, - var_arg_type=str_rprimitive + var_arg_type=str_rprimitive, ) # str.startswith(str) method_op( - name='startswith', + name="startswith", arg_types=[str_rprimitive, str_rprimitive], return_type=bool_rprimitive, - c_function_name='CPyStr_Startswith', - error_kind=ERR_NEVER + c_function_name="CPyStr_Startswith", + error_kind=ERR_NEVER, ) # str.endswith(str) method_op( - name='endswith', + name="endswith", arg_types=[str_rprimitive, str_rprimitive], return_type=bool_rprimitive, - c_function_name='CPyStr_Endswith', - error_kind=ERR_NEVER + c_function_name="CPyStr_Endswith", + error_kind=ERR_NEVER, ) # str.split(...) @@ -115,91 +128,102 @@ ] for i in range(len(str_split_types)): method_op( - name='split', - arg_types=str_split_types[0:i+1], + name="split", + arg_types=str_split_types[0 : i + 1], return_type=list_rprimitive, c_function_name=str_split_functions[i], extra_int_constants=str_split_constants[i], - error_kind=ERR_MAGIC) + error_kind=ERR_MAGIC, + ) # str.replace(old, new) method_op( - name='replace', + name="replace", arg_types=[str_rprimitive, str_rprimitive, str_rprimitive], return_type=str_rprimitive, - c_function_name='PyUnicode_Replace', + c_function_name="PyUnicode_Replace", error_kind=ERR_MAGIC, - extra_int_constants=[(-1, c_int_rprimitive)]) + extra_int_constants=[(-1, c_int_rprimitive)], +) # str.replace(old, new, count) method_op( - name='replace', + name="replace", arg_types=[str_rprimitive, str_rprimitive, str_rprimitive, int_rprimitive], return_type=str_rprimitive, - c_function_name='CPyStr_Replace', - error_kind=ERR_MAGIC) + c_function_name="CPyStr_Replace", + error_kind=ERR_MAGIC, +) # check if a string is true (isn't an empty string) str_check_if_true = custom_op( arg_types=[str_rprimitive], return_type=bit_rprimitive, - c_function_name='CPyStr_IsTrue', - error_kind=ERR_NEVER) + c_function_name="CPyStr_IsTrue", + error_kind=ERR_NEVER, +) str_ssize_t_size_op = custom_op( arg_types=[str_rprimitive], return_type=c_pyssize_t_rprimitive, - c_function_name='CPyStr_Size_size_t', - error_kind=ERR_NEG_INT) + c_function_name="CPyStr_Size_size_t", + error_kind=ERR_NEG_INT, +) # obj.decode() method_op( - name='decode', + name="decode", arg_types=[bytes_rprimitive], return_type=str_rprimitive, - c_function_name='CPy_Decode', + c_function_name="CPy_Decode", error_kind=ERR_MAGIC, - extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)], +) # obj.decode(encoding) method_op( - name='decode', + name="decode", arg_types=[bytes_rprimitive, str_rprimitive], return_type=str_rprimitive, - c_function_name='CPy_Decode', + c_function_name="CPy_Decode", error_kind=ERR_MAGIC, - extra_int_constants=[(0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive)], +) # obj.decode(encoding, errors) method_op( - name='decode', + name="decode", arg_types=[bytes_rprimitive, str_rprimitive, str_rprimitive], return_type=str_rprimitive, - c_function_name='CPy_Decode', - error_kind=ERR_MAGIC) + c_function_name="CPy_Decode", + error_kind=ERR_MAGIC, +) # str.encode() method_op( - name='encode', + name="encode", arg_types=[str_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPy_Encode', + c_function_name="CPy_Encode", error_kind=ERR_MAGIC, - extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive), (0, pointer_rprimitive)], +) # str.encode(encoding) method_op( - name='encode', + name="encode", arg_types=[str_rprimitive, str_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPy_Encode', + c_function_name="CPy_Encode", error_kind=ERR_MAGIC, - extra_int_constants=[(0, pointer_rprimitive)]) + extra_int_constants=[(0, pointer_rprimitive)], +) # str.encode(encoding, errors) method_op( - name='encode', + name="encode", arg_types=[str_rprimitive, str_rprimitive, str_rprimitive], return_type=bytes_rprimitive, - c_function_name='CPy_Encode', - error_kind=ERR_MAGIC) + c_function_name="CPy_Encode", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/primitives/tuple_ops.py b/mypyc/primitives/tuple_ops.py index 33f8e331b56dc..862d2b0ca0785 100644 --- a/mypyc/primitives/tuple_ops.py +++ b/mypyc/primitives/tuple_ops.py @@ -4,70 +4,78 @@ objects, i.e. tuple_rprimitive (RPrimitive), not RTuple. """ -from mypyc.ir.ops import ERR_MAGIC, ERR_FALSE +from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC from mypyc.ir.rtypes import ( - tuple_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, - c_pyssize_t_rprimitive, bit_rprimitive + bit_rprimitive, + c_pyssize_t_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, + tuple_rprimitive, ) -from mypyc.primitives.registry import load_address_op, method_op, function_op, custom_op +from mypyc.primitives.registry import custom_op, function_op, load_address_op, method_op # Get the 'builtins.tuple' type object. -load_address_op( - name='builtins.tuple', - type=object_rprimitive, - src='PyTuple_Type') +load_address_op(name="builtins.tuple", type=object_rprimitive, src="PyTuple_Type") # tuple[index] (for an int index) tuple_get_item_op = method_op( - name='__getitem__', + name="__getitem__", arg_types=[tuple_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPySequenceTuple_GetItem', - error_kind=ERR_MAGIC) + c_function_name="CPySequenceTuple_GetItem", + error_kind=ERR_MAGIC, +) # Construct a boxed tuple from items: (item1, item2, ...) new_tuple_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=tuple_rprimitive, - c_function_name='PyTuple_Pack', + c_function_name="PyTuple_Pack", error_kind=ERR_MAGIC, - var_arg_type=object_rprimitive) + var_arg_type=object_rprimitive, +) new_tuple_with_length_op = custom_op( arg_types=[c_pyssize_t_rprimitive], return_type=tuple_rprimitive, - c_function_name='PyTuple_New', - error_kind=ERR_MAGIC) + c_function_name="PyTuple_New", + error_kind=ERR_MAGIC, +) # PyTuple_SET_ITEM does no error checking, # and should only be used to fill in brand new tuples. new_tuple_set_item_op = custom_op( arg_types=[tuple_rprimitive, int_rprimitive, object_rprimitive], return_type=bit_rprimitive, - c_function_name='CPySequenceTuple_SetItemUnsafe', + c_function_name="CPySequenceTuple_SetItemUnsafe", error_kind=ERR_FALSE, - steals=[False, False, True]) + steals=[False, False, True], +) # Construct tuple from a list. list_tuple_op = function_op( - name='builtins.tuple', + name="builtins.tuple", arg_types=[list_rprimitive], return_type=tuple_rprimitive, - c_function_name='PyList_AsTuple', + c_function_name="PyList_AsTuple", error_kind=ERR_MAGIC, - priority=2) + priority=2, +) # Construct tuple from an arbitrary (iterable) object. function_op( - name='builtins.tuple', + name="builtins.tuple", arg_types=[object_rprimitive], return_type=tuple_rprimitive, - c_function_name='PySequence_Tuple', - error_kind=ERR_MAGIC) + c_function_name="PySequence_Tuple", + error_kind=ERR_MAGIC, +) # tuple[begin:end] tuple_slice_op = custom_op( arg_types=[tuple_rprimitive, int_rprimitive, int_rprimitive], return_type=object_rprimitive, - c_function_name='CPySequenceTuple_GetSlice', - error_kind=ERR_MAGIC) + c_function_name="CPySequenceTuple_GetSlice", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/rt_subtype.py b/mypyc/rt_subtype.py index 7b1d207957d2e..4e41914063331 100644 --- a/mypyc/rt_subtype.py +++ b/mypyc/rt_subtype.py @@ -14,8 +14,19 @@ """ from mypyc.ir.rtypes import ( - RType, RUnion, RInstance, RPrimitive, RTuple, RVoid, RTypeVisitor, RStruct, RArray, - is_int_rprimitive, is_short_int_rprimitive, is_bool_rprimitive, is_bit_rprimitive + RArray, + RInstance, + RPrimitive, + RStruct, + RTuple, + RType, + RTypeVisitor, + RUnion, + RVoid, + is_bit_rprimitive, + is_bool_rprimitive, + is_int_rprimitive, + is_short_int_rprimitive, ) from mypyc.subtype import is_subtype @@ -50,7 +61,8 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: def visit_rtuple(self, left: RTuple) -> bool: if isinstance(self.right, RTuple): return len(self.right.types) == len(left.types) and all( - is_runtime_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types)) + is_runtime_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types) + ) return False def visit_rstruct(self, left: RStruct) -> bool: diff --git a/mypyc/sametype.py b/mypyc/sametype.py index 912585ceabfa2..c16b2e658d589 100644 --- a/mypyc/sametype.py +++ b/mypyc/sametype.py @@ -1,9 +1,17 @@ """Same type check for RTypes.""" +from mypyc.ir.func_ir import FuncSignature from mypyc.ir.rtypes import ( - RType, RTypeVisitor, RInstance, RPrimitive, RTuple, RVoid, RUnion, RStruct, RArray + RArray, + RInstance, + RPrimitive, + RStruct, + RTuple, + RType, + RTypeVisitor, + RUnion, + RVoid, ) -from mypyc.ir.func_ir import FuncSignature def is_same_type(a: RType, b: RType) -> bool: @@ -11,17 +19,24 @@ def is_same_type(a: RType, b: RType) -> bool: def is_same_signature(a: FuncSignature, b: FuncSignature) -> bool: - return (len(a.args) == len(b.args) - and is_same_type(a.ret_type, b.ret_type) - and all(is_same_type(t1.type, t2.type) and t1.name == t2.name - for t1, t2 in zip(a.args, b.args))) + return ( + len(a.args) == len(b.args) + and is_same_type(a.ret_type, b.ret_type) + and all( + is_same_type(t1.type, t2.type) and t1.name == t2.name for t1, t2 in zip(a.args, b.args) + ) + ) def is_same_method_signature(a: FuncSignature, b: FuncSignature) -> bool: - return (len(a.args) == len(b.args) - and is_same_type(a.ret_type, b.ret_type) - and all(is_same_type(t1.type, t2.type) and t1.name == t2.name - for t1, t2 in zip(a.args[1:], b.args[1:]))) + return ( + len(a.args) == len(b.args) + and is_same_type(a.ret_type, b.ret_type) + and all( + is_same_type(t1.type, t2.type) and t1.name == t2.name + for t1, t2 in zip(a.args[1:], b.args[1:]) + ) + ) class SameTypeVisitor(RTypeVisitor[bool]): @@ -48,9 +63,11 @@ def visit_rprimitive(self, left: RPrimitive) -> bool: return left is self.right def visit_rtuple(self, left: RTuple) -> bool: - return (isinstance(self.right, RTuple) + return ( + isinstance(self.right, RTuple) and len(self.right.types) == len(left.types) - and all(is_same_type(t1, t2) for t1, t2 in zip(left.types, self.right.types))) + and all(is_same_type(t1, t2) for t1, t2 in zip(left.types, self.right.types)) + ) def visit_rstruct(self, left: RStruct) -> bool: return isinstance(self.right, RStruct) and self.right.name == left.name diff --git a/mypyc/subtype.py b/mypyc/subtype.py index 4ba8f6301c63e..26ceb9e308f1e 100644 --- a/mypyc/subtype.py +++ b/mypyc/subtype.py @@ -1,9 +1,23 @@ """Subtype check for RTypes.""" from mypyc.ir.rtypes import ( - RType, RInstance, RPrimitive, RTuple, RVoid, RTypeVisitor, RUnion, RStruct, RArray, - is_bool_rprimitive, is_int_rprimitive, is_tuple_rprimitive, is_short_int_rprimitive, - is_object_rprimitive, is_bit_rprimitive, is_tagged, is_fixed_width_rtype + RArray, + RInstance, + RPrimitive, + RStruct, + RTuple, + RType, + RTypeVisitor, + RUnion, + RVoid, + is_bit_rprimitive, + is_bool_rprimitive, + is_fixed_width_rtype, + is_int_rprimitive, + is_object_rprimitive, + is_short_int_rprimitive, + is_tagged, + is_tuple_rprimitive, ) @@ -13,13 +27,11 @@ def is_subtype(left: RType, right: RType) -> bool: elif isinstance(right, RUnion): if isinstance(left, RUnion): for left_item in left.items: - if not any(is_subtype(left_item, right_item) - for right_item in right.items): + if not any(is_subtype(left_item, right_item) for right_item in right.items): return False return True else: - return any(is_subtype(left, item) - for item in right.items) + return any(is_subtype(left, item) for item in right.items) return left.accept(SubtypeVisitor(right)) @@ -37,8 +49,7 @@ def visit_rinstance(self, left: RInstance) -> bool: return isinstance(self.right, RInstance) and self.right.class_ir in left.class_ir.mro def visit_runion(self, left: RUnion) -> bool: - return all(is_subtype(item, self.right) - for item in left.items) + return all(is_subtype(item, self.right) for item in left.items) def visit_rprimitive(self, left: RPrimitive) -> bool: right = self.right @@ -61,7 +72,8 @@ def visit_rtuple(self, left: RTuple) -> bool: return True if isinstance(self.right, RTuple): return len(self.right.types) == len(left.types) and all( - is_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types)) + is_subtype(t1, t2) for t1, t2 in zip(left.types, self.right.types) + ) return False def visit_rstruct(self, left: RStruct) -> bool: diff --git a/mypyc/test/config.py b/mypyc/test/config.py index 1158c5c459be8..f515806fb58c9 100644 --- a/mypyc/test/config.py +++ b/mypyc/test/config.py @@ -1,6 +1,6 @@ import os -provided_prefix = os.getenv('MYPY_TEST_PREFIX', None) +provided_prefix = os.getenv("MYPY_TEST_PREFIX", None) if provided_prefix: PREFIX = provided_prefix else: @@ -8,4 +8,4 @@ PREFIX = os.path.dirname(os.path.dirname(this_file_dir)) # Location of test data files such as test case descriptions. -test_data_prefix = os.path.join(PREFIX, 'mypyc', 'test-data') +test_data_prefix = os.path.join(PREFIX, "mypyc", "test-data") diff --git a/mypyc/test/test_alwaysdefined.py b/mypyc/test/test_alwaysdefined.py index f9a90fabf2a10..5eba8c979839e 100644 --- a/mypyc/test/test_alwaysdefined.py +++ b/mypyc/test/test_alwaysdefined.py @@ -2,18 +2,19 @@ import os.path +from mypy.errors import CompileError from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase -from mypy.errors import CompileError - from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file2, - assert_test_output, infer_ir_build_options_from_test_name + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file2, + infer_ir_build_options_from_test_name, + use_custom_builtins, ) -files = [ - 'alwaysdefined.test' -] +files = ["alwaysdefined.test"] class TestAlwaysDefined(MypycDataSuite): @@ -34,9 +35,10 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: actual = [] for cl in ir.classes: - if cl.name.startswith('_'): + if cl.name.startswith("_"): continue - actual.append('{}: [{}]'.format( - cl.name, ', '.join(sorted(cl._always_initialized_attrs)))) + actual.append( + "{}: [{}]".format(cl.name, ", ".join(sorted(cl._always_initialized_attrs))) + ) - assert_test_output(testcase, actual, 'Invalid test output', testcase.output) + assert_test_output(testcase, actual, "Invalid test output", testcase.output) diff --git a/mypyc/test/test_analysis.py b/mypyc/test/test_analysis.py index b71983705b65e..944d64e50a2f4 100644 --- a/mypyc/test/test_analysis.py +++ b/mypyc/test/test_analysis.py @@ -3,24 +3,24 @@ import os.path from typing import Set -from mypy.test.data import DataDrivenTestCase -from mypy.test.config import test_temp_dir from mypy.errors import CompileError - -from mypyc.common import TOP_LEVEL_NAME +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase from mypyc.analysis import dataflow -from mypyc.transform import exceptions -from mypyc.ir.pprint import format_func, generate_names_for_ir -from mypyc.ir.ops import Value +from mypyc.common import TOP_LEVEL_NAME from mypyc.ir.func_ir import all_values +from mypyc.ir.ops import Value +from mypyc.ir.pprint import format_func, generate_names_for_ir from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file, - assert_test_output + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + use_custom_builtins, ) +from mypyc.transform import exceptions -files = [ - 'analysis.test' -] +files = ["analysis.test"] class TestAnalysis(MypycDataSuite): @@ -39,39 +39,38 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: actual = [] for fn in ir: - if (fn.name == TOP_LEVEL_NAME - and not testcase.name.endswith('_toplevel')): + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): continue exceptions.insert_exception_handling(fn) actual.extend(format_func(fn)) cfg = dataflow.get_cfg(fn.blocks) args: Set[Value] = set(fn.arg_regs) name = testcase.name - if name.endswith('_MaybeDefined'): + if name.endswith("_MaybeDefined"): # Forward, maybe analysis_result = dataflow.analyze_maybe_defined_regs(fn.blocks, cfg, args) - elif name.endswith('_Liveness'): + elif name.endswith("_Liveness"): # Backward, maybe analysis_result = dataflow.analyze_live_regs(fn.blocks, cfg) - elif name.endswith('_MustDefined'): + elif name.endswith("_MustDefined"): # Forward, must analysis_result = dataflow.analyze_must_defined_regs( - fn.blocks, cfg, args, - regs=all_values(fn.arg_regs, fn.blocks)) - elif name.endswith('_BorrowedArgument'): + fn.blocks, cfg, args, regs=all_values(fn.arg_regs, fn.blocks) + ) + elif name.endswith("_BorrowedArgument"): # Forward, must analysis_result = dataflow.analyze_borrowed_arguments(fn.blocks, cfg, args) else: - assert False, 'No recognized _AnalysisName suffix in test case' + assert False, "No recognized _AnalysisName suffix in test case" names = generate_names_for_ir(fn.arg_regs, fn.blocks) - for key in sorted(analysis_result.before.keys(), - key=lambda x: (x[0].label, x[1])): - pre = ', '.join(sorted(names[reg] - for reg in analysis_result.before[key])) - post = ', '.join(sorted(names[reg] - for reg in analysis_result.after[key])) - actual.append('%-8s %-23s %s' % ((key[0].label, key[1]), - '{%s}' % pre, '{%s}' % post)) - assert_test_output(testcase, actual, 'Invalid source code output') + for key in sorted( + analysis_result.before.keys(), key=lambda x: (x[0].label, x[1]) + ): + pre = ", ".join(sorted(names[reg] for reg in analysis_result.before[key])) + post = ", ".join(sorted(names[reg] for reg in analysis_result.after[key])) + actual.append( + "%-8s %-23s %s" % ((key[0].label, key[1]), "{%s}" % pre, "{%s}" % post) + ) + assert_test_output(testcase, actual, "Invalid source code output") diff --git a/mypyc/test/test_cheader.py b/mypyc/test/test_cheader.py index 0966059e24438..f0313649090e6 100644 --- a/mypyc/test/test_cheader.py +++ b/mypyc/test/test_cheader.py @@ -11,29 +11,32 @@ class TestHeaderInclusion(unittest.TestCase): def test_primitives_included_in_header(self) -> None: - base_dir = os.path.join(os.path.dirname(__file__), '..', 'lib-rt') - with open(os.path.join(base_dir, 'CPy.h')) as f: + base_dir = os.path.join(os.path.dirname(__file__), "..", "lib-rt") + with open(os.path.join(base_dir, "CPy.h")) as f: header = f.read() - with open(os.path.join(base_dir, 'pythonsupport.h')) as f: + with open(os.path.join(base_dir, "pythonsupport.h")) as f: header += f.read() def check_name(name: str) -> None: - if name.startswith('CPy'): - assert re.search(fr'\b{name}\b', header), ( - f'"{name}" is used in mypyc.primitives but not declared in CPy.h') + if name.startswith("CPy"): + assert re.search( + rf"\b{name}\b", header + ), f'"{name}" is used in mypyc.primitives but not declared in CPy.h' - for values in [registry.method_call_ops.values(), - registry.function_ops.values(), - registry.binary_ops.values(), - registry.unary_ops.values()]: + for values in [ + registry.method_call_ops.values(), + registry.function_ops.values(), + registry.binary_ops.values(), + registry.unary_ops.values(), + ]: for ops in values: if isinstance(ops, CFunctionDescription): ops = [ops] for op in ops: check_name(op.c_function_name) - primitives_path = os.path.join(os.path.dirname(__file__), '..', 'primitives') - for fnam in glob.glob(f'{primitives_path}/*.py'): + primitives_path = os.path.join(os.path.dirname(__file__), "..", "primitives") + for fnam in glob.glob(f"{primitives_path}/*.py"): with open(fnam) as f: content = f.read() for name in re.findall(r'c_function_name=["\'](CPy[A-Z_a-z0-9]+)', content): diff --git a/mypyc/test/test_commandline.py b/mypyc/test/test_commandline.py index 3ca380f8eebd5..e7f7dc190d9f9 100644 --- a/mypyc/test/test_commandline.py +++ b/mypyc/test/test_commandline.py @@ -10,18 +10,15 @@ import subprocess import sys -from mypy.test.data import DataDrivenTestCase from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase from mypy.test.helpers import normalize_error_messages - from mypyc.test.testutil import MypycDataSuite, assert_test_output -files = [ - 'commandline.test', -] +files = ["commandline.test"] -base_path = os.path.join(os.path.dirname(__file__), '..', '..') +base_path = os.path.join(os.path.dirname(__file__), "..", "..") python3_path = sys.executable @@ -33,40 +30,42 @@ class TestCommandLine(MypycDataSuite): def run_case(self, testcase: DataDrivenTestCase) -> None: # Parse options from test case description (arguments must not have spaces) - text = '\n'.join(testcase.input) - m = re.search(r'# *cmd: *(.*)', text) + text = "\n".join(testcase.input) + m = re.search(r"# *cmd: *(.*)", text) assert m is not None, 'Test case missing "# cmd: " section' args = m.group(1).split() # Write main program to run (not compiled) - program = '_%s.py' % testcase.name + program = "_%s.py" % testcase.name program_path = os.path.join(test_temp_dir, program) - with open(program_path, 'w') as f: + with open(program_path, "w") as f: f.write(text) - out = b'' + out = b"" try: # Compile program - cmd = subprocess.run([sys.executable, '-m', 'mypyc', *args], - stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd='tmp') - if 'ErrorOutput' in testcase.name or cmd.returncode != 0: + cmd = subprocess.run( + [sys.executable, "-m", "mypyc", *args], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd="tmp", + ) + if "ErrorOutput" in testcase.name or cmd.returncode != 0: out += cmd.stdout if cmd.returncode == 0: # Run main program - out += subprocess.check_output( - [python3_path, program], - cwd='tmp') + out += subprocess.check_output([python3_path, program], cwd="tmp") finally: - suffix = 'pyd' if sys.platform == 'win32' else 'so' - so_paths = glob.glob(f'tmp/**/*.{suffix}', recursive=True) + suffix = "pyd" if sys.platform == "win32" else "so" + so_paths = glob.glob(f"tmp/**/*.{suffix}", recursive=True) for path in so_paths: os.remove(path) # Strip out 'tmp/' from error message paths in the testcase output, # due to a mismatch between this test and mypy's test suite. - expected = [x.replace('tmp/', '') for x in testcase.output] + expected = [x.replace("tmp/", "") for x in testcase.output] # Verify output actual = normalize_error_messages(out.decode().splitlines()) - assert_test_output(testcase, actual, 'Invalid output', expected=expected) + assert_test_output(testcase, actual, "Invalid output", expected=expected) diff --git a/mypyc/test/test_emit.py b/mypyc/test/test_emit.py index 1721a68769843..a09b300a2c415 100644 --- a/mypyc/test/test_emit.py +++ b/mypyc/test/test_emit.py @@ -2,32 +2,29 @@ from typing import Dict from mypyc.codegen.emit import Emitter, EmitterContext -from mypyc.ir.ops import BasicBlock, Value, Register +from mypyc.ir.ops import BasicBlock, Register, Value from mypyc.ir.rtypes import int_rprimitive from mypyc.namegen import NameGenerator class TestEmitter(unittest.TestCase): def setUp(self) -> None: - self.n = Register(int_rprimitive, 'n') - self.context = EmitterContext(NameGenerator([['mod']])) + self.n = Register(int_rprimitive, "n") + self.context = EmitterContext(NameGenerator([["mod"]])) def test_label(self) -> None: emitter = Emitter(self.context, {}) - assert emitter.label(BasicBlock(4)) == 'CPyL4' + assert emitter.label(BasicBlock(4)) == "CPyL4" def test_reg(self) -> None: names: Dict[Value, str] = {self.n: "n"} emitter = Emitter(self.context, names) - assert emitter.reg(self.n) == 'cpy_r_n' + assert emitter.reg(self.n) == "cpy_r_n" def test_emit_line(self) -> None: emitter = Emitter(self.context, {}) - emitter.emit_line('line;') - emitter.emit_line('a {') - emitter.emit_line('f();') - emitter.emit_line('}') - assert emitter.fragments == ['line;\n', - 'a {\n', - ' f();\n', - '}\n'] + emitter.emit_line("line;") + emitter.emit_line("a {") + emitter.emit_line("f();") + emitter.emit_line("}") + assert emitter.fragments == ["line;\n", "a {\n", " f();\n", "}\n"] diff --git a/mypyc/test/test_emitclass.py b/mypyc/test/test_emitclass.py index 42bf0af04359c..4e4354a7977c0 100644 --- a/mypyc/test/test_emitclass.py +++ b/mypyc/test/test_emitclass.py @@ -7,21 +7,27 @@ class TestEmitClass(unittest.TestCase): def test_slot_key(self) -> None: - attrs = ['__add__', '__radd__', '__rshift__', '__rrshift__', '__setitem__', '__delitem__'] + attrs = ["__add__", "__radd__", "__rshift__", "__rrshift__", "__setitem__", "__delitem__"] s = sorted(attrs, key=lambda x: slot_key(x)) # __delitem__ and reverse methods should come last. assert s == [ - '__add__', '__rshift__', '__setitem__', '__delitem__', '__radd__', '__rrshift__'] + "__add__", + "__rshift__", + "__setitem__", + "__delitem__", + "__radd__", + "__rrshift__", + ] def test_setter_name(self) -> None: cls = ClassIR(module_name="testing", name="SomeClass") - generator = NameGenerator([['mod']]) + generator = NameGenerator([["mod"]]) # This should never be `setup`, as it will conflict with the class `setup` assert setter_name(cls, "up", generator) == "testing___SomeClass_set_up" def test_getter_name(self) -> None: cls = ClassIR(module_name="testing", name="SomeClass") - generator = NameGenerator([['mod']]) + generator = NameGenerator([["mod"]]) assert getter_name(cls, "down", generator) == "testing___SomeClass_get_down" diff --git a/mypyc/test/test_emitfunc.py b/mypyc/test/test_emitfunc.py index 8ea0906aec613..9f2b7516e2da5 100644 --- a/mypyc/test/test_emitfunc.py +++ b/mypyc/test/test_emitfunc.py @@ -1,37 +1,74 @@ import unittest - from typing import List, Optional from mypy.backports import OrderedDict - from mypy.test.helpers import assert_string_arrays_equal - +from mypyc.codegen.emit import Emitter, EmitterContext +from mypyc.codegen.emitfunc import FunctionEmitterVisitor, generate_native_function +from mypyc.common import PLATFORM_SIZE +from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg from mypyc.ir.ops import ( - BasicBlock, Goto, Return, Integer, Assign, AssignMulti, IncRef, DecRef, Branch, - Call, Unbox, Box, TupleGet, GetAttr, SetAttr, Op, Value, CallC, IntOp, LoadMem, - GetElementPtr, LoadAddress, ComparisonOp, SetMem, Register, Unreachable, Cast, Extend + Assign, + AssignMulti, + BasicBlock, + Box, + Branch, + Call, + CallC, + Cast, + ComparisonOp, + DecRef, + Extend, + GetAttr, + GetElementPtr, + Goto, + IncRef, + Integer, + IntOp, + LoadAddress, + LoadMem, + Op, + Register, + Return, + SetAttr, + SetMem, + TupleGet, + Unbox, + Unreachable, + Value, ) +from mypyc.ir.pprint import generate_names_for_ir from mypyc.ir.rtypes import ( - RTuple, RInstance, RType, RArray, int_rprimitive, bool_rprimitive, list_rprimitive, - dict_rprimitive, object_rprimitive, c_int_rprimitive, short_int_rprimitive, int32_rprimitive, - int64_rprimitive, RStruct, pointer_rprimitive + RArray, + RInstance, + RStruct, + RTuple, + RType, + bool_rprimitive, + c_int_rprimitive, + dict_rprimitive, + int32_rprimitive, + int64_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, + pointer_rprimitive, + short_int_rprimitive, ) -from mypyc.ir.func_ir import FuncIR, FuncDecl, RuntimeArg, FuncSignature -from mypyc.ir.class_ir import ClassIR -from mypyc.ir.pprint import generate_names_for_ir from mypyc.irbuild.vtable import compute_vtable -from mypyc.codegen.emit import Emitter, EmitterContext -from mypyc.codegen.emitfunc import generate_native_function, FunctionEmitterVisitor -from mypyc.primitives.registry import binary_ops -from mypyc.primitives.misc_ops import none_object_op -from mypyc.primitives.list_ops import list_get_item_op, list_set_item_op, list_append_op +from mypyc.namegen import NameGenerator from mypyc.primitives.dict_ops import ( - dict_new_op, dict_update_op, dict_get_item_op, dict_set_item_op + dict_get_item_op, + dict_new_op, + dict_set_item_op, + dict_update_op, ) from mypyc.primitives.int_ops import int_neg_op +from mypyc.primitives.list_ops import list_append_op, list_get_item_op, list_set_item_op +from mypyc.primitives.misc_ops import none_object_op +from mypyc.primitives.registry import binary_ops from mypyc.subtype import is_subtype -from mypyc.namegen import NameGenerator -from mypyc.common import PLATFORM_SIZE class TestFunctionEmitterVisitor(unittest.TestCase): @@ -45,185 +82,193 @@ def add_local(name: str, rtype: RType) -> Register: self.registers.append(reg) return reg - self.n = add_local('n', int_rprimitive) - self.m = add_local('m', int_rprimitive) - self.k = add_local('k', int_rprimitive) - self.l = add_local('l', list_rprimitive) # noqa - self.ll = add_local('ll', list_rprimitive) - self.o = add_local('o', object_rprimitive) - self.o2 = add_local('o2', object_rprimitive) - self.d = add_local('d', dict_rprimitive) - self.b = add_local('b', bool_rprimitive) - self.s1 = add_local('s1', short_int_rprimitive) - self.s2 = add_local('s2', short_int_rprimitive) - self.i32 = add_local('i32', int32_rprimitive) - self.i32_1 = add_local('i32_1', int32_rprimitive) - self.i64 = add_local('i64', int64_rprimitive) - self.i64_1 = add_local('i64_1', int64_rprimitive) - self.ptr = add_local('ptr', pointer_rprimitive) - self.t = add_local('t', RTuple([int_rprimitive, bool_rprimitive])) + self.n = add_local("n", int_rprimitive) + self.m = add_local("m", int_rprimitive) + self.k = add_local("k", int_rprimitive) + self.l = add_local("l", list_rprimitive) # noqa + self.ll = add_local("ll", list_rprimitive) + self.o = add_local("o", object_rprimitive) + self.o2 = add_local("o2", object_rprimitive) + self.d = add_local("d", dict_rprimitive) + self.b = add_local("b", bool_rprimitive) + self.s1 = add_local("s1", short_int_rprimitive) + self.s2 = add_local("s2", short_int_rprimitive) + self.i32 = add_local("i32", int32_rprimitive) + self.i32_1 = add_local("i32_1", int32_rprimitive) + self.i64 = add_local("i64", int64_rprimitive) + self.i64_1 = add_local("i64_1", int64_rprimitive) + self.ptr = add_local("ptr", pointer_rprimitive) + self.t = add_local("t", RTuple([int_rprimitive, bool_rprimitive])) self.tt = add_local( - 'tt', RTuple([RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive])) - ir = ClassIR('A', 'mod') - ir.attributes = OrderedDict([('x', bool_rprimitive), ('y', int_rprimitive)]) + "tt", RTuple([RTuple([int_rprimitive, bool_rprimitive]), bool_rprimitive]) + ) + ir = ClassIR("A", "mod") + ir.attributes = OrderedDict([("x", bool_rprimitive), ("y", int_rprimitive)]) compute_vtable(ir) ir.mro = [ir] - self.r = add_local('r', RInstance(ir)) + self.r = add_local("r", RInstance(ir)) - self.context = EmitterContext(NameGenerator([['mod']])) + self.context = EmitterContext(NameGenerator([["mod"]])) def test_goto(self) -> None: - self.assert_emit(Goto(BasicBlock(2)), - "goto CPyL2;") + self.assert_emit(Goto(BasicBlock(2)), "goto CPyL2;") def test_goto_next_block(self) -> None: next_block = BasicBlock(2) self.assert_emit(Goto(next_block), "", next_block=next_block) def test_return(self) -> None: - self.assert_emit(Return(self.m), - "return cpy_r_m;") + self.assert_emit(Return(self.m), "return cpy_r_m;") def test_integer(self) -> None: - self.assert_emit(Assign(self.n, Integer(5)), - "cpy_r_n = 10;") - self.assert_emit(Assign(self.i32, Integer(5, c_int_rprimitive)), - "cpy_r_i32 = 5;") + self.assert_emit(Assign(self.n, Integer(5)), "cpy_r_n = 10;") + self.assert_emit(Assign(self.i32, Integer(5, c_int_rprimitive)), "cpy_r_i32 = 5;") def test_tuple_get(self) -> None: - self.assert_emit(TupleGet(self.t, 1, 0), 'cpy_r_r0 = cpy_r_t.f1;') + self.assert_emit(TupleGet(self.t, 1, 0), "cpy_r_r0 = cpy_r_t.f1;") def test_load_None(self) -> None: - self.assert_emit(LoadAddress(none_object_op.type, none_object_op.src, 0), - "cpy_r_r0 = (PyObject *)&_Py_NoneStruct;") + self.assert_emit( + LoadAddress(none_object_op.type, none_object_op.src, 0), + "cpy_r_r0 = (PyObject *)&_Py_NoneStruct;", + ) def test_assign_int(self) -> None: - self.assert_emit(Assign(self.m, self.n), - "cpy_r_m = cpy_r_n;") + self.assert_emit(Assign(self.m, self.n), "cpy_r_m = cpy_r_n;") def test_int_add(self) -> None: self.assert_emit_binary_op( - '+', self.n, self.m, self.k, - "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);") + "+", self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Add(cpy_r_m, cpy_r_k);" + ) def test_int_sub(self) -> None: self.assert_emit_binary_op( - '-', self.n, self.m, self.k, - "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);") + "-", self.n, self.m, self.k, "cpy_r_r0 = CPyTagged_Subtract(cpy_r_m, cpy_r_k);" + ) def test_int_neg(self) -> None: - self.assert_emit(CallC(int_neg_op.c_function_name, [self.m], int_neg_op.return_type, - int_neg_op.steals, int_neg_op.is_borrowed, int_neg_op.is_borrowed, - int_neg_op.error_kind, 55), - "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);") + self.assert_emit( + CallC( + int_neg_op.c_function_name, + [self.m], + int_neg_op.return_type, + int_neg_op.steals, + int_neg_op.is_borrowed, + int_neg_op.is_borrowed, + int_neg_op.error_kind, + 55, + ), + "cpy_r_r0 = CPyTagged_Negate(cpy_r_m);", + ) def test_branch(self) -> None: - self.assert_emit(Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL), - """if (cpy_r_b) { + self.assert_emit( + Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL), + """if (cpy_r_b) { goto CPyL8; } else goto CPyL9; - """) + """, + ) b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL) b.negated = True - self.assert_emit(b, - """if (!cpy_r_b) { + self.assert_emit( + b, + """if (!cpy_r_b) { goto CPyL8; } else goto CPyL9; - """) + """, + ) def test_branch_no_else(self) -> None: next_block = BasicBlock(9) b = Branch(self.b, BasicBlock(8), next_block, Branch.BOOL) - self.assert_emit(b, - """if (cpy_r_b) goto CPyL8;""", - next_block=next_block) + self.assert_emit(b, """if (cpy_r_b) goto CPyL8;""", next_block=next_block) next_block = BasicBlock(9) b = Branch(self.b, BasicBlock(8), next_block, Branch.BOOL) b.negated = True - self.assert_emit(b, - """if (!cpy_r_b) goto CPyL8;""", - next_block=next_block) + self.assert_emit(b, """if (!cpy_r_b) goto CPyL8;""", next_block=next_block) def test_branch_no_else_negated(self) -> None: next_block = BasicBlock(1) b = Branch(self.b, next_block, BasicBlock(2), Branch.BOOL) - self.assert_emit(b, - """if (!cpy_r_b) goto CPyL2;""", - next_block=next_block) + self.assert_emit(b, """if (!cpy_r_b) goto CPyL2;""", next_block=next_block) next_block = BasicBlock(1) b = Branch(self.b, next_block, BasicBlock(2), Branch.BOOL) b.negated = True - self.assert_emit(b, - """if (cpy_r_b) goto CPyL2;""", - next_block=next_block) + self.assert_emit(b, """if (cpy_r_b) goto CPyL2;""", next_block=next_block) def test_branch_is_error(self) -> None: b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR) - self.assert_emit(b, - """if (cpy_r_b == 2) { + self.assert_emit( + b, + """if (cpy_r_b == 2) { goto CPyL8; } else goto CPyL9; - """) + """, + ) b = Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR) b.negated = True - self.assert_emit(b, - """if (cpy_r_b != 2) { + self.assert_emit( + b, + """if (cpy_r_b != 2) { goto CPyL8; } else goto CPyL9; - """) + """, + ) def test_branch_is_error_next_block(self) -> None: next_block = BasicBlock(8) b = Branch(self.b, next_block, BasicBlock(9), Branch.IS_ERROR) - self.assert_emit(b, - """if (cpy_r_b != 2) goto CPyL9;""", - next_block=next_block) + self.assert_emit(b, """if (cpy_r_b != 2) goto CPyL9;""", next_block=next_block) b = Branch(self.b, next_block, BasicBlock(9), Branch.IS_ERROR) b.negated = True - self.assert_emit(b, - """if (cpy_r_b == 2) goto CPyL9;""", - next_block=next_block) + self.assert_emit(b, """if (cpy_r_b == 2) goto CPyL9;""", next_block=next_block) def test_branch_rare(self) -> None: - self.assert_emit(Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL, rare=True), - """if (unlikely(cpy_r_b)) { + self.assert_emit( + Branch(self.b, BasicBlock(8), BasicBlock(9), Branch.BOOL, rare=True), + """if (unlikely(cpy_r_b)) { goto CPyL8; } else goto CPyL9; - """) + """, + ) next_block = BasicBlock(9) - self.assert_emit(Branch(self.b, BasicBlock(8), next_block, Branch.BOOL, rare=True), - """if (unlikely(cpy_r_b)) goto CPyL8;""", - next_block=next_block) + self.assert_emit( + Branch(self.b, BasicBlock(8), next_block, Branch.BOOL, rare=True), + """if (unlikely(cpy_r_b)) goto CPyL8;""", + next_block=next_block, + ) next_block = BasicBlock(8) b = Branch(self.b, next_block, BasicBlock(9), Branch.BOOL, rare=True) - self.assert_emit(b, - """if (likely(!cpy_r_b)) goto CPyL9;""", - next_block=next_block) + self.assert_emit(b, """if (likely(!cpy_r_b)) goto CPyL9;""", next_block=next_block) next_block = BasicBlock(8) b = Branch(self.b, next_block, BasicBlock(9), Branch.BOOL, rare=True) b.negated = True - self.assert_emit(b, - """if (likely(cpy_r_b)) goto CPyL9;""", - next_block=next_block) + self.assert_emit(b, """if (likely(cpy_r_b)) goto CPyL9;""", next_block=next_block) def test_call(self) -> None: - decl = FuncDecl('myfn', None, 'mod', - FuncSignature([RuntimeArg('m', int_rprimitive)], int_rprimitive)) - self.assert_emit(Call(decl, [self.m], 55), - "cpy_r_r0 = CPyDef_myfn(cpy_r_m);") + decl = FuncDecl( + "myfn", None, "mod", FuncSignature([RuntimeArg("m", int_rprimitive)], int_rprimitive) + ) + self.assert_emit(Call(decl, [self.m], 55), "cpy_r_r0 = CPyDef_myfn(cpy_r_m);") def test_call_two_args(self) -> None: - decl = FuncDecl('myfn', None, 'mod', - FuncSignature([RuntimeArg('m', int_rprimitive), - RuntimeArg('n', int_rprimitive)], - int_rprimitive)) - self.assert_emit(Call(decl, [self.m, self.k], 55), - "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);") + decl = FuncDecl( + "myfn", + None, + "mod", + FuncSignature( + [RuntimeArg("m", int_rprimitive), RuntimeArg("n", int_rprimitive)], int_rprimitive + ), + ) + self.assert_emit( + Call(decl, [self.m, self.k], 55), "cpy_r_r0 = CPyDef_myfn(cpy_r_m, cpy_r_k);" + ) def test_inc_ref(self) -> None: self.assert_emit(IncRef(self.o), "CPy_INCREF(cpy_r_o);") @@ -242,74 +287,101 @@ def test_dec_ref_int(self) -> None: self.assert_emit(DecRef(self.m), "CPyTagged_DecRef(cpy_r_m);", rare=True) def test_dec_ref_tuple(self) -> None: - self.assert_emit(DecRef(self.t), 'CPyTagged_DECREF(cpy_r_t.f0);') + self.assert_emit(DecRef(self.t), "CPyTagged_DECREF(cpy_r_t.f0);") def test_dec_ref_tuple_nested(self) -> None: - self.assert_emit(DecRef(self.tt), 'CPyTagged_DECREF(cpy_r_tt.f0.f0);') + self.assert_emit(DecRef(self.tt), "CPyTagged_DECREF(cpy_r_tt.f0.f0);") def test_list_get_item(self) -> None: - self.assert_emit(CallC(list_get_item_op.c_function_name, [self.m, self.k], - list_get_item_op.return_type, list_get_item_op.steals, - list_get_item_op.is_borrowed, list_get_item_op.error_kind, 55), - """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""") + self.assert_emit( + CallC( + list_get_item_op.c_function_name, + [self.m, self.k], + list_get_item_op.return_type, + list_get_item_op.steals, + list_get_item_op.is_borrowed, + list_get_item_op.error_kind, + 55, + ), + """cpy_r_r0 = CPyList_GetItem(cpy_r_m, cpy_r_k);""", + ) def test_list_set_item(self) -> None: - self.assert_emit(CallC(list_set_item_op.c_function_name, [self.l, self.n, self.o], - list_set_item_op.return_type, list_set_item_op.steals, - list_set_item_op.is_borrowed, list_set_item_op.error_kind, 55), - """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""") + self.assert_emit( + CallC( + list_set_item_op.c_function_name, + [self.l, self.n, self.o], + list_set_item_op.return_type, + list_set_item_op.steals, + list_set_item_op.is_borrowed, + list_set_item_op.error_kind, + 55, + ), + """cpy_r_r0 = CPyList_SetItem(cpy_r_l, cpy_r_n, cpy_r_o);""", + ) def test_box_int(self) -> None: - self.assert_emit(Box(self.n), - """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""") + self.assert_emit(Box(self.n), """cpy_r_r0 = CPyTagged_StealAsObject(cpy_r_n);""") def test_unbox_int(self) -> None: - self.assert_emit(Unbox(self.m, int_rprimitive, 55), - """if (likely(PyLong_Check(cpy_r_m))) + self.assert_emit( + Unbox(self.m, int_rprimitive, 55), + """if (likely(PyLong_Check(cpy_r_m))) cpy_r_r0 = CPyTagged_FromObject(cpy_r_m); else { CPy_TypeError("int", cpy_r_m); cpy_r_r0 = CPY_INT_TAG; } - """) + """, + ) def test_box_i64(self) -> None: - self.assert_emit(Box(self.i64), - """cpy_r_r0 = PyLong_FromLongLong(cpy_r_i64);""") + self.assert_emit(Box(self.i64), """cpy_r_r0 = PyLong_FromLongLong(cpy_r_i64);""") def test_unbox_i64(self) -> None: - self.assert_emit(Unbox(self.o, int64_rprimitive, 55), - """cpy_r_r0 = CPyLong_AsInt64(cpy_r_o);""") + self.assert_emit( + Unbox(self.o, int64_rprimitive, 55), """cpy_r_r0 = CPyLong_AsInt64(cpy_r_o);""" + ) def test_list_append(self) -> None: - self.assert_emit(CallC(list_append_op.c_function_name, [self.l, self.o], - list_append_op.return_type, list_append_op.steals, - list_append_op.is_borrowed, list_append_op.error_kind, 1), - """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o);""") + self.assert_emit( + CallC( + list_append_op.c_function_name, + [self.l, self.o], + list_append_op.return_type, + list_append_op.steals, + list_append_op.is_borrowed, + list_append_op.error_kind, + 1, + ), + """cpy_r_r0 = PyList_Append(cpy_r_l, cpy_r_o);""", + ) def test_get_attr(self) -> None: self.assert_emit( - GetAttr(self.r, 'y', 1), + GetAttr(self.r, "y", 1), """cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_y; if (unlikely(cpy_r_r0 == CPY_INT_TAG)) { PyErr_SetString(PyExc_AttributeError, "attribute 'y' of 'A' undefined"); } else { CPyTagged_INCREF(cpy_r_r0); } - """) + """, + ) def test_get_attr_non_refcounted(self) -> None: self.assert_emit( - GetAttr(self.r, 'x', 1), + GetAttr(self.r, "x", 1), """cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_x; if (unlikely(cpy_r_r0 == 2)) { PyErr_SetString(PyExc_AttributeError, "attribute 'x' of 'A' undefined"); } - """) + """, + ) def test_get_attr_merged(self) -> None: - op = GetAttr(self.r, 'y', 1) + op = GetAttr(self.r, "y", 1) branch = Branch(op, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR) - branch.traceback_entry = ('foobar', 123) + branch.traceback_entry = ("foobar", 123) self.assert_emit( op, """\ @@ -327,144 +399,227 @@ def test_get_attr_merged(self) -> None: def test_set_attr(self) -> None: self.assert_emit( - SetAttr(self.r, 'y', self.m, 1), + SetAttr(self.r, "y", self.m, 1), """if (((mod___AObject *)cpy_r_r)->_y != CPY_INT_TAG) { CPyTagged_DECREF(((mod___AObject *)cpy_r_r)->_y); } ((mod___AObject *)cpy_r_r)->_y = cpy_r_m; cpy_r_r0 = 1; - """) + """, + ) def test_set_attr_non_refcounted(self) -> None: self.assert_emit( - SetAttr(self.r, 'x', self.b, 1), + SetAttr(self.r, "x", self.b, 1), """((mod___AObject *)cpy_r_r)->_x = cpy_r_b; cpy_r_r0 = 1; - """) + """, + ) def test_dict_get_item(self) -> None: - self.assert_emit(CallC(dict_get_item_op.c_function_name, [self.d, self.o2], - dict_get_item_op.return_type, dict_get_item_op.steals, - dict_get_item_op.is_borrowed, dict_get_item_op.error_kind, 1), - """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""") + self.assert_emit( + CallC( + dict_get_item_op.c_function_name, + [self.d, self.o2], + dict_get_item_op.return_type, + dict_get_item_op.steals, + dict_get_item_op.is_borrowed, + dict_get_item_op.error_kind, + 1, + ), + """cpy_r_r0 = CPyDict_GetItem(cpy_r_d, cpy_r_o2);""", + ) def test_dict_set_item(self) -> None: - self.assert_emit(CallC(dict_set_item_op.c_function_name, [self.d, self.o, self.o2], - dict_set_item_op.return_type, dict_set_item_op.steals, - dict_set_item_op.is_borrowed, dict_set_item_op.error_kind, 1), - """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2);""") + self.assert_emit( + CallC( + dict_set_item_op.c_function_name, + [self.d, self.o, self.o2], + dict_set_item_op.return_type, + dict_set_item_op.steals, + dict_set_item_op.is_borrowed, + dict_set_item_op.error_kind, + 1, + ), + """cpy_r_r0 = CPyDict_SetItem(cpy_r_d, cpy_r_o, cpy_r_o2);""", + ) def test_dict_update(self) -> None: - self.assert_emit(CallC(dict_update_op.c_function_name, [self.d, self.o], - dict_update_op.return_type, dict_update_op.steals, - dict_update_op.is_borrowed, dict_update_op.error_kind, 1), - """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o);""") + self.assert_emit( + CallC( + dict_update_op.c_function_name, + [self.d, self.o], + dict_update_op.return_type, + dict_update_op.steals, + dict_update_op.is_borrowed, + dict_update_op.error_kind, + 1, + ), + """cpy_r_r0 = CPyDict_Update(cpy_r_d, cpy_r_o);""", + ) def test_new_dict(self) -> None: - self.assert_emit(CallC(dict_new_op.c_function_name, [], dict_new_op.return_type, - dict_new_op.steals, dict_new_op.is_borrowed, - dict_new_op.error_kind, 1), - """cpy_r_r0 = PyDict_New();""") + self.assert_emit( + CallC( + dict_new_op.c_function_name, + [], + dict_new_op.return_type, + dict_new_op.steals, + dict_new_op.is_borrowed, + dict_new_op.error_kind, + 1, + ), + """cpy_r_r0 = PyDict_New();""", + ) def test_dict_contains(self) -> None: self.assert_emit_binary_op( - 'in', self.b, self.o, self.d, - """cpy_r_r0 = PyDict_Contains(cpy_r_d, cpy_r_o);""") + "in", self.b, self.o, self.d, """cpy_r_r0 = PyDict_Contains(cpy_r_d, cpy_r_o);""" + ) def test_int_op(self) -> None: - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.ADD, 1), - """cpy_r_r0 = cpy_r_s1 + cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.SUB, 1), - """cpy_r_r0 = cpy_r_s1 - cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.MUL, 1), - """cpy_r_r0 = cpy_r_s1 * cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.DIV, 1), - """cpy_r_r0 = cpy_r_s1 / cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.MOD, 1), - """cpy_r_r0 = cpy_r_s1 % cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.AND, 1), - """cpy_r_r0 = cpy_r_s1 & cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.OR, 1), - """cpy_r_r0 = cpy_r_s1 | cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.XOR, 1), - """cpy_r_r0 = cpy_r_s1 ^ cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.LEFT_SHIFT, 1), - """cpy_r_r0 = cpy_r_s1 << cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.RIGHT_SHIFT, 1), - """cpy_r_r0 = (Py_ssize_t)cpy_r_s1 >> (Py_ssize_t)cpy_r_s2;""") - self.assert_emit(IntOp(short_int_rprimitive, self.i64, self.i64_1, IntOp.RIGHT_SHIFT, 1), - """cpy_r_r0 = cpy_r_i64 >> cpy_r_i64_1;""") + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.ADD, 1), + """cpy_r_r0 = cpy_r_s1 + cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.SUB, 1), + """cpy_r_r0 = cpy_r_s1 - cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.MUL, 1), + """cpy_r_r0 = cpy_r_s1 * cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.DIV, 1), + """cpy_r_r0 = cpy_r_s1 / cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.MOD, 1), + """cpy_r_r0 = cpy_r_s1 % cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.AND, 1), + """cpy_r_r0 = cpy_r_s1 & cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.OR, 1), + """cpy_r_r0 = cpy_r_s1 | cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.XOR, 1), + """cpy_r_r0 = cpy_r_s1 ^ cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.LEFT_SHIFT, 1), + """cpy_r_r0 = cpy_r_s1 << cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.s1, self.s2, IntOp.RIGHT_SHIFT, 1), + """cpy_r_r0 = (Py_ssize_t)cpy_r_s1 >> (Py_ssize_t)cpy_r_s2;""", + ) + self.assert_emit( + IntOp(short_int_rprimitive, self.i64, self.i64_1, IntOp.RIGHT_SHIFT, 1), + """cpy_r_r0 = cpy_r_i64 >> cpy_r_i64_1;""", + ) def test_comparison_op(self) -> None: # signed - self.assert_emit(ComparisonOp(self.s1, self.s2, ComparisonOp.SLT, 1), - """cpy_r_r0 = (Py_ssize_t)cpy_r_s1 < (Py_ssize_t)cpy_r_s2;""") - self.assert_emit(ComparisonOp(self.i32, self.i32_1, ComparisonOp.SLT, 1), - """cpy_r_r0 = cpy_r_i32 < cpy_r_i32_1;""") - self.assert_emit(ComparisonOp(self.i64, self.i64_1, ComparisonOp.SLT, 1), - """cpy_r_r0 = cpy_r_i64 < cpy_r_i64_1;""") + self.assert_emit( + ComparisonOp(self.s1, self.s2, ComparisonOp.SLT, 1), + """cpy_r_r0 = (Py_ssize_t)cpy_r_s1 < (Py_ssize_t)cpy_r_s2;""", + ) + self.assert_emit( + ComparisonOp(self.i32, self.i32_1, ComparisonOp.SLT, 1), + """cpy_r_r0 = cpy_r_i32 < cpy_r_i32_1;""", + ) + self.assert_emit( + ComparisonOp(self.i64, self.i64_1, ComparisonOp.SLT, 1), + """cpy_r_r0 = cpy_r_i64 < cpy_r_i64_1;""", + ) # unsigned - self.assert_emit(ComparisonOp(self.s1, self.s2, ComparisonOp.ULT, 1), - """cpy_r_r0 = cpy_r_s1 < cpy_r_s2;""") - self.assert_emit(ComparisonOp(self.i32, self.i32_1, ComparisonOp.ULT, 1), - """cpy_r_r0 = (uint32_t)cpy_r_i32 < (uint32_t)cpy_r_i32_1;""") - self.assert_emit(ComparisonOp(self.i64, self.i64_1, ComparisonOp.ULT, 1), - """cpy_r_r0 = (uint64_t)cpy_r_i64 < (uint64_t)cpy_r_i64_1;""") + self.assert_emit( + ComparisonOp(self.s1, self.s2, ComparisonOp.ULT, 1), + """cpy_r_r0 = cpy_r_s1 < cpy_r_s2;""", + ) + self.assert_emit( + ComparisonOp(self.i32, self.i32_1, ComparisonOp.ULT, 1), + """cpy_r_r0 = (uint32_t)cpy_r_i32 < (uint32_t)cpy_r_i32_1;""", + ) + self.assert_emit( + ComparisonOp(self.i64, self.i64_1, ComparisonOp.ULT, 1), + """cpy_r_r0 = (uint64_t)cpy_r_i64 < (uint64_t)cpy_r_i64_1;""", + ) # object type - self.assert_emit(ComparisonOp(self.o, self.o2, ComparisonOp.EQ, 1), - """cpy_r_r0 = cpy_r_o == cpy_r_o2;""") - self.assert_emit(ComparisonOp(self.o, self.o2, ComparisonOp.NEQ, 1), - """cpy_r_r0 = cpy_r_o != cpy_r_o2;""") + self.assert_emit( + ComparisonOp(self.o, self.o2, ComparisonOp.EQ, 1), + """cpy_r_r0 = cpy_r_o == cpy_r_o2;""", + ) + self.assert_emit( + ComparisonOp(self.o, self.o2, ComparisonOp.NEQ, 1), + """cpy_r_r0 = cpy_r_o != cpy_r_o2;""", + ) def test_load_mem(self) -> None: - self.assert_emit(LoadMem(bool_rprimitive, self.ptr), - """cpy_r_r0 = *(char *)cpy_r_ptr;""") + self.assert_emit(LoadMem(bool_rprimitive, self.ptr), """cpy_r_r0 = *(char *)cpy_r_ptr;""") def test_set_mem(self) -> None: - self.assert_emit(SetMem(bool_rprimitive, self.ptr, self.b), - """*(char *)cpy_r_ptr = cpy_r_b;""") + self.assert_emit( + SetMem(bool_rprimitive, self.ptr, self.b), """*(char *)cpy_r_ptr = cpy_r_b;""" + ) def test_get_element_ptr(self) -> None: - r = RStruct("Foo", ["b", "i32", "i64"], [bool_rprimitive, - int32_rprimitive, int64_rprimitive]) - self.assert_emit(GetElementPtr(self.o, r, "b"), - """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->b;""") - self.assert_emit(GetElementPtr(self.o, r, "i32"), - """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->i32;""") - self.assert_emit(GetElementPtr(self.o, r, "i64"), - """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->i64;""") + r = RStruct( + "Foo", ["b", "i32", "i64"], [bool_rprimitive, int32_rprimitive, int64_rprimitive] + ) + self.assert_emit( + GetElementPtr(self.o, r, "b"), """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->b;""" + ) + self.assert_emit( + GetElementPtr(self.o, r, "i32"), """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->i32;""" + ) + self.assert_emit( + GetElementPtr(self.o, r, "i64"), """cpy_r_r0 = (CPyPtr)&((Foo *)cpy_r_o)->i64;""" + ) def test_load_address(self) -> None: - self.assert_emit(LoadAddress(object_rprimitive, "PyDict_Type"), - """cpy_r_r0 = (PyObject *)&PyDict_Type;""") + self.assert_emit( + LoadAddress(object_rprimitive, "PyDict_Type"), + """cpy_r_r0 = (PyObject *)&PyDict_Type;""", + ) def test_assign_multi(self) -> None: t = RArray(object_rprimitive, 2) - a = Register(t, 'a') + a = Register(t, "a") self.registers.append(a) - self.assert_emit(AssignMulti(a, [self.o, self.o2]), - """PyObject *cpy_r_a[2] = {cpy_r_o, cpy_r_o2};""") + self.assert_emit( + AssignMulti(a, [self.o, self.o2]), """PyObject *cpy_r_a[2] = {cpy_r_o, cpy_r_o2};""" + ) def test_long_unsigned(self) -> None: - a = Register(int64_rprimitive, 'a') - self.assert_emit(Assign(a, Integer(1 << 31, int64_rprimitive)), - """cpy_r_a = 2147483648LL;""") - self.assert_emit(Assign(a, Integer((1 << 31) - 1, int64_rprimitive)), - """cpy_r_a = 2147483647;""") + a = Register(int64_rprimitive, "a") + self.assert_emit( + Assign(a, Integer(1 << 31, int64_rprimitive)), """cpy_r_a = 2147483648LL;""" + ) + self.assert_emit( + Assign(a, Integer((1 << 31) - 1, int64_rprimitive)), """cpy_r_a = 2147483647;""" + ) def test_long_signed(self) -> None: - a = Register(int64_rprimitive, 'a') - self.assert_emit(Assign(a, Integer(-(1 << 31) + 1, int64_rprimitive)), - """cpy_r_a = -2147483647;""") - self.assert_emit(Assign(a, Integer(-(1 << 31), int64_rprimitive)), - """cpy_r_a = -2147483648LL;""") + a = Register(int64_rprimitive, "a") + self.assert_emit( + Assign(a, Integer(-(1 << 31) + 1, int64_rprimitive)), """cpy_r_a = -2147483647;""" + ) + self.assert_emit( + Assign(a, Integer(-(1 << 31), int64_rprimitive)), """cpy_r_a = -2147483648LL;""" + ) def test_cast_and_branch_merge(self) -> None: op = Cast(self.r, dict_rprimitive, 1) next_block = BasicBlock(9) branch = Branch(op, BasicBlock(8), next_block, Branch.IS_ERROR) - branch.traceback_entry = ('foobar', 123) + branch.traceback_entry = ("foobar", 123) self.assert_emit( op, """\ @@ -483,7 +638,7 @@ def test_cast_and_branch_merge(self) -> None: def test_cast_and_branch_no_merge_1(self) -> None: op = Cast(self.r, dict_rprimitive, 1) branch = Branch(op, BasicBlock(8), BasicBlock(9), Branch.IS_ERROR) - branch.traceback_entry = ('foobar', 123) + branch.traceback_entry = ("foobar", 123) self.assert_emit( op, """\ @@ -504,7 +659,7 @@ def test_cast_and_branch_no_merge_2(self) -> None: next_block = BasicBlock(9) branch = Branch(op, BasicBlock(8), next_block, Branch.IS_ERROR) branch.negated = True - branch.traceback_entry = ('foobar', 123) + branch.traceback_entry = ("foobar", 123) self.assert_emit( op, """\ @@ -523,7 +678,7 @@ def test_cast_and_branch_no_merge_3(self) -> None: op = Cast(self.r, dict_rprimitive, 1) next_block = BasicBlock(9) branch = Branch(op, BasicBlock(8), next_block, Branch.BOOL) - branch.traceback_entry = ('foobar', 123) + branch.traceback_entry = ("foobar", 123) self.assert_emit( op, """\ @@ -557,30 +712,35 @@ def test_cast_and_branch_no_merge_4(self) -> None: ) def test_extend(self) -> None: - a = Register(int32_rprimitive, 'a') - self.assert_emit(Extend(a, int64_rprimitive, signed=True), - """cpy_r_r0 = cpy_r_a;""") - self.assert_emit(Extend(a, int64_rprimitive, signed=False), - """cpy_r_r0 = (uint32_t)cpy_r_a;""") + a = Register(int32_rprimitive, "a") + self.assert_emit(Extend(a, int64_rprimitive, signed=True), """cpy_r_r0 = cpy_r_a;""") + self.assert_emit( + Extend(a, int64_rprimitive, signed=False), """cpy_r_r0 = (uint32_t)cpy_r_a;""" + ) if PLATFORM_SIZE == 4: - self.assert_emit(Extend(self.n, int64_rprimitive, signed=True), - """cpy_r_r0 = (Py_ssize_t)cpy_r_n;""") - self.assert_emit(Extend(self.n, int64_rprimitive, signed=False), - """cpy_r_r0 = cpy_r_n;""") + self.assert_emit( + Extend(self.n, int64_rprimitive, signed=True), + """cpy_r_r0 = (Py_ssize_t)cpy_r_n;""", + ) + self.assert_emit( + Extend(self.n, int64_rprimitive, signed=False), """cpy_r_r0 = cpy_r_n;""" + ) if PLATFORM_SIZE == 8: - self.assert_emit(Extend(a, int_rprimitive, signed=True), - """cpy_r_r0 = cpy_r_a;""") - self.assert_emit(Extend(a, int_rprimitive, signed=False), - """cpy_r_r0 = (uint32_t)cpy_r_a;""") - - def assert_emit(self, - op: Op, - expected: str, - next_block: Optional[BasicBlock] = None, - *, - rare: bool = False, - next_branch: Optional[Branch] = None, - skip_next: bool = False) -> None: + self.assert_emit(Extend(a, int_rprimitive, signed=True), """cpy_r_r0 = cpy_r_a;""") + self.assert_emit( + Extend(a, int_rprimitive, signed=False), """cpy_r_r0 = (uint32_t)cpy_r_a;""" + ) + + def assert_emit( + self, + op: Op, + expected: str, + next_block: Optional[BasicBlock] = None, + *, + rare: bool = False, + next_branch: Optional[Branch] = None, + skip_next: bool = False, + ) -> None: block = BasicBlock(0) block.ops.append(op) value_names = generate_names_for_ir(self.registers, [block]) @@ -589,7 +749,7 @@ def assert_emit(self, emitter.fragments = [] declarations.fragments = [] - visitor = FunctionEmitterVisitor(emitter, declarations, 'prog.py', 'prog') + visitor = FunctionEmitterVisitor(emitter, declarations, "prog.py", "prog") visitor.next_block = next_block visitor.rare = rare if next_branch: @@ -600,86 +760,100 @@ def assert_emit(self, op.accept(visitor) frags = declarations.fragments + emitter.fragments - actual_lines = [line.strip(' ') for line in frags] - assert all(line.endswith('\n') for line in actual_lines) - actual_lines = [line.rstrip('\n') for line in actual_lines] + actual_lines = [line.strip(" ") for line in frags] + assert all(line.endswith("\n") for line in actual_lines) + actual_lines = [line.rstrip("\n") for line in actual_lines] if not expected.strip(): expected_lines = [] else: - expected_lines = expected.rstrip().split('\n') - expected_lines = [line.strip(' ') for line in expected_lines] - assert_string_arrays_equal(expected_lines, actual_lines, - msg='Generated code unexpected') + expected_lines = expected.rstrip().split("\n") + expected_lines = [line.strip(" ") for line in expected_lines] + assert_string_arrays_equal(expected_lines, actual_lines, msg="Generated code unexpected") if skip_next: assert visitor.op_index == 1 else: assert visitor.op_index == 0 - def assert_emit_binary_op(self, - op: str, - dest: Value, - left: Value, - right: Value, - expected: str) -> None: + def assert_emit_binary_op( + self, op: str, dest: Value, left: Value, right: Value, expected: str + ) -> None: if op in binary_ops: ops = binary_ops[op] for desc in ops: - if (is_subtype(left.type, desc.arg_types[0]) - and is_subtype(right.type, desc.arg_types[1])): + if is_subtype(left.type, desc.arg_types[0]) and is_subtype( + right.type, desc.arg_types[1] + ): args = [left, right] if desc.ordering is not None: args = [args[i] for i in desc.ordering] - self.assert_emit(CallC(desc.c_function_name, args, desc.return_type, - desc.steals, desc.is_borrowed, - desc.error_kind, 55), expected) + self.assert_emit( + CallC( + desc.c_function_name, + args, + desc.return_type, + desc.steals, + desc.is_borrowed, + desc.error_kind, + 55, + ), + expected, + ) return else: - assert False, 'Could not find matching op' + assert False, "Could not find matching op" class TestGenerateFunction(unittest.TestCase): def setUp(self) -> None: - self.arg = RuntimeArg('arg', int_rprimitive) - self.reg = Register(int_rprimitive, 'arg') + self.arg = RuntimeArg("arg", int_rprimitive) + self.reg = Register(int_rprimitive, "arg") self.block = BasicBlock(0) def test_simple(self) -> None: self.block.ops.append(Return(self.reg)) - fn = FuncIR(FuncDecl('myfunc', None, 'mod', FuncSignature([self.arg], int_rprimitive)), - [self.reg], - [self.block]) + fn = FuncIR( + FuncDecl("myfunc", None, "mod", FuncSignature([self.arg], int_rprimitive)), + [self.reg], + [self.block], + ) value_names = generate_names_for_ir(fn.arg_regs, fn.blocks) - emitter = Emitter(EmitterContext(NameGenerator([['mod']])), value_names) - generate_native_function(fn, emitter, 'prog.py', 'prog') + emitter = Emitter(EmitterContext(NameGenerator([["mod"]])), value_names) + generate_native_function(fn, emitter, "prog.py", "prog") result = emitter.fragments assert_string_arrays_equal( [ - 'CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', - 'CPyL0: ;\n', - ' return cpy_r_arg;\n', - '}\n', + "CPyTagged CPyDef_myfunc(CPyTagged cpy_r_arg) {\n", + "CPyL0: ;\n", + " return cpy_r_arg;\n", + "}\n", ], - result, msg='Generated code invalid') + result, + msg="Generated code invalid", + ) def test_register(self) -> None: reg = Register(int_rprimitive) op = Assign(reg, Integer(5)) self.block.ops.append(op) self.block.ops.append(Unreachable()) - fn = FuncIR(FuncDecl('myfunc', None, 'mod', FuncSignature([self.arg], list_rprimitive)), - [self.reg], - [self.block]) + fn = FuncIR( + FuncDecl("myfunc", None, "mod", FuncSignature([self.arg], list_rprimitive)), + [self.reg], + [self.block], + ) value_names = generate_names_for_ir(fn.arg_regs, fn.blocks) - emitter = Emitter(EmitterContext(NameGenerator([['mod']])), value_names) - generate_native_function(fn, emitter, 'prog.py', 'prog') + emitter = Emitter(EmitterContext(NameGenerator([["mod"]])), value_names) + generate_native_function(fn, emitter, "prog.py", "prog") result = emitter.fragments assert_string_arrays_equal( [ - 'PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n', - ' CPyTagged cpy_r_r0;\n', - 'CPyL0: ;\n', - ' cpy_r_r0 = 10;\n', - ' CPy_Unreachable();\n', - '}\n', + "PyObject *CPyDef_myfunc(CPyTagged cpy_r_arg) {\n", + " CPyTagged cpy_r_r0;\n", + "CPyL0: ;\n", + " cpy_r_r0 = 10;\n", + " CPy_Unreachable();\n", + "}\n", ], - result, msg='Generated code invalid') + result, + msg="Generated code invalid", + ) diff --git a/mypyc/test/test_emitwrapper.py b/mypyc/test/test_emitwrapper.py index 3eb1be37bfb68..3556a6e01d0f2 100644 --- a/mypyc/test/test_emitwrapper.py +++ b/mypyc/test/test_emitwrapper.py @@ -2,53 +2,58 @@ from typing import List from mypy.test.helpers import assert_string_arrays_equal - from mypyc.codegen.emit import Emitter, EmitterContext, ReturnHandler from mypyc.codegen.emitwrapper import generate_arg_check -from mypyc.ir.rtypes import list_rprimitive, int_rprimitive +from mypyc.ir.rtypes import int_rprimitive, list_rprimitive from mypyc.namegen import NameGenerator class TestArgCheck(unittest.TestCase): def setUp(self) -> None: - self.context = EmitterContext(NameGenerator([['mod']])) + self.context = EmitterContext(NameGenerator([["mod"]])) def test_check_list(self) -> None: emitter = Emitter(self.context) - generate_arg_check('x', list_rprimitive, emitter, ReturnHandler('NULL')) + generate_arg_check("x", list_rprimitive, emitter, ReturnHandler("NULL")) lines = emitter.fragments - self.assert_lines([ - 'PyObject *arg_x;', - 'if (likely(PyList_Check(obj_x)))', - ' arg_x = obj_x;', - 'else {', - ' CPy_TypeError("list", obj_x);', - ' return NULL;', - '}', - ], lines) + self.assert_lines( + [ + "PyObject *arg_x;", + "if (likely(PyList_Check(obj_x)))", + " arg_x = obj_x;", + "else {", + ' CPy_TypeError("list", obj_x);', + " return NULL;", + "}", + ], + lines, + ) def test_check_int(self) -> None: emitter = Emitter(self.context) - generate_arg_check('x', int_rprimitive, emitter, ReturnHandler('NULL')) - generate_arg_check('y', int_rprimitive, emitter, ReturnHandler('NULL'), optional=True) + generate_arg_check("x", int_rprimitive, emitter, ReturnHandler("NULL")) + generate_arg_check("y", int_rprimitive, emitter, ReturnHandler("NULL"), optional=True) lines = emitter.fragments - self.assert_lines([ - 'CPyTagged arg_x;', - 'if (likely(PyLong_Check(obj_x)))', - ' arg_x = CPyTagged_BorrowFromObject(obj_x);', - 'else {', - ' CPy_TypeError("int", obj_x); return NULL;', - '}', - 'CPyTagged arg_y;', - 'if (obj_y == NULL) {', - ' arg_y = CPY_INT_TAG;', - '} else if (likely(PyLong_Check(obj_y)))', - ' arg_y = CPyTagged_BorrowFromObject(obj_y);', - 'else {', - ' CPy_TypeError("int", obj_y); return NULL;', - '}', - ], lines) + self.assert_lines( + [ + "CPyTagged arg_x;", + "if (likely(PyLong_Check(obj_x)))", + " arg_x = CPyTagged_BorrowFromObject(obj_x);", + "else {", + ' CPy_TypeError("int", obj_x); return NULL;', + "}", + "CPyTagged arg_y;", + "if (obj_y == NULL) {", + " arg_y = CPY_INT_TAG;", + "} else if (likely(PyLong_Check(obj_y)))", + " arg_y = CPyTagged_BorrowFromObject(obj_y);", + "else {", + ' CPy_TypeError("int", obj_y); return NULL;', + "}", + ], + lines, + ) def assert_lines(self, expected: List[str], actual: List[str]) -> None: - actual = [line.rstrip('\n') for line in actual] - assert_string_arrays_equal(expected, actual, 'Invalid output') + actual = [line.rstrip("\n") for line in actual] + assert_string_arrays_equal(expected, actual, "Invalid output") diff --git a/mypyc/test/test_exceptions.py b/mypyc/test/test_exceptions.py index 802024f2c86b1..790e984f84b5b 100644 --- a/mypyc/test/test_exceptions.py +++ b/mypyc/test/test_exceptions.py @@ -5,25 +5,25 @@ import os.path +from mypy.errors import CompileError from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase -from mypy.errors import CompileError - +from mypyc.analysis.blockfreq import frequently_executed_blocks from mypyc.common import TOP_LEVEL_NAME from mypyc.ir.pprint import format_func -from mypyc.transform.uninit import insert_uninit_checks -from mypyc.transform.exceptions import insert_exception_handling -from mypyc.transform.refcount import insert_ref_count_opcodes from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file, - assert_test_output, remove_comment_lines + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + remove_comment_lines, + use_custom_builtins, ) -from mypyc.analysis.blockfreq import frequently_executed_blocks +from mypyc.transform.exceptions import insert_exception_handling +from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.uninit import insert_uninit_checks -files = [ - 'exceptions.test', - 'exceptions-freq.test', -] +files = ["exceptions.test", "exceptions-freq.test"] class TestExceptionTransform(MypycDataSuite): @@ -41,16 +41,14 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: actual = [] for fn in ir: - if (fn.name == TOP_LEVEL_NAME - and not testcase.name.endswith('_toplevel')): + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): continue insert_uninit_checks(fn) insert_exception_handling(fn) insert_ref_count_opcodes(fn) actual.extend(format_func(fn)) - if testcase.name.endswith('_freq'): + if testcase.name.endswith("_freq"): common = frequently_executed_blocks(fn.blocks[0]) - actual.append('hot blocks: %s' % sorted(b.label for b in common)) + actual.append("hot blocks: %s" % sorted(b.label for b in common)) - assert_test_output(testcase, actual, 'Invalid source code output', - expected_output) + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/test/test_external.py b/mypyc/test/test_external.py index 5e8e5b54dd8a3..de3f2a147f7b8 100644 --- a/mypyc/test/test_external.py +++ b/mypyc/test/test_external.py @@ -1,14 +1,12 @@ """Test cases that run tests as subprocesses.""" -from typing import List - import os import subprocess import sys import unittest +from typing import List - -base_dir = os.path.join(os.path.dirname(__file__), '..', '..') +base_dir = os.path.join(os.path.dirname(__file__), "..", "..") class TestExternal(unittest.TestCase): @@ -21,27 +19,30 @@ def test_c_unit_test(self) -> None: # The source code for Google Test is copied to this repository. cppflags: List[str] = [] env = os.environ.copy() - if sys.platform == 'darwin': - cppflags += ['-mmacosx-version-min=10.10', '-stdlib=libc++'] - env['CPPFLAGS'] = ' '.join(cppflags) + if sys.platform == "darwin": + cppflags += ["-mmacosx-version-min=10.10", "-stdlib=libc++"] + env["CPPFLAGS"] = " ".join(cppflags) subprocess.check_call( - ['make', 'libgtest.a'], + ["make", "libgtest.a"], env=env, - cwd=os.path.join(base_dir, 'mypyc', 'external', 'googletest', 'make')) + cwd=os.path.join(base_dir, "mypyc", "external", "googletest", "make"), + ) # Build Python wrapper for C unit tests. env = os.environ.copy() - env['CPPFLAGS'] = ' '.join(cppflags) + env["CPPFLAGS"] = " ".join(cppflags) status = subprocess.check_call( - [sys.executable, 'setup.py', 'build_ext', '--inplace'], + [sys.executable, "setup.py", "build_ext", "--inplace"], env=env, - cwd=os.path.join(base_dir, 'mypyc', 'lib-rt')) + cwd=os.path.join(base_dir, "mypyc", "lib-rt"), + ) # Run C unit tests. env = os.environ.copy() - if 'GTEST_COLOR' not in os.environ: - env['GTEST_COLOR'] = 'yes' # Use fancy colors - status = subprocess.call([sys.executable, '-c', - 'import sys, test_capi; sys.exit(test_capi.run_tests())'], - env=env, - cwd=os.path.join(base_dir, 'mypyc', 'lib-rt')) + if "GTEST_COLOR" not in os.environ: + env["GTEST_COLOR"] = "yes" # Use fancy colors + status = subprocess.call( + [sys.executable, "-c", "import sys, test_capi; sys.exit(test_capi.run_tests())"], + env=env, + cwd=os.path.join(base_dir, "mypyc", "lib-rt"), + ) if status != 0: raise AssertionError("make test: C unit test failure") diff --git a/mypyc/test/test_irbuild.py b/mypyc/test/test_irbuild.py index 12da67b8dc1a0..10c406f0486c0 100644 --- a/mypyc/test/test_irbuild.py +++ b/mypyc/test/test_irbuild.py @@ -2,41 +2,45 @@ import os.path +from mypy.errors import CompileError from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase -from mypy.errors import CompileError - from mypyc.common import TOP_LEVEL_NAME from mypyc.ir.pprint import format_func from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file, - assert_test_output, remove_comment_lines, replace_word_size, - infer_ir_build_options_from_test_name + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + infer_ir_build_options_from_test_name, + remove_comment_lines, + replace_word_size, + use_custom_builtins, ) files = [ - 'irbuild-basic.test', - 'irbuild-int.test', - 'irbuild-lists.test', - 'irbuild-tuple.test', - 'irbuild-dict.test', - 'irbuild-set.test', - 'irbuild-str.test', - 'irbuild-bytes.test', - 'irbuild-statements.test', - 'irbuild-nested.test', - 'irbuild-classes.test', - 'irbuild-optional.test', - 'irbuild-any.test', - 'irbuild-generics.test', - 'irbuild-try.test', - 'irbuild-strip-asserts.test', - 'irbuild-vectorcall.test', - 'irbuild-unreachable.test', - 'irbuild-isinstance.test', - 'irbuild-dunders.test', - 'irbuild-singledispatch.test', - 'irbuild-constant-fold.test', + "irbuild-basic.test", + "irbuild-int.test", + "irbuild-lists.test", + "irbuild-tuple.test", + "irbuild-dict.test", + "irbuild-set.test", + "irbuild-str.test", + "irbuild-bytes.test", + "irbuild-statements.test", + "irbuild-nested.test", + "irbuild-classes.test", + "irbuild-optional.test", + "irbuild-any.test", + "irbuild-generics.test", + "irbuild-try.test", + "irbuild-strip-asserts.test", + "irbuild-vectorcall.test", + "irbuild-unreachable.test", + "irbuild-isinstance.test", + "irbuild-dunders.test", + "irbuild-singledispatch.test", + "irbuild-constant-fold.test", ] @@ -62,10 +66,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: actual = [] for fn in ir: - if (fn.name == TOP_LEVEL_NAME - and not name.endswith('_toplevel')): + if fn.name == TOP_LEVEL_NAME and not name.endswith("_toplevel"): continue actual.extend(format_func(fn)) - assert_test_output(testcase, actual, 'Invalid source code output', - expected_output) + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/test/test_ircheck.py b/mypyc/test/test_ircheck.py index 3c56ddac3e858..0141d7adee336 100644 --- a/mypyc/test/test_ircheck.py +++ b/mypyc/test/test_ircheck.py @@ -1,17 +1,22 @@ import unittest from typing import List, Optional -from mypyc.analysis.ircheck import check_func_ir, FnError, can_coerce_to +from mypyc.analysis.ircheck import FnError, can_coerce_to, check_func_ir from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature +from mypyc.ir.ops import Assign, BasicBlock, Goto, Integer, LoadLiteral, Op, Register, Return +from mypyc.ir.pprint import format_func from mypyc.ir.rtypes import ( - none_rprimitive, str_rprimitive, int32_rprimitive, int64_rprimitive, - RType, RUnion, bytes_rprimitive, RInstance, object_rprimitive -) -from mypyc.ir.ops import ( - BasicBlock, Op, Return, Integer, Goto, Register, LoadLiteral, Assign + RInstance, + RType, + RUnion, + bytes_rprimitive, + int32_rprimitive, + int64_rprimitive, + none_rprimitive, + object_rprimitive, + str_rprimitive, ) -from mypyc.ir.func_ir import FuncIR, FuncDecl, FuncSignature -from mypyc.ir.pprint import format_func def assert_has_error(fn: FuncIR, error: FnError) -> None: @@ -43,10 +48,7 @@ def func_decl(self, name: str, ret_type: Optional[RType] = None) -> FuncDecl: name=name, class_name=None, module_name="module", - sig=FuncSignature( - args=[], - ret_type=ret_type, - ), + sig=FuncSignature(args=[], ret_type=ret_type), ) def test_valid_fn(self) -> None: @@ -54,33 +56,19 @@ def test_valid_fn(self) -> None: FuncIR( decl=self.func_decl(name="func_1"), arg_regs=[], - blocks=[ - self.basic_block( - ops=[ - Return(value=NONE_VALUE), - ] - ) - ], + blocks=[self.basic_block(ops=[Return(value=NONE_VALUE)])], ) ) def test_block_not_terminated_empty_block(self) -> None: block = self.basic_block([]) - fn = FuncIR( - decl=self.func_decl(name="func_1"), - arg_regs=[], - blocks=[block], - ) + fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block]) assert_has_error(fn, FnError(source=block, desc="Block not terminated")) def test_valid_goto(self) -> None: block_1 = self.basic_block([Return(value=NONE_VALUE)]) block_2 = self.basic_block([Goto(label=block_1)]) - fn = FuncIR( - decl=self.func_decl(name="func_1"), - arg_regs=[], - blocks=[block_1, block_2], - ) + fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block_1, block_2]) assert_no_errors(fn) def test_invalid_goto(self) -> None: @@ -93,43 +81,20 @@ def test_invalid_goto(self) -> None: # block_1 omitted blocks=[block_2], ) - assert_has_error( - fn, FnError(source=goto, desc="Invalid control operation target: 1") - ) + assert_has_error(fn, FnError(source=goto, desc="Invalid control operation target: 1")) def test_invalid_register_source(self) -> None: - ret = Return( - value=Register( - type=none_rprimitive, - name="r1", - ) - ) + ret = Return(value=Register(type=none_rprimitive, name="r1")) block = self.basic_block([ret]) - fn = FuncIR( - decl=self.func_decl(name="func_1"), - arg_regs=[], - blocks=[block], - ) - assert_has_error( - fn, FnError(source=ret, desc="Invalid op reference to register r1") - ) + fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block]) + assert_has_error(fn, FnError(source=ret, desc="Invalid op reference to register r1")) def test_invalid_op_source(self) -> None: - ret = Return( - value=LoadLiteral( - value="foo", - rtype=str_rprimitive, - ) - ) + ret = Return(value=LoadLiteral(value="foo", rtype=str_rprimitive)) block = self.basic_block([ret]) - fn = FuncIR( - decl=self.func_decl(name="func_1"), - arg_regs=[], - blocks=[block], - ) + fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block]) assert_has_error( - fn, - FnError(source=ret, desc="Invalid op reference to op of type LoadLiteral"), + fn, FnError(source=ret, desc="Invalid op reference to op of type LoadLiteral") ) def test_invalid_return_type(self) -> None: @@ -140,10 +105,7 @@ def test_invalid_return_type(self) -> None: blocks=[self.basic_block([ret])], ) assert_has_error( - fn, - FnError( - source=ret, desc="Cannot coerce source type int32 to dest type int64" - ), + fn, FnError(source=ret, desc="Cannot coerce source type int32 to dest type int64") ) def test_invalid_assign(self) -> None: @@ -156,10 +118,7 @@ def test_invalid_assign(self) -> None: blocks=[self.basic_block([assign, ret])], ) assert_has_error( - fn, - FnError( - source=assign, desc="Cannot coerce source type int32 to dest type int64" - ), + fn, FnError(source=assign, desc="Cannot coerce source type int32 to dest type int64") ) def test_can_coerce_to(self) -> None: @@ -189,11 +148,7 @@ def test_duplicate_op(self) -> None: arg_reg = Register(type=int32_rprimitive, name="r1") assign = Assign(dest=arg_reg, src=Integer(value=5, rtype=int32_rprimitive)) block = self.basic_block([assign, assign, Return(value=NONE_VALUE)]) - fn = FuncIR( - decl=self.func_decl(name="func_1"), - arg_regs=[], - blocks=[block], - ) + fn = FuncIR(decl=self.func_decl(name="func_1"), arg_regs=[], blocks=[block]) assert_has_error(fn, FnError(source=assign, desc="Func has a duplicate op")) def test_pprint(self) -> None: diff --git a/mypyc/test/test_literals.py b/mypyc/test/test_literals.py index 5c7b685d39ef9..6473820d50429 100644 --- a/mypyc/test/test_literals.py +++ b/mypyc/test/test_literals.py @@ -3,85 +3,86 @@ import unittest from mypyc.codegen.literals import ( - Literals, format_str_literal, _encode_str_values, _encode_bytes_values, _encode_int_values + Literals, + _encode_bytes_values, + _encode_int_values, + _encode_str_values, + format_str_literal, ) class TestLiterals(unittest.TestCase): def test_format_str_literal(self) -> None: - assert format_str_literal('') == b'\x00' - assert format_str_literal('xyz') == b'\x03xyz' - assert format_str_literal('x' * 127) == b'\x7f' + b'x' * 127 - assert format_str_literal('x' * 128) == b'\x81\x00' + b'x' * 128 - assert format_str_literal('x' * 131) == b'\x81\x03' + b'x' * 131 + assert format_str_literal("") == b"\x00" + assert format_str_literal("xyz") == b"\x03xyz" + assert format_str_literal("x" * 127) == b"\x7f" + b"x" * 127 + assert format_str_literal("x" * 128) == b"\x81\x00" + b"x" * 128 + assert format_str_literal("x" * 131) == b"\x81\x03" + b"x" * 131 def test_encode_str_values(self) -> None: - assert _encode_str_values({}) == [b''] - assert _encode_str_values({'foo': 0}) == [b'\x01\x03foo', b''] - assert _encode_str_values({'foo': 0, 'b': 1}) == [b'\x02\x03foo\x01b', b''] - assert _encode_str_values({'foo': 0, 'x' * 70: 1}) == [ - b'\x01\x03foo', - bytes([1, 70]) + b'x' * 70, - b'' - ] - assert _encode_str_values({'y' * 100: 0}) == [ - bytes([1, 100]) + b'y' * 100, - b'' + assert _encode_str_values({}) == [b""] + assert _encode_str_values({"foo": 0}) == [b"\x01\x03foo", b""] + assert _encode_str_values({"foo": 0, "b": 1}) == [b"\x02\x03foo\x01b", b""] + assert _encode_str_values({"foo": 0, "x" * 70: 1}) == [ + b"\x01\x03foo", + bytes([1, 70]) + b"x" * 70, + b"", ] + assert _encode_str_values({"y" * 100: 0}) == [bytes([1, 100]) + b"y" * 100, b""] def test_encode_bytes_values(self) -> None: - assert _encode_bytes_values({}) == [b''] - assert _encode_bytes_values({b'foo': 0}) == [b'\x01\x03foo', b''] - assert _encode_bytes_values({b'foo': 0, b'b': 1}) == [b'\x02\x03foo\x01b', b''] - assert _encode_bytes_values({b'foo': 0, b'x' * 70: 1}) == [ - b'\x01\x03foo', - bytes([1, 70]) + b'x' * 70, - b'' - ] - assert _encode_bytes_values({b'y' * 100: 0}) == [ - bytes([1, 100]) + b'y' * 100, - b'' + assert _encode_bytes_values({}) == [b""] + assert _encode_bytes_values({b"foo": 0}) == [b"\x01\x03foo", b""] + assert _encode_bytes_values({b"foo": 0, b"b": 1}) == [b"\x02\x03foo\x01b", b""] + assert _encode_bytes_values({b"foo": 0, b"x" * 70: 1}) == [ + b"\x01\x03foo", + bytes([1, 70]) + b"x" * 70, + b"", ] + assert _encode_bytes_values({b"y" * 100: 0}) == [bytes([1, 100]) + b"y" * 100, b""] def test_encode_int_values(self) -> None: - assert _encode_int_values({}) == [b''] - assert _encode_int_values({123: 0}) == [b'\x01123', b''] - assert _encode_int_values({123: 0, 9: 1}) == [b'\x02123\x009', b''] + assert _encode_int_values({}) == [b""] + assert _encode_int_values({123: 0}) == [b"\x01123", b""] + assert _encode_int_values({123: 0, 9: 1}) == [b"\x02123\x009", b""] assert _encode_int_values({123: 0, 45: 1, 5 * 10**70: 2}) == [ - b'\x02123\x0045', - b'\x015' + b'0' * 70, - b'' - ] - assert _encode_int_values({6 * 10**100: 0}) == [ - b'\x016' + b'0' * 100, - b'' + b"\x02123\x0045", + b"\x015" + b"0" * 70, + b"", ] + assert _encode_int_values({6 * 10**100: 0}) == [b"\x016" + b"0" * 100, b""] def test_simple_literal_index(self) -> None: lit = Literals() lit.record_literal(1) - lit.record_literal('y') + lit.record_literal("y") lit.record_literal(True) lit.record_literal(None) lit.record_literal(False) assert lit.literal_index(None) == 0 assert lit.literal_index(False) == 1 assert lit.literal_index(True) == 2 - assert lit.literal_index('y') == 3 + assert lit.literal_index("y") == 3 assert lit.literal_index(1) == 4 def test_tuple_literal(self) -> None: lit = Literals() - lit.record_literal((1, 'y', None, (b'a', 'b'))) - lit.record_literal((b'a', 'b')) + lit.record_literal((1, "y", None, (b"a", "b"))) + lit.record_literal((b"a", "b")) lit.record_literal(()) - assert lit.literal_index((b'a', 'b')) == 7 - assert lit.literal_index((1, 'y', None, (b'a', 'b'))) == 8 + assert lit.literal_index((b"a", "b")) == 7 + assert lit.literal_index((1, "y", None, (b"a", "b"))) == 8 assert lit.literal_index(()) == 9 print(lit.encoded_tuple_values()) assert lit.encoded_tuple_values() == [ - '3', # Number of tuples - '2', '5', '4', # First tuple (length=2) - '4', '6', '3', '0', '7', # Second tuple (length=4) - '0', # Third tuple (length=0) + "3", # Number of tuples + "2", + "5", + "4", # First tuple (length=2) + "4", + "6", + "3", + "0", + "7", # Second tuple (length=4) + "0", # Third tuple (length=0) ] diff --git a/mypyc/test/test_namegen.py b/mypyc/test/test_namegen.py index 5baacc0eecf9e..c4b83f9a58e21 100644 --- a/mypyc/test/test_namegen.py +++ b/mypyc/test/test_namegen.py @@ -1,40 +1,46 @@ import unittest from mypyc.namegen import ( - NameGenerator, exported_name, candidate_suffixes, make_module_translation_map + NameGenerator, + candidate_suffixes, + exported_name, + make_module_translation_map, ) class TestNameGen(unittest.TestCase): def test_candidate_suffixes(self) -> None: - assert candidate_suffixes('foo') == ['', 'foo.'] - assert candidate_suffixes('foo.bar') == ['', 'bar.', 'foo.bar.'] + assert candidate_suffixes("foo") == ["", "foo."] + assert candidate_suffixes("foo.bar") == ["", "bar.", "foo.bar."] def test_exported_name(self) -> None: - assert exported_name('foo') == 'foo' - assert exported_name('foo.bar') == 'foo___bar' + assert exported_name("foo") == "foo" + assert exported_name("foo.bar") == "foo___bar" def test_make_module_translation_map(self) -> None: - assert make_module_translation_map( - ['foo', 'bar']) == {'foo': 'foo.', 'bar': 'bar.'} - assert make_module_translation_map( - ['foo.bar', 'foo.baz']) == {'foo.bar': 'bar.', 'foo.baz': 'baz.'} - assert make_module_translation_map( - ['zar', 'foo.bar', 'foo.baz']) == {'foo.bar': 'bar.', - 'foo.baz': 'baz.', - 'zar': 'zar.'} - assert make_module_translation_map( - ['foo.bar', 'fu.bar', 'foo.baz']) == {'foo.bar': 'foo.bar.', - 'fu.bar': 'fu.bar.', - 'foo.baz': 'baz.'} + assert make_module_translation_map(["foo", "bar"]) == {"foo": "foo.", "bar": "bar."} + assert make_module_translation_map(["foo.bar", "foo.baz"]) == { + "foo.bar": "bar.", + "foo.baz": "baz.", + } + assert make_module_translation_map(["zar", "foo.bar", "foo.baz"]) == { + "foo.bar": "bar.", + "foo.baz": "baz.", + "zar": "zar.", + } + assert make_module_translation_map(["foo.bar", "fu.bar", "foo.baz"]) == { + "foo.bar": "foo.bar.", + "fu.bar": "fu.bar.", + "foo.baz": "baz.", + } def test_name_generator(self) -> None: - g = NameGenerator([['foo', 'foo.zar']]) - assert g.private_name('foo', 'f') == 'foo___f' - assert g.private_name('foo', 'C.x.y') == 'foo___C___x___y' - assert g.private_name('foo', 'C.x.y') == 'foo___C___x___y' - assert g.private_name('foo.zar', 'C.x.y') == 'zar___C___x___y' - assert g.private_name('foo', 'C.x_y') == 'foo___C___x_y' - assert g.private_name('foo', 'C_x_y') == 'foo___C_x_y' - assert g.private_name('foo', 'C_x_y') == 'foo___C_x_y' - assert g.private_name('foo', '___') == 'foo______3_' + g = NameGenerator([["foo", "foo.zar"]]) + assert g.private_name("foo", "f") == "foo___f" + assert g.private_name("foo", "C.x.y") == "foo___C___x___y" + assert g.private_name("foo", "C.x.y") == "foo___C___x___y" + assert g.private_name("foo.zar", "C.x.y") == "zar___C___x___y" + assert g.private_name("foo", "C.x_y") == "foo___C___x_y" + assert g.private_name("foo", "C_x_y") == "foo___C_x_y" + assert g.private_name("foo", "C_x_y") == "foo___C_x_y" + assert g.private_name("foo", "___") == "foo______3_" diff --git a/mypyc/test/test_pprint.py b/mypyc/test/test_pprint.py index 4c3374cddcc19..33fbbc43e0428 100644 --- a/mypyc/test/test_pprint.py +++ b/mypyc/test/test_pprint.py @@ -1,13 +1,13 @@ import unittest from typing import List -from mypyc.ir.ops import BasicBlock, Register, Op, Integer, IntOp, Unreachable, Assign -from mypyc.ir.rtypes import int_rprimitive +from mypyc.ir.ops import Assign, BasicBlock, Integer, IntOp, Op, Register, Unreachable from mypyc.ir.pprint import generate_names_for_ir +from mypyc.ir.rtypes import int_rprimitive def register(name: str) -> Register: - return Register(int_rprimitive, 'foo', is_arg=True) + return Register(int_rprimitive, "foo", is_arg=True) def make_block(ops: List[Op]) -> BasicBlock: @@ -21,8 +21,8 @@ def test_empty(self) -> None: assert generate_names_for_ir([], []) == {} def test_arg(self) -> None: - reg = register('foo') - assert generate_names_for_ir([reg], []) == {reg: 'foo'} + reg = register("foo") + assert generate_names_for_ir([reg], []) == {reg: "foo"} def test_int_op(self) -> None: n1 = Integer(2) @@ -30,12 +30,12 @@ def test_int_op(self) -> None: op1 = IntOp(int_rprimitive, n1, n2, IntOp.ADD) op2 = IntOp(int_rprimitive, op1, n2, IntOp.ADD) block = make_block([op1, op2, Unreachable()]) - assert generate_names_for_ir([], [block]) == {op1: 'r0', op2: 'r1'} + assert generate_names_for_ir([], [block]) == {op1: "r0", op2: "r1"} def test_assign(self) -> None: - reg = register('foo') + reg = register("foo") n = Integer(2) op1 = Assign(reg, n) op2 = Assign(reg, n) block = make_block([op1, op2]) - assert generate_names_for_ir([reg], [block]) == {reg: 'foo'} + assert generate_names_for_ir([reg], [block]) == {reg: "foo"} diff --git a/mypyc/test/test_rarray.py b/mypyc/test/test_rarray.py index a6702c811077d..c599f663d3c98 100644 --- a/mypyc/test/test_rarray.py +++ b/mypyc/test/test_rarray.py @@ -4,7 +4,11 @@ from mypyc.common import PLATFORM_SIZE from mypyc.ir.rtypes import ( - RArray, int_rprimitive, bool_rprimitive, compute_rtype_alignment, compute_rtype_size + RArray, + bool_rprimitive, + compute_rtype_alignment, + compute_rtype_size, + int_rprimitive, ) @@ -16,8 +20,8 @@ def test_basics(self) -> None: def test_str_conversion(self) -> None: a = RArray(int_rprimitive, 10) - assert str(a) == 'int[10]' - assert repr(a) == '[10]>' + assert str(a) == "int[10]" + assert repr(a) == "[10]>" def test_eq(self) -> None: a = RArray(int_rprimitive, 10) diff --git a/mypyc/test/test_refcount.py b/mypyc/test/test_refcount.py index 2c9502330cd53..1bd8ff79ba7b1 100644 --- a/mypyc/test/test_refcount.py +++ b/mypyc/test/test_refcount.py @@ -6,23 +6,25 @@ import os.path +from mypy.errors import CompileError from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase -from mypy.errors import CompileError - from mypyc.common import TOP_LEVEL_NAME from mypyc.ir.pprint import format_func -from mypyc.transform.refcount import insert_ref_count_opcodes -from mypyc.transform.uninit import insert_uninit_checks from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, use_custom_builtins, MypycDataSuite, build_ir_for_single_file, - assert_test_output, remove_comment_lines, replace_word_size, - infer_ir_build_options_from_test_name + ICODE_GEN_BUILTINS, + MypycDataSuite, + assert_test_output, + build_ir_for_single_file, + infer_ir_build_options_from_test_name, + remove_comment_lines, + replace_word_size, + use_custom_builtins, ) +from mypyc.transform.refcount import insert_ref_count_opcodes +from mypyc.transform.uninit import insert_uninit_checks -files = [ - 'refcount.test' -] +files = ["refcount.test"] class TestRefCountTransform(MypycDataSuite): @@ -46,12 +48,10 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: else: actual = [] for fn in ir: - if (fn.name == TOP_LEVEL_NAME - and not testcase.name.endswith('_toplevel')): + if fn.name == TOP_LEVEL_NAME and not testcase.name.endswith("_toplevel"): continue insert_uninit_checks(fn) insert_ref_count_opcodes(fn) actual.extend(format_func(fn)) - assert_test_output(testcase, actual, 'Invalid source code output', - expected_output) + assert_test_output(testcase, actual, "Invalid source code output", expected_output) diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 4013c30c8bc83..075a5b33c4807 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -1,64 +1,67 @@ """Test cases for building an C extension and running it.""" import ast +import contextlib import glob import os.path import re -import subprocess -import contextlib import shutil +import subprocess import sys from typing import Any, Iterator, List, cast from mypy import build -from mypy.test.data import DataDrivenTestCase -from mypy.test.config import test_temp_dir from mypy.errors import CompileError from mypy.options import Options +from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase from mypy.test.helpers import assert_module_equivalence, perform_file_operations - +from mypyc.build import construct_groups from mypyc.codegen import emitmodule -from mypyc.options import CompilerOptions from mypyc.errors import Errors -from mypyc.build import construct_groups +from mypyc.options import CompilerOptions +from mypyc.test.test_serialization import check_serialization_roundtrip from mypyc.test.testutil import ( - ICODE_GEN_BUILTINS, TESTUTIL_PATH, - use_custom_builtins, MypycDataSuite, assert_test_output, - show_c, fudge_dir_mtimes, + ICODE_GEN_BUILTINS, + TESTUTIL_PATH, + MypycDataSuite, + assert_test_output, + fudge_dir_mtimes, + show_c, + use_custom_builtins, ) -from mypyc.test.test_serialization import check_serialization_roundtrip files = [ - 'run-misc.test', - 'run-functions.test', - 'run-integers.test', - 'run-floats.test', - 'run-bools.test', - 'run-strings.test', - 'run-bytes.test', - 'run-tuples.test', - 'run-lists.test', - 'run-dicts.test', - 'run-sets.test', - 'run-primitives.test', - 'run-loops.test', - 'run-exceptions.test', - 'run-imports.test', - 'run-classes.test', - 'run-traits.test', - 'run-generators.test', - 'run-multimodule.test', - 'run-bench.test', - 'run-mypy-sim.test', - 'run-dunders.test', - 'run-singledispatch.test', - 'run-attrs.test', + "run-misc.test", + "run-functions.test", + "run-integers.test", + "run-floats.test", + "run-bools.test", + "run-strings.test", + "run-bytes.test", + "run-tuples.test", + "run-lists.test", + "run-dicts.test", + "run-sets.test", + "run-primitives.test", + "run-loops.test", + "run-exceptions.test", + "run-imports.test", + "run-classes.test", + "run-traits.test", + "run-generators.test", + "run-multimodule.test", + "run-bench.test", + "run-mypy-sim.test", + "run-dunders.test", + "run-singledispatch.test", + "run-attrs.test", ] if sys.version_info >= (3, 7): - files.append('run-python37.test') + files.append("run-python37.test") if sys.version_info >= (3, 8): - files.append('run-python38.test') + files.append("run-python38.test") setup_format = """\ from setuptools import setup @@ -70,7 +73,7 @@ ) """ -WORKDIR = 'build' +WORKDIR = "build" def run_setup(script_name: str, script_args: List[str]) -> bool: @@ -87,12 +90,12 @@ def run_setup(script_name: str, script_args: List[str]) -> bool: Returns whether the setup succeeded. """ save_argv = sys.argv.copy() - g = {'__file__': script_name} + g = {"__file__": script_name} try: try: sys.argv[0] = script_name sys.argv[1:] = script_args - with open(script_name, 'rb') as f: + with open(script_name, "rb") as f: exec(f.read(), g) finally: sys.argv = save_argv @@ -122,6 +125,7 @@ def chdir_manager(target: str) -> Iterator[None]: class TestRun(MypycDataSuite): """Test cases that build a C extension and run code.""" + files = files base_path = test_temp_dir optional_out = True @@ -132,21 +136,22 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: # setup.py wants to be run from the root directory of the package, which we accommodate # by chdiring into tmp/ with use_custom_builtins(os.path.join(self.data_prefix, ICODE_GEN_BUILTINS), testcase), ( - chdir_manager('tmp')): + chdir_manager("tmp") + ): self.run_case_inner(testcase) def run_case_inner(self, testcase: DataDrivenTestCase) -> None: if not os.path.isdir(WORKDIR): # (one test puts something in build...) os.mkdir(WORKDIR) - text = '\n'.join(testcase.input) + text = "\n".join(testcase.input) - with open('native.py', 'w', encoding='utf-8') as f: + with open("native.py", "w", encoding="utf-8") as f: f.write(text) - with open('interpreted.py', 'w', encoding='utf-8') as f: + with open("interpreted.py", "w", encoding="utf-8") as f: f.write(text) - shutil.copyfile(TESTUTIL_PATH, 'testutil.py') + shutil.copyfile(TESTUTIL_PATH, "testutil.py") step = 1 self.run_case_step(testcase, step) @@ -162,12 +167,12 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: fudge_dir_mtimes(WORKDIR, -1) step += 1 - with chdir_manager('..'): + with chdir_manager(".."): perform_file_operations(operations) self.run_case_step(testcase, step) def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> None: - bench = testcase.config.getoption('--bench', False) and 'Benchmark' in testcase.name + bench = testcase.config.getoption("--bench", False) and "Benchmark" in testcase.name options = Options() options.use_builtins_fixtures = True @@ -180,33 +185,35 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> # Avoid checking modules/packages named 'unchecked', to provide a way # to test interacting with code we don't have types for. - options.per_module_options['unchecked.*'] = {'follow_imports': 'error'} + options.per_module_options["unchecked.*"] = {"follow_imports": "error"} - source = build.BuildSource('native.py', 'native', None) + source = build.BuildSource("native.py", "native", None) sources = [source] - module_names = ['native'] - module_paths = ['native.py'] + module_names = ["native"] + module_paths = ["native.py"] # Hard code another module name to compile in the same compilation unit. to_delete = [] for fn, text in testcase.files: fn = os.path.relpath(fn, test_temp_dir) - if os.path.basename(fn).startswith('other') and fn.endswith('.py'): - name = fn.split('.')[0].replace(os.sep, '.') + if os.path.basename(fn).startswith("other") and fn.endswith(".py"): + name = fn.split(".")[0].replace(os.sep, ".") module_names.append(name) sources.append(build.BuildSource(fn, name, None)) to_delete.append(fn) module_paths.append(fn) - shutil.copyfile(fn, - os.path.join(os.path.dirname(fn), name + '_interpreted.py')) + shutil.copyfile(fn, os.path.join(os.path.dirname(fn), name + "_interpreted.py")) for source in sources: - options.per_module_options.setdefault(source.module, {})['mypyc'] = True + options.per_module_options.setdefault(source.module, {})["mypyc"] = True - separate = (self.get_separate('\n'.join(testcase.input), incremental_step) if self.separate - else False) + separate = ( + self.get_separate("\n".join(testcase.input), incremental_step) + if self.separate + else False + ) groups = construct_groups(sources, separate, len(module_names) > 1) @@ -217,13 +224,11 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> options=options, compiler_options=compiler_options, groups=groups, - alt_lib_path='.') + alt_lib_path=".", + ) errors = Errors() ir, cfiles = emitmodule.compile_modules_to_c( - result, - compiler_options=compiler_options, - errors=errors, - groups=groups, + result, compiler_options=compiler_options, errors=errors, groups=groups ) if errors.num_errors: errors.flush_errors() @@ -231,111 +236,115 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> except CompileError as e: for line in e.messages: print(fix_native_line_number(line, testcase.file, testcase.line)) - assert False, 'Compile error' + assert False, "Compile error" # Check that serialization works on this IR. (Only on the first # step because the the returned ir only includes updated code.) if incremental_step == 1: check_serialization_roundtrip(ir) - opt_level = int(os.environ.get('MYPYC_OPT_LEVEL', 0)) - debug_level = int(os.environ.get('MYPYC_DEBUG_LEVEL', 0)) + opt_level = int(os.environ.get("MYPYC_OPT_LEVEL", 0)) + debug_level = int(os.environ.get("MYPYC_DEBUG_LEVEL", 0)) - setup_file = os.path.abspath(os.path.join(WORKDIR, 'setup.py')) + setup_file = os.path.abspath(os.path.join(WORKDIR, "setup.py")) # We pass the C file information to the build script via setup.py unfortunately - with open(setup_file, 'w', encoding='utf-8') as f: - f.write(setup_format.format(module_paths, - separate, - cfiles, - self.multi_file, - opt_level, - debug_level)) - - if not run_setup(setup_file, ['build_ext', '--inplace']): - if testcase.config.getoption('--mypyc-showc'): + with open(setup_file, "w", encoding="utf-8") as f: + f.write( + setup_format.format( + module_paths, separate, cfiles, self.multi_file, opt_level, debug_level + ) + ) + + if not run_setup(setup_file, ["build_ext", "--inplace"]): + if testcase.config.getoption("--mypyc-showc"): show_c(cfiles) assert False, "Compilation failed" # Assert that an output file got created - suffix = 'pyd' if sys.platform == 'win32' else 'so' - assert glob.glob(f'native.*.{suffix}') or glob.glob(f'native.{suffix}') + suffix = "pyd" if sys.platform == "win32" else "so" + assert glob.glob(f"native.*.{suffix}") or glob.glob(f"native.{suffix}") - driver_path = 'driver.py' + driver_path = "driver.py" if not os.path.isfile(driver_path): # No driver.py provided by test case. Use the default one # (mypyc/test-data/driver/driver.py) that calls each # function named test_*. default_driver = os.path.join( - os.path.dirname(__file__), '..', 'test-data', 'driver', 'driver.py') + os.path.dirname(__file__), "..", "test-data", "driver", "driver.py" + ) shutil.copy(default_driver, driver_path) env = os.environ.copy() - env['MYPYC_RUN_BENCH'] = '1' if bench else '0' + env["MYPYC_RUN_BENCH"] = "1" if bench else "0" - debugger = testcase.config.getoption('debugger') + debugger = testcase.config.getoption("debugger") if debugger: - if debugger == 'lldb': - subprocess.check_call(['lldb', '--', sys.executable, driver_path], env=env) - elif debugger == 'gdb': - subprocess.check_call(['gdb', '--args', sys.executable, driver_path], env=env) + if debugger == "lldb": + subprocess.check_call(["lldb", "--", sys.executable, driver_path], env=env) + elif debugger == "gdb": + subprocess.check_call(["gdb", "--args", sys.executable, driver_path], env=env) else: - assert False, 'Unsupported debugger' + assert False, "Unsupported debugger" # TODO: find a way to automatically disable capturing # stdin/stdout when in debugging mode assert False, ( "Test can't pass in debugging mode. " "(Make sure to pass -s to pytest to interact with the debugger)" ) - proc = subprocess.Popen([sys.executable, driver_path], stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, env=env) - output = proc.communicate()[0].decode('utf8') + proc = subprocess.Popen( + [sys.executable, driver_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + ) + output = proc.communicate()[0].decode("utf8") outlines = output.splitlines() - if testcase.config.getoption('--mypyc-showc'): + if testcase.config.getoption("--mypyc-showc"): show_c(cfiles) if proc.returncode != 0: print() - print('*** Exit status: %d' % proc.returncode) + print("*** Exit status: %d" % proc.returncode) # Verify output. if bench: - print('Test output:') + print("Test output:") print(output) else: if incremental_step == 1: - msg = 'Invalid output' + msg = "Invalid output" expected = testcase.output else: - msg = f'Invalid output (step {incremental_step})' + msg = f"Invalid output (step {incremental_step})" expected = testcase.output2.get(incremental_step, []) if not expected: # Tweak some line numbers, but only if the expected output is empty, # as tweaked output might not match expected output. - outlines = [fix_native_line_number(line, testcase.file, testcase.line) - for line in outlines] + outlines = [ + fix_native_line_number(line, testcase.file, testcase.line) for line in outlines + ] assert_test_output(testcase, outlines, msg, expected) if incremental_step > 1 and options.incremental: - suffix = '' if incremental_step == 2 else str(incremental_step - 1) + suffix = "" if incremental_step == 2 else str(incremental_step - 1) expected_rechecked = testcase.expected_rechecked_modules.get(incremental_step - 1) if expected_rechecked is not None: assert_module_equivalence( - 'rechecked' + suffix, - expected_rechecked, result.manager.rechecked_modules) + "rechecked" + suffix, expected_rechecked, result.manager.rechecked_modules + ) expected_stale = testcase.expected_stale_modules.get(incremental_step - 1) if expected_stale is not None: assert_module_equivalence( - 'stale' + suffix, - expected_stale, result.manager.stale_modules) + "stale" + suffix, expected_stale, result.manager.stale_modules + ) assert proc.returncode == 0 - def get_separate(self, program_text: str, - incremental_step: int) -> Any: - template = r'# separate{}: (\[.*\])$' + def get_separate(self, program_text: str, incremental_step: int) -> Any: + template = r"# separate{}: (\[.*\])$" m = re.search(template.format(incremental_step), program_text, flags=re.MULTILINE) if not m: - m = re.search(template.format(''), program_text, flags=re.MULTILINE) + m = re.search(template.format(""), program_text, flags=re.MULTILINE) if m: return ast.literal_eval(m.group(1)) else: @@ -350,11 +359,8 @@ class TestRunMultiFile(TestRun): """ multi_file = True - test_name_suffix = '_multi' - files = [ - 'run-multimodule.test', - 'run-mypy-sim.test', - ] + test_name_suffix = "_multi" + files = ["run-multimodule.test", "run-mypy-sim.test"] class TestRunSeparate(TestRun): @@ -372,12 +378,10 @@ class TestRunSeparate(TestRun): This puts other.py and other_b.py into a compilation group named "stuff". Any files not mentioned in the comment will get single-file groups. """ + separate = True - test_name_suffix = '_separate' - files = [ - 'run-multimodule.test', - 'run-mypy-sim.test', - ] + test_name_suffix = "_separate" + files = ["run-multimodule.test", "run-mypy-sim.test"] def fix_native_line_number(message: str, fnam: str, delta: int) -> str: @@ -396,10 +400,12 @@ def fix_native_line_number(message: str, fnam: str, delta: int) -> str: Returns updated message (or original message if we couldn't find anything). """ fnam = os.path.basename(fnam) - message = re.sub(r'native\.py:([0-9]+):', - lambda m: '%s:%d:' % (fnam, int(m.group(1)) + delta), - message) - message = re.sub(r'"native.py", line ([0-9]+),', - lambda m: '"%s", line %d,' % (fnam, int(m.group(1)) + delta), - message) + message = re.sub( + r"native\.py:([0-9]+):", lambda m: "%s:%d:" % (fnam, int(m.group(1)) + delta), message + ) + message = re.sub( + r'"native.py", line ([0-9]+),', + lambda m: '"%s", line %d,' % (fnam, int(m.group(1)) + delta), + message, + ) return message diff --git a/mypyc/test/test_serialization.py b/mypyc/test/test_serialization.py index eeef6beb1305a..1c54b4ae074a9 100644 --- a/mypyc/test/test_serialization.py +++ b/mypyc/test/test_serialization.py @@ -3,20 +3,20 @@ # This file is named test_serialization.py even though it doesn't # contain its own tests so that pytest will rewrite the asserts... -from typing import Any, Dict, Tuple -from mypy.backports import OrderedDict from collections.abc import Iterable +from typing import Any, Dict, Tuple -from mypyc.ir.ops import DeserMaps -from mypyc.ir.rtypes import RType -from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature +from mypy.backports import OrderedDict from mypyc.ir.class_ir import ClassIR +from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature from mypyc.ir.module_ir import ModuleIR, deserialize_modules -from mypyc.sametype import is_same_type, is_same_signature +from mypyc.ir.ops import DeserMaps +from mypyc.ir.rtypes import RType +from mypyc.sametype import is_same_signature, is_same_type def get_dict(x: Any) -> Dict[str, Any]: - if hasattr(x, '__mypyc_attrs__'): + if hasattr(x, "__mypyc_attrs__"): return {k: getattr(x, k) for k in x.__mypyc_attrs__ if hasattr(x, k)} else: return dict(x.__dict__) @@ -25,8 +25,8 @@ def get_dict(x: Any) -> Dict[str, Any]: def get_function_dict(x: FuncIR) -> Dict[str, Any]: """Get a dict of function attributes safe to compare across serialization""" d = get_dict(x) - d.pop('blocks', None) - d.pop('env', None) + d.pop("blocks", None) + d.pop("env", None) return d @@ -87,12 +87,12 @@ def assert_modules_same(ir1: ModuleIR, ir2: ModuleIR) -> None: assert_blobs_same(get_dict(cls1), get_dict(cls2), (ir1.fullname, cls1.fullname)) for fn1, fn2 in zip(ir1.functions, ir2.functions): - assert_blobs_same(get_function_dict(fn1), get_function_dict(fn2), - (ir1.fullname, fn1.fullname)) - assert_blobs_same(get_dict(fn1.decl), get_dict(fn2.decl), - (ir1.fullname, fn1.fullname)) + assert_blobs_same( + get_function_dict(fn1), get_function_dict(fn2), (ir1.fullname, fn1.fullname) + ) + assert_blobs_same(get_dict(fn1.decl), get_dict(fn2.decl), (ir1.fullname, fn1.fullname)) - assert_blobs_same(ir1.final_names, ir2.final_names, (ir1.fullname, 'final_names')) + assert_blobs_same(ir1.final_names, ir2.final_names, (ir1.fullname, "final_names")) def check_serialization_roundtrip(irs: Dict[str, ModuleIR]) -> None: diff --git a/mypyc/test/test_struct.py b/mypyc/test/test_struct.py index 0617f83bbb38d..b9d97adcbdf3b 100644 --- a/mypyc/test/test_struct.py +++ b/mypyc/test/test_struct.py @@ -1,8 +1,12 @@ import unittest from mypyc.ir.rtypes import ( - RStruct, bool_rprimitive, int64_rprimitive, int32_rprimitive, object_rprimitive, - int_rprimitive + RStruct, + bool_rprimitive, + int32_rprimitive, + int64_rprimitive, + int_rprimitive, + object_rprimitive, ) from mypyc.rt_subtype import is_runtime_subtype @@ -25,8 +29,7 @@ def test_struct_offsets(self) -> None: assert r2.size == 8 assert r3.size == 16 - r4 = RStruct("", [], [bool_rprimitive, bool_rprimitive, - bool_rprimitive, int32_rprimitive]) + r4 = RStruct("", [], [bool_rprimitive, bool_rprimitive, bool_rprimitive, int32_rprimitive]) assert r4.size == 8 assert r4.offsets == [0, 1, 2, 4] @@ -43,11 +46,12 @@ def test_struct_offsets(self) -> None: assert r7.size == 12 def test_struct_str(self) -> None: - r = RStruct("Foo", ["a", "b"], - [bool_rprimitive, object_rprimitive]) + r = RStruct("Foo", ["a", "b"], [bool_rprimitive, object_rprimitive]) assert str(r) == "Foo{a:bool, b:object}" - assert repr(r) == ", " \ - "b:}>" + assert ( + repr(r) == ", " + "b:}>" + ) r1 = RStruct("Bar", ["c"], [int32_rprimitive]) assert str(r1) == "Bar{c:int32}" assert repr(r1) == "}>" @@ -57,28 +61,24 @@ def test_struct_str(self) -> None: def test_runtime_subtype(self) -> None: # right type to check with - r = RStruct("Foo", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r = RStruct("Foo", ["a", "b"], [bool_rprimitive, int_rprimitive]) # using the exact same fields - r1 = RStruct("Foo", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r1 = RStruct("Foo", ["a", "b"], [bool_rprimitive, int_rprimitive]) # names different - r2 = RStruct("Bar", ["c", "b"], - [bool_rprimitive, int_rprimitive]) + r2 = RStruct("Bar", ["c", "b"], [bool_rprimitive, int_rprimitive]) # name different - r3 = RStruct("Baz", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r3 = RStruct("Baz", ["a", "b"], [bool_rprimitive, int_rprimitive]) # type different - r4 = RStruct("FooBar", ["a", "b"], - [bool_rprimitive, int32_rprimitive]) + r4 = RStruct("FooBar", ["a", "b"], [bool_rprimitive, int32_rprimitive]) # number of types different - r5 = RStruct("FooBarBaz", ["a", "b", "c"], - [bool_rprimitive, int_rprimitive, bool_rprimitive]) + r5 = RStruct( + "FooBarBaz", ["a", "b", "c"], [bool_rprimitive, int_rprimitive, bool_rprimitive] + ) assert is_runtime_subtype(r1, r) is True assert is_runtime_subtype(r2, r) is False @@ -87,29 +87,24 @@ def test_runtime_subtype(self) -> None: assert is_runtime_subtype(r5, r) is False def test_eq_and_hash(self) -> None: - r = RStruct("Foo", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r = RStruct("Foo", ["a", "b"], [bool_rprimitive, int_rprimitive]) # using the exact same fields - r1 = RStruct("Foo", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r1 = RStruct("Foo", ["a", "b"], [bool_rprimitive, int_rprimitive]) assert hash(r) == hash(r1) assert r == r1 # different name - r2 = RStruct("Foq", ["a", "b"], - [bool_rprimitive, int_rprimitive]) + r2 = RStruct("Foq", ["a", "b"], [bool_rprimitive, int_rprimitive]) assert hash(r) != hash(r2) assert r != r2 # different names - r3 = RStruct("Foo", ["a", "c"], - [bool_rprimitive, int_rprimitive]) + r3 = RStruct("Foo", ["a", "c"], [bool_rprimitive, int_rprimitive]) assert hash(r) != hash(r3) assert r != r3 # different type - r4 = RStruct("Foo", ["a", "b"], - [bool_rprimitive, int_rprimitive, bool_rprimitive]) + r4 = RStruct("Foo", ["a", "b"], [bool_rprimitive, int_rprimitive, bool_rprimitive]) assert hash(r) != hash(r4) assert r != r4 diff --git a/mypyc/test/test_subtype.py b/mypyc/test/test_subtype.py index e006e5425174d..85baac9065442 100644 --- a/mypyc/test/test_subtype.py +++ b/mypyc/test/test_subtype.py @@ -3,11 +3,15 @@ import unittest from mypyc.ir.rtypes import ( - bit_rprimitive, bool_rprimitive, int_rprimitive, int64_rprimitive, int32_rprimitive, - short_int_rprimitive + bit_rprimitive, + bool_rprimitive, + int32_rprimitive, + int64_rprimitive, + int_rprimitive, + short_int_rprimitive, ) -from mypyc.subtype import is_subtype from mypyc.rt_subtype import is_runtime_subtype +from mypyc.subtype import is_subtype class TestSubtype(unittest.TestCase): diff --git a/mypyc/test/test_tuplename.py b/mypyc/test/test_tuplename.py index 7f3fd2000d290..eab4e6102a7ef 100644 --- a/mypyc/test/test_tuplename.py +++ b/mypyc/test/test_tuplename.py @@ -1,23 +1,31 @@ import unittest +from mypyc.ir.class_ir import ClassIR from mypyc.ir.rtypes import ( - RTuple, object_rprimitive, int_rprimitive, bool_rprimitive, list_rprimitive, - RInstance, RUnion, + RInstance, + RTuple, + RUnion, + bool_rprimitive, + int_rprimitive, + list_rprimitive, + object_rprimitive, ) -from mypyc.ir.class_ir import ClassIR class TestTupleNames(unittest.TestCase): def setUp(self) -> None: - self.inst_a = RInstance(ClassIR('A', '__main__')) - self.inst_b = RInstance(ClassIR('B', '__main__')) + self.inst_a = RInstance(ClassIR("A", "__main__")) + self.inst_b = RInstance(ClassIR("B", "__main__")) def test_names(self) -> None: assert RTuple([int_rprimitive, int_rprimitive]).unique_id == "T2II" assert RTuple([list_rprimitive, object_rprimitive, self.inst_a]).unique_id == "T3OOO" assert RTuple([list_rprimitive, object_rprimitive, self.inst_b]).unique_id == "T3OOO" assert RTuple([]).unique_id == "T0" - assert RTuple([RTuple([]), - RTuple([int_rprimitive, int_rprimitive])]).unique_id == "T2T0T2II" - assert RTuple([bool_rprimitive, - RUnion([bool_rprimitive, int_rprimitive])]).unique_id == "T2CO" + assert ( + RTuple([RTuple([]), RTuple([int_rprimitive, int_rprimitive])]).unique_id == "T2T0T2II" + ) + assert ( + RTuple([bool_rprimitive, RUnion([bool_rprimitive, int_rprimitive])]).unique_id + == "T2CO" + ) diff --git a/mypyc/test/testutil.py b/mypyc/test/testutil.py index d5c5dea2d6344..f7129ace1ed31 100644 --- a/mypyc/test/testutil.py +++ b/mypyc/test/testutil.py @@ -5,29 +5,28 @@ import os.path import re import shutil -from typing import List, Callable, Iterator, Optional, Tuple +from typing import Callable, Iterator, List, Optional, Tuple from mypy import build from mypy.errors import CompileError from mypy.options import Options -from mypy.test.data import DataSuite, DataDrivenTestCase from mypy.test.config import test_temp_dir +from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal - -from mypyc.options import CompilerOptions from mypyc.analysis.ircheck import assert_func_ir_valid +from mypyc.common import IS_32_BIT_PLATFORM, PLATFORM_SIZE +from mypyc.errors import Errors from mypyc.ir.func_ir import FuncIR from mypyc.ir.module_ir import ModuleIR -from mypyc.errors import Errors from mypyc.irbuild.main import build_ir from mypyc.irbuild.mapper import Mapper +from mypyc.options import CompilerOptions from mypyc.test.config import test_data_prefix -from mypyc.common import IS_32_BIT_PLATFORM, PLATFORM_SIZE # The builtins stub used during icode generation test cases. -ICODE_GEN_BUILTINS = os.path.join(test_data_prefix, 'fixtures/ir.py') +ICODE_GEN_BUILTINS = os.path.join(test_data_prefix, "fixtures/ir.py") # The testutil support library -TESTUTIL_PATH = os.path.join(test_data_prefix, 'fixtures/testutil.py') +TESTUTIL_PATH = os.path.join(test_data_prefix, "fixtures/testutil.py") class MypycDataSuite(DataSuite): @@ -36,8 +35,9 @@ class MypycDataSuite(DataSuite): data_prefix = test_data_prefix -def builtins_wrapper(func: Callable[[DataDrivenTestCase], None], - path: str) -> Callable[[DataDrivenTestCase], None]: +def builtins_wrapper( + func: Callable[[DataDrivenTestCase], None], path: str +) -> Callable[[DataDrivenTestCase], None]: """Decorate a function that implements a data-driven test case to copy an alternative builtins module implementation in place before performing the test case. Clean up after executing the test case. @@ -48,12 +48,12 @@ def builtins_wrapper(func: Callable[[DataDrivenTestCase], None], @contextlib.contextmanager def use_custom_builtins(builtins_path: str, testcase: DataDrivenTestCase) -> Iterator[None]: for path, _ in testcase.files: - if os.path.basename(path) == 'builtins.pyi': + if os.path.basename(path) == "builtins.pyi": default_builtins = False break else: # Use default builtins. - builtins = os.path.abspath(os.path.join(test_temp_dir, 'builtins.pyi')) + builtins = os.path.abspath(os.path.join(test_temp_dir, "builtins.pyi")) shutil.copyfile(builtins_path, builtins) default_builtins = True @@ -66,15 +66,16 @@ def use_custom_builtins(builtins_path: str, testcase: DataDrivenTestCase) -> Ite os.remove(builtins) -def perform_test(func: Callable[[DataDrivenTestCase], None], - builtins_path: str, testcase: DataDrivenTestCase) -> None: +def perform_test( + func: Callable[[DataDrivenTestCase], None], builtins_path: str, testcase: DataDrivenTestCase +) -> None: for path, _ in testcase.files: - if os.path.basename(path) == 'builtins.py': + if os.path.basename(path) == "builtins.py": default_builtins = False break else: # Use default builtins. - builtins = os.path.join(test_temp_dir, 'builtins.py') + builtins = os.path.join(test_temp_dir, "builtins.py") shutil.copyfile(builtins_path, builtins) default_builtins = True @@ -86,15 +87,16 @@ def perform_test(func: Callable[[DataDrivenTestCase], None], os.remove(builtins) -def build_ir_for_single_file(input_lines: List[str], - compiler_options: Optional[CompilerOptions] = None) -> List[FuncIR]: +def build_ir_for_single_file( + input_lines: List[str], compiler_options: Optional[CompilerOptions] = None +) -> List[FuncIR]: return build_ir_for_single_file2(input_lines, compiler_options).functions -def build_ir_for_single_file2(input_lines: List[str], - compiler_options: Optional[CompilerOptions] = None - ) -> ModuleIR: - program_text = '\n'.join(input_lines) +def build_ir_for_single_file2( + input_lines: List[str], compiler_options: Optional[CompilerOptions] = None +) -> ModuleIR: + program_text = "\n".join(input_lines) # By default generate IR compatible with the earliest supported Python C API. # If a test needs more recent API features, this should be overridden. @@ -106,22 +108,24 @@ def build_ir_for_single_file2(input_lines: List[str], options.python_version = (3, 6) options.export_types = True options.preserve_asts = True - options.per_module_options['__main__'] = {'mypyc': True} + options.per_module_options["__main__"] = {"mypyc": True} - source = build.BuildSource('main', '__main__', program_text) + source = build.BuildSource("main", "__main__", program_text) # Construct input as a single single. # Parse and type check the input program. - result = build.build(sources=[source], - options=options, - alt_lib_path=test_temp_dir) + result = build.build(sources=[source], options=options, alt_lib_path=test_temp_dir) if result.errors: raise CompileError(result.errors) errors = Errors() modules = build_ir( - [result.files['__main__']], result.graph, result.types, - Mapper({'__main__': None}), - compiler_options, errors) + [result.files["__main__"]], + result.graph, + result.types, + Mapper({"__main__": None}), + compiler_options, + errors, + ) if errors.num_errors: raise CompileError(errors.new_messages()) @@ -141,44 +145,46 @@ def update_testcase_output(testcase: DataDrivenTestCase, output: List[str]) -> N # We can't rely on the test line numbers to *find* the test, since # we might fix multiple tests in a run. So find it by the case # header. Give up if there are multiple tests with the same name. - test_slug = f'[case {testcase.name}]' + test_slug = f"[case {testcase.name}]" if data_lines.count(test_slug) != 1: return start_idx = data_lines.index(test_slug) stop_idx = start_idx + 11 - while stop_idx < len(data_lines) and not data_lines[stop_idx].startswith('[case '): + while stop_idx < len(data_lines) and not data_lines[stop_idx].startswith("[case "): stop_idx += 1 test = data_lines[start_idx:stop_idx] - out_start = test.index('[out]') - test[out_start + 1:] = output - data_lines[start_idx:stop_idx] = test + [''] - data = '\n'.join(data_lines) + out_start = test.index("[out]") + test[out_start + 1 :] = output + data_lines[start_idx:stop_idx] = test + [""] + data = "\n".join(data_lines) - with open(testcase_path, 'w') as f: + with open(testcase_path, "w") as f: print(data, file=f) -def assert_test_output(testcase: DataDrivenTestCase, - actual: List[str], - message: str, - expected: Optional[List[str]] = None, - formatted: Optional[List[str]] = None) -> None: +def assert_test_output( + testcase: DataDrivenTestCase, + actual: List[str], + message: str, + expected: Optional[List[str]] = None, + formatted: Optional[List[str]] = None, +) -> None: __tracebackhide__ = True expected_output = expected if expected is not None else testcase.output - if expected_output != actual and testcase.config.getoption('--update-data', False): + if expected_output != actual and testcase.config.getoption("--update-data", False): update_testcase_output(testcase, actual) assert_string_arrays_equal( - expected_output, actual, - f'{message} ({testcase.file}, line {testcase.line})') + expected_output, actual, f"{message} ({testcase.file}, line {testcase.line})" + ) def get_func_names(expected: List[str]) -> List[str]: res = [] for s in expected: - m = re.match(r'def ([_a-zA-Z0-9.*$]+)\(', s) + m = re.match(r"def ([_a-zA-Z0-9.*$]+)\(", s) if m: res.append(m.group(1)) return res @@ -191,7 +197,7 @@ def remove_comment_lines(a: List[str]) -> List[str]: """ r = [] for s in a: - if s.strip().startswith('--') and not s.strip().startswith('---'): + if s.strip().startswith("--") and not s.strip().startswith("---"): pass else: r.append(s) @@ -201,20 +207,20 @@ def remove_comment_lines(a: List[str]) -> List[str]: def print_with_line_numbers(s: str) -> None: lines = s.splitlines() for i, line in enumerate(lines): - print('%-4d %s' % (i + 1, line)) + print("%-4d %s" % (i + 1, line)) def heading(text: str) -> None: - print('=' * 20 + ' ' + text + ' ' + '=' * 20) + print("=" * 20 + " " + text + " " + "=" * 20) def show_c(cfiles: List[List[Tuple[str, str]]]) -> None: - heading('Generated C') + heading("Generated C") for group in cfiles: for cfile, ctext in group: - print(f'== {cfile} ==') + print(f"== {cfile} ==") print_with_line_numbers(ctext) - heading('End C') + heading("End C") def fudge_dir_mtimes(dir: str, delta: int) -> None: @@ -229,7 +235,7 @@ def replace_word_size(text: List[str]) -> List[str]: """Replace WORDSIZE with platform specific word sizes""" result = [] for line in text: - index = line.find('WORD_SIZE') + index = line.find("WORD_SIZE") if index != -1: # get 'WORDSIZE*n' token word_size_token = line[index:].split()[0] @@ -258,16 +264,15 @@ def infer_ir_build_options_from_test_name(name: str) -> Optional[CompilerOptions Don't generate code for assert statements """ # If this is specific to some bit width, always pass if platform doesn't match. - if '_64bit' in name and IS_32_BIT_PLATFORM: + if "_64bit" in name and IS_32_BIT_PLATFORM: return None - if '_32bit' in name and not IS_32_BIT_PLATFORM: + if "_32bit" in name and not IS_32_BIT_PLATFORM: return None - options = CompilerOptions(strip_asserts='StripAssert' in name, - capi_version=(3, 5)) + options = CompilerOptions(strip_asserts="StripAssert" in name, capi_version=(3, 5)) # A suffix like _python3.8 is used to set the target C API version. - m = re.search(r'_python([3-9]+)_([0-9]+)(_|\b)', name) + m = re.search(r"_python([3-9]+)_([0-9]+)(_|\b)", name) if m: options.capi_version = (int(m.group(1)), int(m.group(2))) - elif '_py' in name or '_Python' in name: - assert False, f'Invalid _py* suffix (should be _pythonX_Y): {name}' + elif "_py" in name or "_Python" in name: + assert False, f"Invalid _py* suffix (should be _pythonX_Y): {name}" return options diff --git a/mypyc/transform/exceptions.py b/mypyc/transform/exceptions.py index e845de1fcf191..d140f050d6aa6 100644 --- a/mypyc/transform/exceptions.py +++ b/mypyc/transform/exceptions.py @@ -11,15 +11,27 @@ from typing import List, Optional +from mypyc.ir.func_ir import FuncIR from mypyc.ir.ops import ( - Value, BasicBlock, LoadErrorValue, Return, Branch, RegisterOp, ComparisonOp, CallC, - Integer, ERR_NEVER, ERR_MAGIC, ERR_FALSE, ERR_ALWAYS, ERR_MAGIC_OVERLAPPING, - NO_TRACEBACK_LINE_NO + ERR_ALWAYS, + ERR_FALSE, + ERR_MAGIC, + ERR_MAGIC_OVERLAPPING, + ERR_NEVER, + NO_TRACEBACK_LINE_NO, + BasicBlock, + Branch, + CallC, + ComparisonOp, + Integer, + LoadErrorValue, + RegisterOp, + Return, + Value, ) -from mypyc.ir.func_ir import FuncIR from mypyc.ir.rtypes import bool_rprimitive -from mypyc.primitives.registry import CFunctionDescription from mypyc.primitives.exc_ops import err_occurred_op +from mypyc.primitives.registry import CFunctionDescription def insert_exception_handling(ir: FuncIR) -> None: @@ -45,9 +57,9 @@ def add_handler_block(ir: FuncIR) -> BasicBlock: return block -def split_blocks_at_errors(blocks: List[BasicBlock], - default_error_handler: BasicBlock, - func_name: Optional[str]) -> List[BasicBlock]: +def split_blocks_at_errors( + blocks: List[BasicBlock], default_error_handler: BasicBlock, func_name: Optional[str] +) -> List[BasicBlock]: new_blocks: List[BasicBlock] = [] # First split blocks on ops that may raise. @@ -90,8 +102,9 @@ def split_blocks_at_errors(blocks: List[BasicBlock], cur_block.ops.append(comp) new_block2 = BasicBlock() new_blocks.append(new_block2) - branch = Branch(comp, true_label=new_block2, false_label=new_block, - op=Branch.BOOL) + branch = Branch( + comp, true_label=new_block2, false_label=new_block, op=Branch.BOOL + ) cur_block.ops.append(branch) cur_block = new_block2 target = primitive_call(err_occurred_op, [], target.line) @@ -99,18 +112,16 @@ def split_blocks_at_errors(blocks: List[BasicBlock], variant = Branch.IS_ERROR negated = True else: - assert False, 'unknown error kind %d' % op.error_kind + assert False, "unknown error kind %d" % op.error_kind # Void ops can't generate errors since error is always # indicated by a special value stored in a register. if op.error_kind != ERR_ALWAYS: assert not op.is_void, "void op generating errors?" - branch = Branch(target, - true_label=error_label, - false_label=new_block, - op=variant, - line=op.line) + branch = Branch( + target, true_label=error_label, false_label=new_block, op=variant, line=op.line + ) branch.negated = negated if op.line != NO_TRACEBACK_LINE_NO and func_name is not None: branch.traceback_entry = (func_name, op.line) diff --git a/mypyc/transform/refcount.py b/mypyc/transform/refcount.py index 60163e385c2de..05e2843fe8865 100644 --- a/mypyc/transform/refcount.py +++ b/mypyc/transform/refcount.py @@ -19,19 +19,30 @@ from typing import Dict, Iterable, List, Set, Tuple from mypyc.analysis.dataflow import ( - get_cfg, - analyze_must_defined_regs, - analyze_live_regs, + AnalysisDict, analyze_borrowed_arguments, + analyze_live_regs, + analyze_must_defined_regs, cleanup_cfg, - AnalysisDict + get_cfg, ) +from mypyc.ir.func_ir import FuncIR, all_values from mypyc.ir.ops import ( - BasicBlock, Assign, RegisterOp, DecRef, IncRef, Branch, Goto, Op, ControlOp, Value, Register, - LoadAddress, Integer, KeepAlive + Assign, + BasicBlock, + Branch, + ControlOp, + DecRef, + Goto, + IncRef, + Integer, + KeepAlive, + LoadAddress, + Op, + Register, + RegisterOp, + Value, ) -from mypyc.ir.func_ir import FuncIR, all_values - Decs = Tuple[Tuple[Value, bool], ...] Incs = Tuple[Value, ...] @@ -59,14 +70,16 @@ def insert_ref_count_opcodes(ir: FuncIR) -> None: cache: BlockCache = {} for block in ir.blocks[:]: if isinstance(block.ops[-1], (Branch, Goto)): - insert_branch_inc_and_decrefs(block, - cache, - ir.blocks, - live.before, - borrow.before, - borrow.after, - defined.after, - ordering) + insert_branch_inc_and_decrefs( + block, + cache, + ir.blocks, + live.before, + borrow.before, + borrow.after, + defined.after, + ordering, + ) transform_block(block, live.before, live.after, borrow.before, defined.after) cleanup_cfg(ir.blocks) @@ -76,8 +89,9 @@ def is_maybe_undefined(post_must_defined: Set[Value], src: Value) -> bool: return isinstance(src, Register) and src not in post_must_defined -def maybe_append_dec_ref(ops: List[Op], dest: Value, - defined: 'AnalysisDict[Value]', key: Tuple[BasicBlock, int]) -> None: +def maybe_append_dec_ref( + ops: List[Op], dest: Value, defined: "AnalysisDict[Value]", key: Tuple[BasicBlock, int] +) -> None: if dest.type.is_refcounted and not isinstance(dest, Integer): ops.append(DecRef(dest, is_xdec=is_maybe_undefined(defined[key], dest))) @@ -87,11 +101,13 @@ def maybe_append_inc_ref(ops: List[Op], dest: Value) -> None: ops.append(IncRef(dest)) -def transform_block(block: BasicBlock, - pre_live: 'AnalysisDict[Value]', - post_live: 'AnalysisDict[Value]', - pre_borrow: 'AnalysisDict[Value]', - post_must_defined: 'AnalysisDict[Value]') -> None: +def transform_block( + block: BasicBlock, + pre_live: "AnalysisDict[Value]", + post_live: "AnalysisDict[Value]", + pre_borrow: "AnalysisDict[Value]", + post_must_defined: "AnalysisDict[Value]", +) -> None: old_ops = block.ops ops: List[Op] = [] for i, op in enumerate(old_ops): @@ -108,7 +124,7 @@ def transform_block(block: BasicBlock, maybe_append_inc_ref(ops, src) # For assignments to registers that were already live, # decref the old value. - if (dest not in pre_borrow[key] and dest in pre_live[key]): + if dest not in pre_borrow[key] and dest in pre_live[key]: assert isinstance(op, Assign) maybe_append_dec_ref(ops, dest, post_must_defined, key) @@ -127,21 +143,25 @@ def transform_block(block: BasicBlock, maybe_append_dec_ref(ops, src, post_must_defined, key) # Decrement the destination if it is dead after the op and # wasn't a borrowed RegisterOp - if (not dest.is_void and dest not in post_live[key] - and not (isinstance(op, RegisterOp) and dest.is_borrowed)): + if ( + not dest.is_void + and dest not in post_live[key] + and not (isinstance(op, RegisterOp) and dest.is_borrowed) + ): maybe_append_dec_ref(ops, dest, post_must_defined, key) block.ops = ops def insert_branch_inc_and_decrefs( - block: BasicBlock, - cache: BlockCache, - blocks: List[BasicBlock], - pre_live: 'AnalysisDict[Value]', - pre_borrow: 'AnalysisDict[Value]', - post_borrow: 'AnalysisDict[Value]', - post_must_defined: 'AnalysisDict[Value]', - ordering: Dict[Value, int]) -> None: + block: BasicBlock, + cache: BlockCache, + blocks: List[BasicBlock], + pre_live: "AnalysisDict[Value]", + pre_borrow: "AnalysisDict[Value]", + post_borrow: "AnalysisDict[Value]", + post_must_defined: "AnalysisDict[Value]", + ordering: Dict[Value, int], +) -> None: """Insert inc_refs and/or dec_refs after a branch/goto. Add dec_refs for registers that become dead after a branch. @@ -176,46 +196,52 @@ def f(a: int) -> None omitted = () decs = after_branch_decrefs( - target, pre_live, source_defined, - source_borrowed, source_live_regs, ordering, omitted) - incs = after_branch_increfs( - target, pre_live, pre_borrow, source_borrowed, ordering) + target, pre_live, source_defined, source_borrowed, source_live_regs, ordering, omitted + ) + incs = after_branch_increfs(target, pre_live, pre_borrow, source_borrowed, ordering) term.set_target(i, add_block(decs, incs, cache, blocks, target)) -def after_branch_decrefs(label: BasicBlock, - pre_live: 'AnalysisDict[Value]', - source_defined: Set[Value], - source_borrowed: Set[Value], - source_live_regs: Set[Value], - ordering: Dict[Value, int], - omitted: Iterable[Value]) -> Tuple[Tuple[Value, bool], ...]: +def after_branch_decrefs( + label: BasicBlock, + pre_live: "AnalysisDict[Value]", + source_defined: Set[Value], + source_borrowed: Set[Value], + source_live_regs: Set[Value], + ordering: Dict[Value, int], + omitted: Iterable[Value], +) -> Tuple[Tuple[Value, bool], ...]: target_pre_live = pre_live[label, 0] decref = source_live_regs - target_pre_live - source_borrowed if decref: - return tuple((reg, is_maybe_undefined(source_defined, reg)) - for reg in sorted(decref, key=lambda r: ordering[r]) - if reg.type.is_refcounted and reg not in omitted) + return tuple( + (reg, is_maybe_undefined(source_defined, reg)) + for reg in sorted(decref, key=lambda r: ordering[r]) + if reg.type.is_refcounted and reg not in omitted + ) return () -def after_branch_increfs(label: BasicBlock, - pre_live: 'AnalysisDict[Value]', - pre_borrow: 'AnalysisDict[Value]', - source_borrowed: Set[Value], - ordering: Dict[Value, int]) -> Tuple[Value, ...]: +def after_branch_increfs( + label: BasicBlock, + pre_live: "AnalysisDict[Value]", + pre_borrow: "AnalysisDict[Value]", + source_borrowed: Set[Value], + ordering: Dict[Value, int], +) -> Tuple[Value, ...]: target_pre_live = pre_live[label, 0] target_borrowed = pre_borrow[label, 0] incref = (source_borrowed - target_borrowed) & target_pre_live if incref: - return tuple(reg - for reg in sorted(incref, key=lambda r: ordering[r]) - if reg.type.is_refcounted) + return tuple( + reg for reg in sorted(incref, key=lambda r: ordering[r]) if reg.type.is_refcounted + ) return () -def add_block(decs: Decs, incs: Incs, cache: BlockCache, - blocks: List[BasicBlock], label: BasicBlock) -> BasicBlock: +def add_block( + decs: Decs, incs: Incs, cache: BlockCache, blocks: List[BasicBlock], label: BasicBlock +) -> BasicBlock: if not decs and not incs: return label @@ -247,9 +273,11 @@ def make_value_ordering(ir: FuncIR) -> Dict[Value, int]: for block in ir.blocks: for op in block.ops: - if (isinstance(op, LoadAddress) - and isinstance(op.src, Register) - and op.src not in result): + if ( + isinstance(op, LoadAddress) + and isinstance(op.src, Register) + and op.src not in result + ): # Taking the address of a register allows initialization. result[op.src] = n n += 1 diff --git a/mypyc/transform/uninit.py b/mypyc/transform/uninit.py index ca21d2690636d..3b51ee26ad313 100644 --- a/mypyc/transform/uninit.py +++ b/mypyc/transform/uninit.py @@ -2,17 +2,20 @@ from typing import List -from mypyc.analysis.dataflow import ( - get_cfg, - cleanup_cfg, - analyze_must_defined_regs, - AnalysisDict -) +from mypyc.analysis.dataflow import AnalysisDict, analyze_must_defined_regs, cleanup_cfg, get_cfg +from mypyc.ir.func_ir import FuncIR, all_values from mypyc.ir.ops import ( - BasicBlock, Op, Branch, Value, RaiseStandardError, Unreachable, Register, - LoadAddress, Assign, LoadErrorValue + Assign, + BasicBlock, + Branch, + LoadAddress, + LoadErrorValue, + Op, + RaiseStandardError, + Register, + Unreachable, + Value, ) -from mypyc.ir.func_ir import FuncIR, all_values def insert_uninit_checks(ir: FuncIR) -> None: @@ -22,16 +25,15 @@ def insert_uninit_checks(ir: FuncIR) -> None: cfg = get_cfg(ir.blocks) must_defined = analyze_must_defined_regs( - ir.blocks, - cfg, - set(ir.arg_regs), - all_values(ir.arg_regs, ir.blocks)) + ir.blocks, cfg, set(ir.arg_regs), all_values(ir.arg_regs, ir.blocks) + ) ir.blocks = split_blocks_at_uninits(ir.blocks, must_defined.before) -def split_blocks_at_uninits(blocks: List[BasicBlock], - pre_must_defined: 'AnalysisDict[Value]') -> List[BasicBlock]: +def split_blocks_at_uninits( + blocks: List[BasicBlock], pre_must_defined: "AnalysisDict[Value]" +) -> List[BasicBlock]: new_blocks: List[BasicBlock] = [] init_registers = [] @@ -54,9 +56,12 @@ def split_blocks_at_uninits(blocks: List[BasicBlock], # Note that for register operand in a LoadAddress op, # we should be able to use it without initialization # as we may need to use its address to update itself - if (isinstance(src, Register) and src not in defined - and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR) - and not isinstance(op, LoadAddress)): + if ( + isinstance(src, Register) + and src not in defined + and not (isinstance(op, Branch) and op.op == Branch.IS_ERROR) + and not isinstance(op, LoadAddress) + ): new_block, error_block = BasicBlock(), BasicBlock() new_block.error_handler = error_block.error_handler = cur_block.error_handler new_blocks += [error_block, new_block] @@ -65,15 +70,20 @@ def split_blocks_at_uninits(blocks: List[BasicBlock], init_registers.append(src) init_registers_set.add(src) - cur_block.ops.append(Branch(src, - true_label=error_block, - false_label=new_block, - op=Branch.IS_ERROR, - line=op.line)) + cur_block.ops.append( + Branch( + src, + true_label=error_block, + false_label=new_block, + op=Branch.IS_ERROR, + line=op.line, + ) + ) raise_std = RaiseStandardError( RaiseStandardError.UNBOUND_LOCAL_ERROR, f'local variable "{src.name}" referenced before assignment', - op.line) + op.line, + ) error_block.ops.append(raise_std) error_block.ops.append(Unreachable()) cur_block = new_block diff --git a/runtests.py b/runtests.py index b075fdb4a5199..bd991d2ca2503 100755 --- a/runtests.py +++ b/runtests.py @@ -1,22 +1,22 @@ #!/usr/bin/env python3 import subprocess from subprocess import Popen -from sys import argv, exit, executable +from sys import argv, executable, exit # Slow test suites -CMDLINE = 'PythonCmdline' -SAMPLES = 'SamplesSuite' -TYPESHED = 'TypeshedSuite' -PEP561 = 'PEP561Suite' -EVALUATION = 'PythonEvaluation' -DAEMON = 'testdaemon' -STUBGEN_CMD = 'StubgenCmdLine' -STUBGEN_PY = 'StubgenPythonSuite' -MYPYC_RUN = 'TestRun' -MYPYC_RUN_MULTI = 'TestRunMultiFile' -MYPYC_EXTERNAL = 'TestExternal' -MYPYC_COMMAND_LINE = 'TestCommandLine' -ERROR_STREAM = 'ErrorStreamSuite' +CMDLINE = "PythonCmdline" +SAMPLES = "SamplesSuite" +TYPESHED = "TypeshedSuite" +PEP561 = "PEP561Suite" +EVALUATION = "PythonEvaluation" +DAEMON = "testdaemon" +STUBGEN_CMD = "StubgenCmdLine" +STUBGEN_PY = "StubgenPythonSuite" +MYPYC_RUN = "TestRun" +MYPYC_RUN_MULTI = "TestRunMultiFile" +MYPYC_EXTERNAL = "TestExternal" +MYPYC_COMMAND_LINE = "TestCommandLine" +ERROR_STREAM = "ErrorStreamSuite" ALL_NON_FAST = [ @@ -49,45 +49,40 @@ # time to run. cmds = { # Self type check - 'self': [executable, '-m', 'mypy', '--config-file', 'mypy_self_check.ini', '-p', 'mypy'], + "self": [executable, "-m", "mypy", "--config-file", "mypy_self_check.ini", "-p", "mypy"], # Lint - 'lint': ['flake8', '-j0'], + "lint": ["flake8", "-j0"], "format-black": ["black", "."], "format-isort": ["isort", "."], # Fast test cases only (this is the bulk of the test suite) - 'pytest-fast': ['pytest', '-q', '-k', f"not ({' or '.join(ALL_NON_FAST)})"], + "pytest-fast": ["pytest", "-q", "-k", f"not ({' or '.join(ALL_NON_FAST)})"], # Test cases that invoke mypy (with small inputs) - 'pytest-cmdline': ['pytest', '-q', '-k', ' or '.join([CMDLINE, - EVALUATION, - STUBGEN_CMD, - STUBGEN_PY])], + "pytest-cmdline": [ + "pytest", + "-q", + "-k", + " or ".join([CMDLINE, EVALUATION, STUBGEN_CMD, STUBGEN_PY]), + ], # Test cases that may take seconds to run each - 'pytest-slow': ['pytest', '-q', '-k', ' or '.join( - [SAMPLES, - TYPESHED, - DAEMON, - MYPYC_EXTERNAL, - MYPYC_COMMAND_LINE, - ERROR_STREAM])], - + "pytest-slow": [ + "pytest", + "-q", + "-k", + " or ".join([SAMPLES, TYPESHED, DAEMON, MYPYC_EXTERNAL, MYPYC_COMMAND_LINE, ERROR_STREAM]), + ], # Test cases that might take minutes to run - 'pytest-extra': ['pytest', '-q', '-k', ' or '.join(PYTEST_OPT_IN)], - + "pytest-extra": ["pytest", "-q", "-k", " or ".join(PYTEST_OPT_IN)], # Test cases to run in typeshed CI - 'typeshed-ci': ['pytest', '-q', '-k', ' or '.join([CMDLINE, - EVALUATION, - SAMPLES, - TYPESHED])], - + "typeshed-ci": ["pytest", "-q", "-k", " or ".join([CMDLINE, EVALUATION, SAMPLES, TYPESHED])], # Mypyc tests that aren't run by default, since they are slow and rarely # fail for commits that don't touch mypyc - 'mypyc-extra': ['pytest', '-q', '-k', ' or '.join(MYPYC_OPT_IN)], + "mypyc-extra": ["pytest", "-q", "-k", " or ".join(MYPYC_OPT_IN)], } # Stop run immediately if these commands fail -FAST_FAIL = ['self', 'lint'] +FAST_FAIL = ["self", "lint"] -EXTRA_COMMANDS = ('pytest-extra', 'mypyc-extra', 'typeshed-ci') +EXTRA_COMMANDS = ("pytest-extra", "mypyc-extra", "typeshed-ci") DEFAULT_COMMANDS = [cmd for cmd in cmds if cmd not in EXTRA_COMMANDS] assert all(cmd in cmds for cmd in FAST_FAIL) @@ -96,10 +91,10 @@ def run_cmd(name: str) -> int: status = 0 cmd = cmds[name] - print(f'run {name}: {cmd}') + print(f"run {name}: {cmd}") proc = subprocess.run(cmd, stderr=subprocess.STDOUT) if proc.returncode: - print('\nFAILED: %s' % name) + print("\nFAILED: %s" % name) status = proc.returncode if name in FAST_FAIL: exit(status) @@ -108,16 +103,14 @@ def run_cmd(name: str) -> int: def start_background_cmd(name: str) -> Popen: cmd = cmds[name] - proc = subprocess.Popen(cmd, - stderr=subprocess.STDOUT, - stdout=subprocess.PIPE) + proc = subprocess.Popen(cmd, stderr=subprocess.STDOUT, stdout=subprocess.PIPE) return proc def wait_background_cmd(name: str, proc: Popen) -> int: output = proc.communicate()[0] status = proc.returncode - print(f'run {name}: {cmds[name]}') + print(f"run {name}: {cmds[name]}") if status: print(output.decode().rstrip()) print("\nFAILED:", name) @@ -132,8 +125,10 @@ def main() -> None: if not set(args).issubset(cmds): print("usage:", prog, " ".join(f"[{k}]" for k in cmds)) print() - print('Run the given tests. If given no arguments, run everything except' - + ' pytest-extra and mypyc-extra.') + print( + "Run the given tests. If given no arguments, run everything except" + + " pytest-extra and mypyc-extra." + ) exit(1) if not args: @@ -141,16 +136,16 @@ def main() -> None: status = 0 - if 'self' in args and 'lint' in args: + if "self" in args and "lint" in args: # Perform lint and self check in parallel as it's faster. - proc = start_background_cmd('lint') - cmd_status = run_cmd('self') + proc = start_background_cmd("lint") + cmd_status = run_cmd("self") if cmd_status: status = cmd_status - cmd_status = wait_background_cmd('lint', proc) + cmd_status = wait_background_cmd("lint", proc) if cmd_status: status = cmd_status - args = [arg for arg in args if arg not in ('self', 'lint')] + args = [arg for arg in args if arg not in ("self", "lint")] for arg in args: cmd_status = run_cmd(arg) @@ -160,5 +155,5 @@ def main() -> None: exit(status) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/scripts/find_type.py b/scripts/find_type.py index 757c2a40fd15c..d52424952a33a 100755 --- a/scripts/find_type.py +++ b/scripts/find_type.py @@ -23,54 +23,65 @@ # # For an Emacs example, see misc/macs.el. -from typing import List, Tuple, Optional +import os.path +import re import subprocess import sys import tempfile -import os.path -import re +from typing import List, Optional, Tuple + +REVEAL_TYPE_START = "reveal_type(" +REVEAL_TYPE_END = ")" -REVEAL_TYPE_START = 'reveal_type(' -REVEAL_TYPE_END = ')' def update_line(line: str, s: str, pos: int) -> str: return line[:pos] + s + line[pos:] + def run_mypy(mypy_and_args: List[str], filename: str, tmp_name: str) -> str: - proc = subprocess.run(mypy_and_args + ['--shadow-file', filename, tmp_name], stdout=subprocess.PIPE) - assert(isinstance(proc.stdout, bytes)) # Guaranteed to be true because we called run with universal_newlines=False + proc = subprocess.run( + mypy_and_args + ["--shadow-file", filename, tmp_name], stdout=subprocess.PIPE + ) + assert isinstance( + proc.stdout, bytes + ) # Guaranteed to be true because we called run with universal_newlines=False return proc.stdout.decode(encoding="utf-8") + def get_revealed_type(line: str, relevant_file: str, relevant_line: int) -> Optional[str]: m = re.match(r'(.+?):(\d+): note: Revealed type is "(.*)"$', line) - if (m and - int(m.group(2)) == relevant_line and - os.path.samefile(relevant_file, m.group(1))): + if m and int(m.group(2)) == relevant_line and os.path.samefile(relevant_file, m.group(1)): return m.group(3) else: return None + def process_output(output: str, filename: str, start_line: int) -> Tuple[Optional[str], bool]: error_found = False for line in output.splitlines(): t = get_revealed_type(line, filename, start_line) if t: return t, error_found - elif 'error:' in line: + elif "error:" in line: error_found = True return None, True # finding no reveal_type is an error + def main(): - filename, start_line_str, start_col_str, end_line_str, end_col_str, *mypy_and_args = sys.argv[1:] + filename, start_line_str, start_col_str, end_line_str, end_col_str, *mypy_and_args = sys.argv[ + 1: + ] start_line = int(start_line_str) start_col = int(start_col_str) end_line = int(end_line_str) end_col = int(end_col_str) with open(filename) as f: lines = f.readlines() - lines[end_line - 1] = update_line(lines[end_line - 1], REVEAL_TYPE_END, end_col) # insert after end_col + lines[end_line - 1] = update_line( + lines[end_line - 1], REVEAL_TYPE_END, end_col + ) # insert after end_col lines[start_line - 1] = update_line(lines[start_line - 1], REVEAL_TYPE_START, start_col) - with tempfile.NamedTemporaryFile(mode='w', prefix='mypy') as tmp_f: + with tempfile.NamedTemporaryFile(mode="w", prefix="mypy") as tmp_f: tmp_f.writelines(lines) tmp_f.flush() diff --git a/setup.py b/setup.py index 7999fb20216e8..3a10bb54726d6 100644 --- a/setup.py +++ b/setup.py @@ -15,12 +15,13 @@ # This requires setuptools when building; setuptools is not needed # when installing from a wheel file (though it is still needed for # alternative forms of installing, as suggested by README.md). -from setuptools import setup, find_packages +from setuptools import find_packages, setup from setuptools.command.build_py import build_py + from mypy.version import __version__ as version -description = 'Optional static typing for Python' -long_description = ''' +description = "Optional static typing for Python" +long_description = """ Mypy -- Optional Static Typing for Python ========================================= @@ -30,10 +31,10 @@ actually having to run it. Mypy has a powerful type system with features such as type inference, gradual typing, generics and union types. -'''.lstrip() +""".lstrip() -def find_package_data(base, globs, root='mypy'): +def find_package_data(base, globs, root="mypy"): """Find all interesting data files, for setup(package_data=) Arguments: @@ -55,9 +56,9 @@ def find_package_data(base, globs, root='mypy'): class CustomPythonBuild(build_py): def pin_version(self): - path = os.path.join(self.build_lib, 'mypy') + path = os.path.join(self.build_lib, "mypy") self.mkpath(path) - with open(os.path.join(path, 'version.py'), 'w') as stream: + with open(os.path.join(path, "version.py"), "w") as stream: stream.write(f'__version__ = "{version}"\n') def run(self): @@ -65,152 +66,164 @@ def run(self): build_py.run(self) -cmdclass = {'build_py': CustomPythonBuild} +cmdclass = {"build_py": CustomPythonBuild} -package_data = ['py.typed'] +package_data = ["py.typed"] -package_data += find_package_data(os.path.join('mypy', 'typeshed'), ['*.py', '*.pyi']) -package_data += [os.path.join('mypy', 'typeshed', 'stdlib', 'VERSIONS')] +package_data += find_package_data(os.path.join("mypy", "typeshed"), ["*.py", "*.pyi"]) +package_data += [os.path.join("mypy", "typeshed", "stdlib", "VERSIONS")] -package_data += find_package_data(os.path.join('mypy', 'xml'), ['*.xsd', '*.xslt', '*.css']) +package_data += find_package_data(os.path.join("mypy", "xml"), ["*.xsd", "*.xslt", "*.css"]) USE_MYPYC = False # To compile with mypyc, a mypyc checkout must be present on the PYTHONPATH -if len(sys.argv) > 1 and sys.argv[1] == '--use-mypyc': +if len(sys.argv) > 1 and sys.argv[1] == "--use-mypyc": sys.argv.pop(1) USE_MYPYC = True -if os.getenv('MYPY_USE_MYPYC', None) == '1': +if os.getenv("MYPY_USE_MYPYC", None) == "1": USE_MYPYC = True if USE_MYPYC: - MYPYC_BLACKLIST = tuple(os.path.join('mypy', x) for x in ( - # Need to be runnable as scripts - '__main__.py', - 'pyinfo.py', - os.path.join('dmypy', '__main__.py'), - - # Uses __getattr__/__setattr__ - 'split_namespace.py', - - # Lies to mypy about code reachability - 'bogus_type.py', - - # We don't populate __file__ properly at the top level or something? - # Also I think there would be problems with how we generate version.py. - 'version.py', - - # Skip these to reduce the size of the build - 'stubtest.py', - 'stubgenc.py', - 'stubdoc.py', - 'stubutil.py', - )) + ( + MYPYC_BLACKLIST = tuple( + os.path.join("mypy", x) + for x in ( + # Need to be runnable as scripts + "__main__.py", + "pyinfo.py", + os.path.join("dmypy", "__main__.py"), + # Uses __getattr__/__setattr__ + "split_namespace.py", + # Lies to mypy about code reachability + "bogus_type.py", + # We don't populate __file__ properly at the top level or something? + # Also I think there would be problems with how we generate version.py. + "version.py", + # Skip these to reduce the size of the build + "stubtest.py", + "stubgenc.py", + "stubdoc.py", + "stubutil.py", + ) + ) + ( # Don't want to grab this accidentally - os.path.join('mypyc', 'lib-rt', 'setup.py'), + os.path.join("mypyc", "lib-rt", "setup.py"), # Uses __file__ at top level https://github.com/mypyc/mypyc/issues/700 - os.path.join('mypyc', '__main__.py'), + os.path.join("mypyc", "__main__.py"), ) - everything = ( - [os.path.join('mypy', x) for x in find_package_data('mypy', ['*.py'])] + - [os.path.join('mypyc', x) for x in find_package_data('mypyc', ['*.py'], root='mypyc')]) + everything = [os.path.join("mypy", x) for x in find_package_data("mypy", ["*.py"])] + [ + os.path.join("mypyc", x) for x in find_package_data("mypyc", ["*.py"], root="mypyc") + ] # Start with all the .py files - all_real_pys = [x for x in everything - if not x.startswith(os.path.join('mypy', 'typeshed') + os.sep)] + all_real_pys = [ + x for x in everything if not x.startswith(os.path.join("mypy", "typeshed") + os.sep) + ] # Strip out anything in our blacklist mypyc_targets = [x for x in all_real_pys if x not in MYPYC_BLACKLIST] # Strip out any test code - mypyc_targets = [x for x in mypyc_targets - if not x.startswith((os.path.join('mypy', 'test') + os.sep, - os.path.join('mypyc', 'test') + os.sep, - os.path.join('mypyc', 'doc') + os.sep, - os.path.join('mypyc', 'test-data') + os.sep, - ))] + mypyc_targets = [ + x + for x in mypyc_targets + if not x.startswith( + ( + os.path.join("mypy", "test") + os.sep, + os.path.join("mypyc", "test") + os.sep, + os.path.join("mypyc", "doc") + os.sep, + os.path.join("mypyc", "test-data") + os.sep, + ) + ) + ] # ... and add back in the one test module we need - mypyc_targets.append(os.path.join('mypy', 'test', 'visitors.py')) + mypyc_targets.append(os.path.join("mypy", "test", "visitors.py")) # The targets come out of file system apis in an unspecified # order. Sort them so that the mypyc output is deterministic. mypyc_targets.sort() - use_other_mypyc = os.getenv('ALTERNATE_MYPYC_PATH', None) + use_other_mypyc = os.getenv("ALTERNATE_MYPYC_PATH", None) if use_other_mypyc: # This bit is super unfortunate: we want to use a different # mypy/mypyc version, but we've already imported parts, so we # remove the modules that we've imported already, which will # let the right versions be imported by mypyc. - del sys.modules['mypy'] - del sys.modules['mypy.version'] - del sys.modules['mypy.git'] + del sys.modules["mypy"] + del sys.modules["mypy.version"] + del sys.modules["mypy.git"] sys.path.insert(0, use_other_mypyc) from mypyc.build import mypycify - opt_level = os.getenv('MYPYC_OPT_LEVEL', '3') - debug_level = os.getenv('MYPYC_DEBUG_LEVEL', '1') - force_multifile = os.getenv('MYPYC_MULTI_FILE', '') == '1' + + opt_level = os.getenv("MYPYC_OPT_LEVEL", "3") + debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1") + force_multifile = os.getenv("MYPYC_MULTI_FILE", "") == "1" ext_modules = mypycify( - mypyc_targets + ['--config-file=mypy_bootstrap.ini'], + mypyc_targets + ["--config-file=mypy_bootstrap.ini"], opt_level=opt_level, debug_level=debug_level, # Use multi-file compilation mode on windows because without it # our Appveyor builds run out of memory sometimes. - multi_file=sys.platform == 'win32' or force_multifile, + multi_file=sys.platform == "win32" or force_multifile, ) else: ext_modules = [] classifiers = [ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Topic :: Software Development', + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Software Development", ] -setup(name='mypy', - version=version, - description=description, - long_description=long_description, - author='Jukka Lehtosalo', - author_email='jukka.lehtosalo@iki.fi', - url='http://www.mypy-lang.org/', - license='MIT License', - py_modules=[], - ext_modules=ext_modules, - packages=find_packages(), - package_data={'mypy': package_data}, - entry_points={'console_scripts': ['mypy=mypy.__main__:console_entry', - 'stubgen=mypy.stubgen:main', - 'stubtest=mypy.stubtest:main', - 'dmypy=mypy.dmypy.client:console_entry', - 'mypyc=mypyc.__main__:main', - ]}, - classifiers=classifiers, - cmdclass=cmdclass, - # When changing this, also update mypy-requirements.txt. - install_requires=["typed_ast >= 1.4.0, < 2; python_version<'3.8'", - 'typing_extensions>=3.10', - 'mypy_extensions >= 0.4.3', - "tomli>=1.1.0; python_version<'3.11'", - ], - # Same here. - extras_require={ - 'dmypy': 'psutil >= 4.0', - 'python2': 'typed_ast >= 1.4.0, < 2', - 'reports': 'lxml' - }, - python_requires=">=3.6", - include_package_data=True, - project_urls={ - 'News': 'http://mypy-lang.org/news.html', - 'Documentation': 'https://mypy.readthedocs.io/en/stable/index.html', - 'Repository': 'https://github.com/python/mypy', - }, - ) +setup( + name="mypy", + version=version, + description=description, + long_description=long_description, + author="Jukka Lehtosalo", + author_email="jukka.lehtosalo@iki.fi", + url="http://www.mypy-lang.org/", + license="MIT License", + py_modules=[], + ext_modules=ext_modules, + packages=find_packages(), + package_data={"mypy": package_data}, + entry_points={ + "console_scripts": [ + "mypy=mypy.__main__:console_entry", + "stubgen=mypy.stubgen:main", + "stubtest=mypy.stubtest:main", + "dmypy=mypy.dmypy.client:console_entry", + "mypyc=mypyc.__main__:main", + ] + }, + classifiers=classifiers, + cmdclass=cmdclass, + # When changing this, also update mypy-requirements.txt. + install_requires=[ + "typed_ast >= 1.4.0, < 2; python_version<'3.8'", + "typing_extensions>=3.10", + "mypy_extensions >= 0.4.3", + "tomli>=1.1.0; python_version<'3.11'", + ], + # Same here. + extras_require={ + "dmypy": "psutil >= 4.0", + "python2": "typed_ast >= 1.4.0, < 2", + "reports": "lxml", + }, + python_requires=">=3.6", + include_package_data=True, + project_urls={ + "News": "http://mypy-lang.org/news.html", + "Documentation": "https://mypy.readthedocs.io/en/stable/index.html", + "Repository": "https://github.com/python/mypy", + }, +) diff --git a/tox.ini b/tox.ini index c2159cd6fdba0..d2284813195e4 100644 --- a/tox.ini +++ b/tox.ini @@ -49,8 +49,8 @@ parallel_show_output = True description = check the code style commands = flake8 {posargs} - black --check --diff . - isort --check --diff . + black --check --diff --color . + isort --check --diff --color . [testenv:type] description = type check ourselves