From 0c8b76195a773363721d5521653bcdf9989d8768 Mon Sep 17 00:00:00 2001 From: Ali Hamdan Date: Mon, 25 Sep 2023 02:28:08 +0200 Subject: [PATCH] stubgen: multiple fixes to the generated imports (#15624) * Fix handling of nested imports. Instead of assuming that a name is imported from a top level package, look in the imports for this name starting from the parent submodule up until the import is found * Fix "from imports" getting reexported unnecessarily * Fix import sorting when having import aliases Fixes #13661 Fixes #7006 --- mypy/stubgen.py | 24 ++++++++++----- test-data/unit/stubgen.test | 60 +++++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 10 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index ca7249465746..e8c12ee4d99b 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -496,7 +496,9 @@ def add_import(self, module: str, alias: str | None = None) -> None: name = name.rpartition(".")[0] def require_name(self, name: str) -> None: - self.required_names.add(name.split(".")[0]) + while name not in self.direct_imports and "." in name: + name = name.rsplit(".", 1)[0] + self.required_names.add(name) def reexport(self, name: str) -> None: """Mark a given non qualified name as needed in __all__. @@ -516,7 +518,10 @@ def import_lines(self) -> list[str]: # be imported from it. the names can also be alias in the form 'original as alias' module_map: Mapping[str, list[str]] = defaultdict(list) - for name in sorted(self.required_names): + for name in sorted( + self.required_names, + key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""), + ): # If we haven't seen this name in an import statement, ignore it if name not in self.module_for: continue @@ -540,7 +545,7 @@ def import_lines(self) -> list[str]: assert "." not in name # Because reexports only has nonqualified names result.append(f"import {name} as {name}\n") else: - result.append(f"import {self.direct_imports[name]}\n") + result.append(f"import {name}\n") # Now generate all the from ... import ... lines collected in module_map for module, names in sorted(module_map.items()): @@ -595,7 +600,7 @@ def visit_name_expr(self, e: NameExpr) -> None: self.refs.add(e.name) def visit_instance(self, t: Instance) -> None: - self.add_ref(t.type.fullname) + self.add_ref(t.type.name) super().visit_instance(t) def visit_unbound_type(self, t: UnboundType) -> None: @@ -614,7 +619,10 @@ def visit_callable_type(self, t: CallableType) -> None: t.ret_type.accept(self) def add_ref(self, fullname: str) -> None: - self.refs.add(fullname.split(".")[-1]) + self.refs.add(fullname) + while "." in fullname: + fullname = fullname.rsplit(".", 1)[0] + self.refs.add(fullname) class StubGenerator(mypy.traverser.TraverserVisitor): @@ -1295,6 +1303,7 @@ def visit_import_from(self, o: ImportFrom) -> None: if ( as_name is None and name not in self.referenced_names + and not any(n.startswith(name + ".") for n in self.referenced_names) and (not self._all_ or name in IGNORED_DUNDERS) and not is_private and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES @@ -1303,14 +1312,15 @@ def visit_import_from(self, o: ImportFrom) -> None: # exported, unless there is an explicit __all__. Note that we need to special # case 'abc' since some references are deleted during semantic analysis. exported = True - top_level = full_module.split(".")[0] + top_level = full_module.split(".", 1)[0] + self_top_level = self.module.split(".", 1)[0] if ( as_name is None and not self.export_less and (not self._all_ or name in IGNORED_DUNDERS) and self.module and not is_private - and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0]) + and top_level in (self_top_level, "_" + self_top_level) ): # Export imports from the same package, since we can't reliably tell whether they # are part of the public API. diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 828680fadcf2..23dbf36a551b 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2772,9 +2772,9 @@ y: b.Y z: p.a.X [out] +import p.a import p.a as a import p.b as b -import p.a x: a.X y: b.Y @@ -2787,7 +2787,7 @@ from p import a x: a.X [out] -from p import a as a +from p import a x: a.X @@ -2809,7 +2809,7 @@ from p import a x: a.X [out] -from p import a as a +from p import a x: a.X @@ -2859,6 +2859,60 @@ import p.a x: a.X y: p.a.Y +[case testNestedImports] +import p +import p.m1 +import p.m2 + +x: p.X +y: p.m1.Y +z: p.m2.Z + +[out] +import p +import p.m1 +import p.m2 + +x: p.X +y: p.m1.Y +z: p.m2.Z + +[case testNestedImportsAliased] +import p as t +import p.m1 as pm1 +import p.m2 as pm2 + +x: t.X +y: pm1.Y +z: pm2.Z + +[out] +import p as t +import p.m1 as pm1 +import p.m2 as pm2 + +x: t.X +y: pm1.Y +z: pm2.Z + +[case testNestedFromImports] +from p import m1 +from p.m1 import sm1 +from p.m2 import sm2 + +x: m1.X +y: sm1.Y +z: sm2.Z + +[out] +from p import m1 +from p.m1 import sm1 +from p.m2 import sm2 + +x: m1.X +y: sm1.Y +z: sm2.Z + [case testOverload_fromTypingImport] from typing import Tuple, Union, overload