Skip to content

Commit

Permalink
Fixed docstrings being unavailable after instrumentation
Browse files Browse the repository at this point in the history
Fixes #359.
  • Loading branch information
agronholm committed Jun 12, 2023
1 parent 845fcbc commit 05b1c67
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ This library adheres to `Semantic Versioning 2.0 <https://semver.org/#semantic-v
(`#363 <https://github.com/agronholm/typeguard/issues/363>`_; PR by Alex Waygood)
- Fixed ``NameError`` when generated type checking code references an imported name from
a method (`#362 <https://github.com/agronholm/typeguard/issues/362>`_)
- Fixed docstrings disappearing from instrumented functions
(`#359 <https://github.com/agronholm/typeguard/issues/359>`_)

**4.0.0** (2023-05-12)

Expand Down
39 changes: 23 additions & 16 deletions src/typeguard/_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TransformMemo:
should_instrument: bool = field(init=False, default=True)
variable_annotations: dict[str, expr] = field(init=False, default_factory=dict)
configuration_overrides: dict[str, Any] = field(init=False, default_factory=dict)
code_inject_index: int = field(init=False, default=0)

def __post_init__(self) -> None:
elements: list[str] = []
Expand All @@ -152,6 +153,18 @@ def __post_init__(self) -> None:

self.joined_path = Constant(".".join(elements))

# Figure out where to insert instrumentation code
if self.node:
for index, child in enumerate(self.node.body):
if isinstance(child, ImportFrom) and child.module == "__future__":
# (module only) __future__ imports must come first
continue
elif isinstance(child, Expr) and isinstance(child.value, Str):
continue # docstring

self.code_inject_index = index
break

def get_unused_name(self, name: str) -> str:
memo: TransformMemo | None = self
while memo is not None:
Expand Down Expand Up @@ -212,20 +225,12 @@ def insert_imports(self, node: Module | FunctionDef | AsyncFunctionDef) -> None:
return

# Insert imports after any "from __future__ ..." imports and any docstring
for i, child in enumerate(node.body):
if isinstance(child, ImportFrom) and child.module == "__future__":
continue
elif isinstance(child, Expr) and isinstance(child.value, Str):
continue # module docstring

for modulename, names in self.load_names.items():
aliases = [
alias(orig_name, new_name.id if orig_name != new_name.id else None)
for orig_name, new_name in sorted(names.items())
]
node.body.insert(i, ImportFrom(modulename, aliases, 0))

break
for modulename, names in self.load_names.items():
aliases = [
alias(orig_name, new_name.id if orig_name != new_name.id else None)
for orig_name, new_name in sorted(names.items())
]
node.body.insert(self.code_inject_index, ImportFrom(modulename, aliases, 0))

def name_matches(self, expression: expr | Expr | None, *names: str) -> bool:
if expression is None:
Expand Down Expand Up @@ -757,7 +762,9 @@ def visit_FunctionDef(
annotations_dict,
self._memo.get_memo_name(),
]
node.body.insert(0, Expr(Call(func_name, args, [])))
node.body.insert(
self._memo.code_inject_index, Expr(Call(func_name, args, []))
)

# Add a checked "return None" to the end if there's no explicit return
# Skip if the return annotation is None or Any
Expand Down Expand Up @@ -859,7 +866,7 @@ def visit_FunctionDef(
[keyword(key, value) for key, value in memo_kwargs.items()],
)
node.body.insert(
0,
self._memo.code_inject_index,
Assign([memo_store_name], memo_expr),
)

Expand Down
27 changes: 27 additions & 0 deletions tests/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,3 +1521,30 @@ def foo(x: Annotated[str, 'foo bar']) -> None:
"""
).strip()
)


def test_respect_docstring() -> None:
# Regression test for #359
node = parse(
dedent(
'''
def foo() -> int:
"""This is a docstring."""
return 1
'''
)
)
TypeguardTransformer(["foo"]).visit(node)
assert (
unparse(node)
== dedent(
'''
def foo() -> int:
"""This is a docstring."""
from typeguard import TypeCheckMemo
from typeguard._functions import check_return_type
memo = TypeCheckMemo(globals(), locals())
return check_return_type('foo', 1, int, memo)
'''
).strip()
)

0 comments on commit 05b1c67

Please sign in to comment.