Skip to content

Commit

Permalink
stubgen: multiple fixes to the generated imports (#15624)
Browse files Browse the repository at this point in the history
* Fix handling of nested imports.
Instead of assuming that a name is imported from a top level package,
look in the imports for this name starting from the parent submodule up
until the import is found
* Fix "from imports" getting reexported unnecessarily
* Fix import sorting when having import aliases

Fixes #13661
Fixes #7006
  • Loading branch information
hamdanal authored Sep 25, 2023
1 parent 9edda9a commit 0c8b761
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 10 deletions.
24 changes: 17 additions & 7 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ def add_import(self, module: str, alias: str | None = None) -> None:
name = name.rpartition(".")[0]

def require_name(self, name: str) -> None:
self.required_names.add(name.split(".")[0])
while name not in self.direct_imports and "." in name:
name = name.rsplit(".", 1)[0]
self.required_names.add(name)

def reexport(self, name: str) -> None:
"""Mark a given non qualified name as needed in __all__.
Expand All @@ -516,7 +518,10 @@ def import_lines(self) -> list[str]:
# be imported from it. the names can also be alias in the form 'original as alias'
module_map: Mapping[str, list[str]] = defaultdict(list)

for name in sorted(self.required_names):
for name in sorted(
self.required_names,
key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""),
):
# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue
Expand All @@ -540,7 +545,7 @@ def import_lines(self) -> list[str]:
assert "." not in name # Because reexports only has nonqualified names
result.append(f"import {name} as {name}\n")
else:
result.append(f"import {self.direct_imports[name]}\n")
result.append(f"import {name}\n")

# Now generate all the from ... import ... lines collected in module_map
for module, names in sorted(module_map.items()):
Expand Down Expand Up @@ -595,7 +600,7 @@ def visit_name_expr(self, e: NameExpr) -> None:
self.refs.add(e.name)

def visit_instance(self, t: Instance) -> None:
self.add_ref(t.type.fullname)
self.add_ref(t.type.name)
super().visit_instance(t)

def visit_unbound_type(self, t: UnboundType) -> None:
Expand All @@ -614,7 +619,10 @@ def visit_callable_type(self, t: CallableType) -> None:
t.ret_type.accept(self)

def add_ref(self, fullname: str) -> None:
self.refs.add(fullname.split(".")[-1])
self.refs.add(fullname)
while "." in fullname:
fullname = fullname.rsplit(".", 1)[0]
self.refs.add(fullname)


class StubGenerator(mypy.traverser.TraverserVisitor):
Expand Down Expand Up @@ -1295,6 +1303,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
if (
as_name is None
and name not in self.referenced_names
and not any(n.startswith(name + ".") for n in self.referenced_names)
and (not self._all_ or name in IGNORED_DUNDERS)
and not is_private
and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES
Expand All @@ -1303,14 +1312,15 @@ def visit_import_from(self, o: ImportFrom) -> None:
# exported, unless there is an explicit __all__. Note that we need to special
# case 'abc' since some references are deleted during semantic analysis.
exported = True
top_level = full_module.split(".")[0]
top_level = full_module.split(".", 1)[0]
self_top_level = self.module.split(".", 1)[0]
if (
as_name is None
and not self.export_less
and (not self._all_ or name in IGNORED_DUNDERS)
and self.module
and not is_private
and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0])
and top_level in (self_top_level, "_" + self_top_level)
):
# Export imports from the same package, since we can't reliably tell whether they
# are part of the public API.
Expand Down
60 changes: 57 additions & 3 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -2772,9 +2772,9 @@ y: b.Y
z: p.a.X

[out]
import p.a
import p.a as a
import p.b as b
import p.a

x: a.X
y: b.Y
Expand All @@ -2787,7 +2787,7 @@ from p import a
x: a.X

[out]
from p import a as a
from p import a

x: a.X

Expand All @@ -2809,7 +2809,7 @@ from p import a
x: a.X

[out]
from p import a as a
from p import a

x: a.X

Expand Down Expand Up @@ -2859,6 +2859,60 @@ import p.a
x: a.X
y: p.a.Y

[case testNestedImports]
import p
import p.m1
import p.m2

x: p.X
y: p.m1.Y
z: p.m2.Z

[out]
import p
import p.m1
import p.m2

x: p.X
y: p.m1.Y
z: p.m2.Z

[case testNestedImportsAliased]
import p as t
import p.m1 as pm1
import p.m2 as pm2

x: t.X
y: pm1.Y
z: pm2.Z

[out]
import p as t
import p.m1 as pm1
import p.m2 as pm2

x: t.X
y: pm1.Y
z: pm2.Z

[case testNestedFromImports]
from p import m1
from p.m1 import sm1
from p.m2 import sm2

x: m1.X
y: sm1.Y
z: sm2.Z

[out]
from p import m1
from p.m1 import sm1
from p.m2 import sm2

x: m1.X
y: sm1.Y
z: sm2.Z

[case testOverload_fromTypingImport]
from typing import Tuple, Union, overload

Expand Down

0 comments on commit 0c8b761

Please sign in to comment.