Skip to content

Support attributes and type aliases #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/example_pkg-stubs/_basic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import logging
from collections.abc import Sequence
from typing import Any, Literal, Self, Union

from _typeshed import Incomplete

from . import CustomException

logger = ...
logger: Incomplete

__all__ = [
"func_empty",
Expand Down
74 changes: 60 additions & 14 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import re
import typing
from dataclasses import asdict, dataclass
from functools import cache
from pathlib import Path

import libcst as cst
import libcst.matchers as cstm

from ._utils import accumulate_qualname, module_name_from_path, pyfile_checksum

Expand Down Expand Up @@ -45,13 +47,13 @@ class KnownImport:

Attributes
----------
import_path : str, optional
import_path
Dotted names after "from".
import_name : str, optional
import_name
Dotted names after "import".
import_alias : str, optional
import_alias
Name (without ".") after "as".
builtin_name : str, optional
builtin_name
Names an object that's builtin and doesn't need an import.

Examples
Expand All @@ -65,6 +67,26 @@ class KnownImport:
import_alias: str = None
builtin_name: str = None

@classmethod
@cache
def typeshed_Incomplete(cls):
"""Create import corresponding to ``from _typeshed import Incomplete``.

This type is not actually available at runtime and only intended to be
used in stub files [1]_.

Returns
-------
import : KnownImport
The import corresponding to ``from _typeshed import Incomplete``.

References
----------
.. [1] https://typing.readthedocs.io/en/latest/guides/writing_stubs.html#incomplete-stubs
"""
import_ = cls(import_path="_typeshed", import_name="Incomplete")
return import_

@classmethod
def one_from_config(cls, name, *, info):
"""Create one KnownImport from the configuration format.
Expand Down Expand Up @@ -327,23 +349,47 @@ def __init__(self, *, module_name):

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self._stack.append(node.name.value)

class_name = ".".join(self._stack[:1])
qualname = f"{self.module_name}.{'.'.join(self._stack)}"
known_import = KnownImport(import_path=self.module_name, import_name=class_name)
self.known_imports[qualname] = known_import

self._collect_type_annotation(self._stack)
return True

def leave_ClassDef(self, original_node: cst.ClassDef) -> None:
self._stack.pop()

def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self._stack.append(node.name.value)
return True
return False

def visit_TypeAlias(self, node: cst.TypeAlias) -> bool:
"""Collect type alias with 3.12 syntax."""
stack = [*self._stack, node.name.value]
self._collect_type_annotation(stack)
return False

def visit_AnnAssign(self, node: cst.AnnAssign) -> bool:
"""Collect type alias annotated with `TypeAlias`."""
is_type_alias = cstm.matches(
node,
cstm.AnnAssign(
annotation=cstm.Annotation(annotation=cstm.Name(value="TypeAlias"))
),
)
if is_type_alias and node.value is not None:
names = cstm.findall(node.target, cstm.Name())
assert len(names) == 1
stack = [*self._stack, names[0].value]
self._collect_type_annotation(stack)
return False

def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None:
self._stack.pop()
def _collect_type_annotation(self, stack):
"""Collect an importable type annotation.

Parameters
----------
stack : Iterable[str]
A list of names that form the path to the collected type.
"""
qualname = ".".join([self.module_name, *stack])
known_import = KnownImport(import_path=self.module_name, import_name=stack[0])
self.known_imports[qualname] = known_import


class TypesDatabase:
Expand Down
97 changes: 90 additions & 7 deletions src/docstub/_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpydoc.docscrape import NumpyDocString

from ._analysis import KnownImport
from ._utils import ContextFormatter, accumulate_qualname, escape_qualname
from ._utils import ContextFormatter, DocstubError, accumulate_qualname, escape_qualname

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -135,16 +135,29 @@ def _aggregate_annotations(*types):
return values, imports


GrammarErrorFallback = Annotation(
value="Any",
imports=frozenset((KnownImport(import_path="typing", import_name="Any"),)),
FallbackAnnotation = Annotation(
value="Incomplete", imports=frozenset([KnownImport.typeshed_Incomplete()])
)


class QualnameIsKeyword(DocstubError):
"""Raised when a qualname is a blacklisted Python keyword."""


@lark.visitors.v_args(tree=True)
class DoctypeTransformer(lark.visitors.Transformer):
"""Transformer for docstring type descriptions (doctypes).

Attributes
----------
blacklisted_qualnames : frozenset[str]
All Python keywords [1]_ are blacklisted from use in qualnames except for ``True``
``False`` and ``None``.

References
----------
.. [1] https://docs.python.org/3/reference/lexical_analysis.html#keywords

Examples
--------
>>> transformer = DoctypeTransformer()
Expand All @@ -155,6 +168,43 @@ class DoctypeTransformer(lark.visitors.Transformer):
[('tuple', 0, 5), ('int', 9, 12)]
"""

blacklisted_qualnames = frozenset(
{
"await",
"else",
"import",
"pass",
"break",
"except",
"in",
"raise",
"class",
"finally",
"is",
"return",
"and",
"continue",
"for",
"lambda",
"try",
"as",
"def",
"from",
"nonlocal",
"while",
"assert",
"del",
"global",
"not",
"with",
"async",
"elif",
"if",
"or",
"yield",
}
)

def __init__(self, *, types_db=None, replace_doctypes=None, **kwargs):
"""
Parameters
Expand Down Expand Up @@ -204,7 +254,11 @@ def doctype_to_annotation(self, doctype):
value=value, imports=frozenset(self._collected_imports)
)
return annotation, self._unknown_qualnames
except (lark.exceptions.LexError, lark.exceptions.ParseError):
except (
lark.exceptions.LexError,
lark.exceptions.ParseError,
QualnameIsKeyword,
):
self.stats["grammar_errors"] += 1
raise
finally:
Expand Down Expand Up @@ -274,6 +328,13 @@ def qualname(self, tree):

_qualname = self._find_import(_qualname, meta=tree.meta)

if _qualname in self.blacklisted_qualnames:
msg = (
f"qualname {_qualname!r} in docstring type description "
"is a reserved Python keyword and not allowed"
)
raise QualnameIsKeyword(msg)

_qualname = lark.Token(type="QUALNAME", value=_qualname)
return _qualname

Expand Down Expand Up @@ -399,7 +460,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
details = details.replace("^", click.style("^", fg="red", bold=True))
if ctx:
ctx.print_message("invalid syntax in doctype", details=details)
return GrammarErrorFallback
return FallbackAnnotation

except lark.visitors.VisitError as e:
tb = "\n".join(traceback.format_exception(e.orig_exc))
Expand All @@ -408,7 +469,7 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
ctx.print_message(
"unexpected error while parsing doctype", details=details
)
return GrammarErrorFallback
return FallbackAnnotation

else:
for name, start_col, stop_col in unknown_qualnames:
Expand All @@ -421,6 +482,28 @@ def _doctype_to_annotation(self, doctype, ds_line=0):
)
return annotation

@cached_property
def attributes(self) -> dict[str, Annotation]:
annotations = {}
for attribute in self.np_docstring["Attributes"]:
if not attribute.type:
continue

ds_line = 0
for i, line in enumerate(self.docstring.split("\n")):
if attribute.name in line and attribute.type in line:
ds_line = i
break

if attribute.name in annotations:
logger.warning("duplicate parameter name %r, ignoring", attribute.name)
continue

annotation = self._doctype_to_annotation(attribute.type, ds_line=ds_line)
annotations[attribute.name] = annotation

return annotations

@cached_property
def parameters(self) -> dict[str, Annotation]:
all_params = chain(
Expand Down
Loading