Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ There are several interesting things to note here:
- Optional arguments that default to `None` are recognized and a `| None` is appended automatically if the type doesn't include it already.
The `optional` or `default = ...` part don't influence the annotation.

- Common container types from Python's standard library such as `Iterable` can be used and a necessary import will be added automatically.
- Referencing the `float` and `Iterable` types worked out of the box.
All builtin types as well as types from the standard libraries `typing` and `collections.abc` module can be used.
Necessary imports will be added automatically to the stub file.


## Using types & nicknames
Expand Down
2 changes: 1 addition & 1 deletion examples/example_pkg-stubs/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def func_literals(
def func_use_from_elsewhere(
a1: CustomException,
a2: ExampleClass,
a3: CustomException.NestedClass,
a3: ExampleClass.NestedClass,
a4: ExampleClass.NestedClass,
) -> tuple[CustomException, ExampleClass.NestedClass]: ...

Expand Down
3 changes: 2 additions & 1 deletion examples/example_pkg-stubs/_numpy.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# File generated with docstub

import numpy
import numpy as np
from numpy.typing import ArrayLike, NDArray

def func_object_with_numpy_objects(
a1: np.int8, a2: np.int16, a3: np.typing.DTypeLike, a4: np.typing.DTypeLike
a1: numpy.int8, a2: np.int16, a3: numpy.typing.DTypeLike, a4: np.typing.DTypeLike
) -> None: ...
def func_ndarray(
a1: NDArray, a2: np.NDArray, a3: NDArray[float], a4: NDArray[np.uint8] | None = ...
Expand Down
2 changes: 1 addition & 1 deletion examples/example_pkg/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def func_contains(a1, a2, a3, a4, a5, a6, a7):
----------
a1 : list[float]
a2 : dict[str, Union[int, str]]
a3 : Sequence[int | float]
a3 : collections.abc.Sequence[int | float]
a4 : frozenset[bytes]
a5 : tuple of int
a6 : list of (int, str)
Expand Down
67 changes: 28 additions & 39 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Collect type information."""

import builtins
import collections.abc
import importlib
import json
import logging
import re
import typing
from dataclasses import asdict, dataclass
from functools import cache
from pathlib import Path
Expand Down Expand Up @@ -227,7 +226,7 @@ def _is_type(value):
return is_type


def _builtin_imports():
def _builtin_types():
"""Return known imports for all builtins (in the current runtime).

Returns
Expand All @@ -248,45 +247,24 @@ def _builtin_imports():
return known_imports


def _typing_imports():
"""Return known imports for public types in the `typing` module.

Returns
-------
known_imports : dict[str, KnownImport]
"""
known_imports = {}
for name in typing.__all__:
def _runtime_types_in_module(module_name):
module = importlib.import_module(module_name)
types = {}
for name in module.__all__:
if name.startswith("_"):
continue
value = getattr(typing, name)
value = getattr(module, name)
if not _is_type(value):
continue
known_imports[name] = KnownImport.one_from_config(name, info={"from": "typing"})
return known_imports


def _collections_abc_imports():
"""Return known imports for public types in the `collections.abc` module.
import_ = KnownImport(import_path=module_name, import_name=name)
types[name] = import_
types[f"{module_name}.{name}"] = import_

Returns
-------
known_imports : dict[str, KnownImport]
"""
known_imports = {}
for name in collections.abc.__all__:
if name.startswith("_"):
continue
value = getattr(collections.abc, name)
if not _is_type(value):
continue
known_imports[name] = KnownImport.one_from_config(
name, info={"from": "collections.abc"}
)
return known_imports
return types


def common_known_imports():
def common_known_types():
"""Return known imports for commonly supported types.

This includes builtin types, and types from the `typing` or
Expand All @@ -295,10 +273,21 @@ def common_known_imports():
Returns
-------
known_imports : dict[str, KnownImport]

Examples
--------
>>> types = common_known_types()
>>> types["str"]
<KnownImport str (builtin)>
>>> types["Iterable"]
<KnownImport 'from collections.abc import Iterable'>
>>> types["collections.abc.Iterable"]
<KnownImport 'from collections.abc import Iterable'>
"""
known_imports = _builtin_imports()
known_imports |= _typing_imports()
known_imports |= _collections_abc_imports() # Overrides containers from typing
known_imports = _builtin_types()
known_imports |= _runtime_types_in_module("typing")
# Overrides containers from typing
known_imports |= _runtime_types_in_module("collections.abc")
return known_imports


Expand Down Expand Up @@ -426,7 +415,7 @@ class TypeMatcher:

Examples
--------
>>> from docstub._analysis import TypeMatcher, common_known_imports
>>> from docstub._analysis import TypeMatcher, common_known_types
>>> db = TypeMatcher()
>>> db.match("Any")
('Any', <KnownImport 'from typing import Any'>)
Expand All @@ -446,7 +435,7 @@ def __init__(
type_prefixes : dict[str, KnownImport]
type_nicknames : dict[str, str]
"""
self.types = types or common_known_imports()
self.types = types or common_known_types()
self.type_prefixes = type_prefixes or {}
self.type_nicknames = type_nicknames or {}
self.successful_queries = 0
Expand Down
6 changes: 3 additions & 3 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
KnownImport,
TypeCollector,
TypeMatcher,
common_known_imports,
common_known_types,
)
from ._cache import FileCache
from ._config import Config
Expand Down Expand Up @@ -89,7 +89,7 @@ def _collect_types(root_path):
-------
types : dict[str, ~.KnownImport]
"""
types = common_known_imports()
types = common_known_types()

collect_cached_types = FileCache(
func=TypeCollector.collect,
Expand Down Expand Up @@ -213,7 +213,7 @@ def run(root_path, out_dir, config_paths, group_errors, allow_errors, verbose):

config = _load_configuration(config_paths)

types = common_known_imports()
types = common_known_types()
types |= _collect_types(root_path)
types |= {
type_name: KnownImport(import_path=module, import_name=type_name)
Expand Down
23 changes: 22 additions & 1 deletion tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import pytest

from docstub._analysis import KnownImport, TypeCollector, TypeMatcher
from docstub._analysis import (
KnownImport,
TypeCollector,
TypeMatcher,
)


class Test_KnownImport:
Expand Down Expand Up @@ -182,3 +186,20 @@ def test_query_prefix(self, search_name, expected_name, expected_origin):
assert type_name.startswith(type_origin.target)
assert type_name == expected_name
# fmt: on

@pytest.mark.parametrize(
("search_name", "import_path"),
[
("Iterable", "collections.abc"),
("collections.abc.Iterable", "collections.abc"),
("Literal", "typing"),
("typing.Literal", "typing"),
],
)
def test_common_known_types(self, search_name, import_path):
matcher = TypeMatcher()
type_name, type_origin = matcher.match(search_name)

assert type_name == search_name.split(".")[-1]
assert type_origin is not None
assert type_origin.import_path == import_path