Skip to content

Commit b2c878d

Browse files
brettimusBrett Beutell
andauthored
Add support for decorating async functions (#33)
* Add support for decorating async functions * Update types for async * Add a test of a basic async function * Fix typo in fastapi-example.py * Format code * Update fastapi-example.py with async code as well * Implement PR feedback (fix docstrings, rename symbol) * Add test for async exceptions * Update django dep in example to resolve dependabot alert --------- Co-authored-by: Brett Beutell <brett@fiberplane.com>
1 parent 0a11a33 commit b2c878d

File tree

10 files changed

+213
-29
lines changed

10 files changed

+213
-29
lines changed

examples/django_example/django_example/asgi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111

1212
from django.core.asgi import get_asgi_application
1313

14-
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_example.settings')
14+
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_example.settings")
1515

1616
application = get_asgi_application()

examples/django_example/django_example/wsgi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111

1212
from django.core.wsgi import get_wsgi_application
1313

14-
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_example.settings')
14+
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_example.settings")
1515

1616
application = get_wsgi_application()

examples/django_example/manage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
def main():
88
"""Run administrative tasks."""
9-
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'django_example.settings')
9+
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_example.settings")
1010
try:
1111
from django.core.management import execute_from_command_line
1212
except ImportError as exc:
@@ -18,5 +18,5 @@ def main():
1818
execute_from_command_line(sys.argv)
1919

2020

21-
if __name__ == '__main__':
21+
if __name__ == "__main__":
2222
main()

examples/example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,6 @@ def random_error():
9494

9595
try:
9696
# Call random_error. It will randomly raise an error or return "ok"
97-
random_error()
97+
random_error()
9898
except Exception:
9999
pass

examples/fastapi-example.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from fastapi import FastAPI, Response
23
import uvicorn
34
from autometrics import autometrics
@@ -7,24 +8,40 @@
78

89

910
# Set up a metrics endpoint for Prometheus to scrape
10-
# `generate_lates` returns the latest metrics data in the Prometheus text format
11+
# `generate_latest` returns the latest metrics data in the Prometheus text format
1112
@app.get("/metrics")
1213
def metrics():
1314
return Response(generate_latest())
1415

1516

1617
# Set up the root endpoint of the API
17-
@autometrics
1818
@app.get("/")
19+
@autometrics
1920
def read_root():
2021
do_something()
2122
return {"Hello": "World"}
2223

2324

25+
# Set up an async handler
26+
@app.get("/async")
27+
@autometrics
28+
async def async_route():
29+
message = await do_something_async()
30+
return {"Hello": message}
31+
32+
2433
@autometrics
2534
def do_something():
2635
print("done")
2736

2837

38+
@autometrics
39+
async def do_something_async():
40+
print("async start")
41+
await asyncio.sleep(2.0)
42+
print("async done")
43+
return "async world"
44+
45+
2946
if __name__ == "__main__":
3047
uvicorn.run(app, host="localhost", port=8080)

poetry.lock

Lines changed: 23 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ optional = true
2929
[tool.poetry.group.dev.dependencies]
3030
pyright = "^1.1.302"
3131
pytest = "^7.3.0"
32+
pytest-asyncio = "^0.21.0"
3233
black = "^23.3.0"
3334

3435
[tool.poetry.group.examples]

src/autometrics/decorator.py

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Autometrics module."""
22
import time
3+
import inspect
4+
35
from functools import wraps
4-
from typing import overload, TypeVar, Callable, Optional
6+
from typing import overload, TypeVar, Callable, Optional, Awaitable
57
from typing_extensions import ParamSpec
68
from .objectives import Objective
79
from .tracker import get_tracker, Result
8-
from .utils import get_module_name, get_caller_function, write_docs
10+
from .utils import get_module_name, get_caller_function, append_docs_to_docstring
911

1012

1113
P = ParamSpec("P")
@@ -18,6 +20,11 @@ def autometrics(func: Callable[P, T]) -> Callable[P, T]:
1820
...
1921

2022

23+
@overload
24+
def autometrics(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
25+
...
26+
27+
2128
# Decorator with arguments
2229
@overload
2330
def autometrics(*, objective: Optional[Objective] = None) -> Callable:
@@ -29,49 +36,100 @@ def autometrics(
2936
*,
3037
objective: Optional[Objective] = None,
3138
):
32-
"""Decorator for tracking function calls and duration."""
39+
"""Decorator for tracking function calls and duration. Supports synchronous and async functions."""
40+
41+
def track_result_ok(start_time: float, function: str, module: str, caller: str):
42+
get_tracker().finish(
43+
start_time,
44+
function=function,
45+
module=module,
46+
caller=caller,
47+
objective=objective,
48+
result=Result.OK,
49+
)
50+
51+
def track_result_error(
52+
start_time: float,
53+
function: str,
54+
module: str,
55+
caller: str,
56+
):
57+
get_tracker().finish(
58+
start_time,
59+
function=function,
60+
module=module,
61+
caller=caller,
62+
objective=objective,
63+
result=Result.ERROR,
64+
)
65+
66+
def sync_decorator(func: Callable[P, T]) -> Callable[P, T]:
67+
"""Helper for decorating synchronous functions, to track calls and duration."""
3368

34-
def decorator(func: Callable[P, T]) -> Callable[P, T]:
3569
module_name = get_module_name(func)
3670
func_name = func.__name__
3771

3872
@wraps(func)
39-
def wrapper(*args: P.args, **kwds: P.kwargs) -> T:
73+
def sync_wrapper(*args: P.args, **kwds: P.kwargs) -> T:
4074
start_time = time.time()
4175
caller = get_caller_function()
4276

4377
try:
4478
result = func(*args, **kwds)
45-
get_tracker().finish(
79+
track_result_ok(
80+
start_time, function=func_name, module=module_name, caller=caller
81+
)
82+
83+
except Exception as exception:
84+
result = exception.__class__.__name__
85+
track_result_error(
4686
start_time,
4787
function=func_name,
4888
module=module_name,
4989
caller=caller,
50-
objective=objective,
51-
result=Result.OK,
90+
)
91+
# Reraise exception
92+
raise exception
93+
return result
94+
95+
sync_wrapper.__doc__ = append_docs_to_docstring(func, func_name, module_name)
96+
return sync_wrapper
97+
98+
def async_decorator(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
99+
"""Helper for decorating async functions, to track calls and duration."""
100+
101+
module_name = get_module_name(func)
102+
func_name = func.__name__
103+
104+
@wraps(func)
105+
async def async_wrapper(*args: P.args, **kwds: P.kwargs) -> T:
106+
start_time = time.time()
107+
caller = get_caller_function()
108+
109+
try:
110+
result = await func(*args, **kwds)
111+
track_result_ok(
112+
start_time, function=func_name, module=module_name, caller=caller
52113
)
53114

54115
except Exception as exception:
55116
result = exception.__class__.__name__
56-
get_tracker().finish(
117+
track_result_error(
57118
start_time,
58119
function=func_name,
59120
module=module_name,
60121
caller=caller,
61-
objective=objective,
62-
result=Result.ERROR,
63122
)
64123
# Reraise exception
65124
raise exception
66125
return result
67126

68-
if func.__doc__ is None:
69-
wrapper.__doc__ = write_docs(func_name, module_name)
70-
else:
71-
wrapper.__doc__ = f"{func.__doc__}\n{write_docs(func_name, module_name)}"
72-
return wrapper
127+
async_wrapper.__doc__ = append_docs_to_docstring(func, func_name, module_name)
128+
return async_wrapper
73129

74130
if func is None:
75-
return decorator
131+
return sync_decorator
132+
elif inspect.iscoroutinefunction(func):
133+
return async_decorator(func)
76134
else:
77-
return decorator(func)
135+
return sync_decorator(func)

src/autometrics/test_decorator.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Test the autometrics decorator."""
22
import time
3+
import asyncio
34
from prometheus_client.exposition import generate_latest
45
import pytest
56

@@ -21,6 +22,18 @@ def error_function():
2122
raise RuntimeError("This is a test error")
2223

2324

25+
async def basic_async_function(sleep_duration: float = 1.0):
26+
"""This is a basic async function."""
27+
await asyncio.sleep(sleep_duration)
28+
return True
29+
30+
31+
async def error_async_function():
32+
"""This is an async function that raises an error."""
33+
await asyncio.sleep(0.5)
34+
raise RuntimeError("This is a test error")
35+
36+
2437
tracker_types = [TrackerType.PROMETHEUS, TrackerType.OPENTELEMETRY]
2538

2639

@@ -62,6 +75,39 @@ def test_basic(self):
6275
duration_sum = f"""function_calls_duration_sum{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
6376
assert duration_sum in data
6477

78+
@pytest.mark.asyncio
79+
async def test_basic_async(self):
80+
"""This is a basic test."""
81+
82+
# set up the function + basic variables
83+
caller = get_caller_function(depth=1)
84+
assert caller is not None
85+
assert caller != ""
86+
function_name = basic_async_function.__name__
87+
wrapped_function = autometrics(basic_async_function)
88+
89+
# Test that the function is *still* async after we wrap it
90+
assert asyncio.iscoroutinefunction(wrapped_function) == True
91+
92+
await wrapped_function()
93+
94+
blob = generate_latest()
95+
assert blob is not None
96+
data = blob.decode("utf-8")
97+
98+
total_count = f"""function_calls_count_total{{caller="{caller}",function="{function_name}",module="test_decorator",objective_name="",objective_percentile="",result="ok"}} 1.0"""
99+
assert total_count in data
100+
101+
for latency in ObjectiveLatency:
102+
query = f"""function_calls_duration_bucket{{function="{function_name}",le="{latency.value}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
103+
assert query in data
104+
105+
duration_count = f"""function_calls_duration_count{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
106+
assert duration_count in data
107+
108+
duration_sum = f"""function_calls_duration_sum{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
109+
assert duration_sum in data
110+
65111
def test_objectives(self):
66112
"""This is a test that covers objectives."""
67113

@@ -132,3 +178,38 @@ def test_exception(self):
132178

133179
duration_sum = f"""function_calls_duration_sum{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
134180
assert duration_sum in data
181+
182+
@pytest.mark.asyncio
183+
async def test_async_exception(self):
184+
"""This is a test that covers exceptions."""
185+
caller = get_caller_function(depth=1)
186+
assert caller is not None
187+
assert caller != ""
188+
189+
function_name = error_async_function.__name__
190+
wrapped_function = autometrics(error_async_function)
191+
192+
# Test that the function is *still* async after we wrap it
193+
assert asyncio.iscoroutinefunction(wrapped_function) == True
194+
195+
with pytest.raises(RuntimeError) as exception:
196+
await wrapped_function()
197+
assert "This is a test error" in str(exception.value)
198+
199+
# get the metrics
200+
blob = generate_latest()
201+
assert blob is not None
202+
data = blob.decode("utf-8")
203+
204+
total_count = f"""function_calls_count_total{{caller="{caller}",function="{function_name}",module="test_decorator",objective_name="",objective_percentile="",result="error"}} 1.0"""
205+
assert total_count in data
206+
207+
for latency in ObjectiveLatency:
208+
query = f"""function_calls_duration_bucket{{function="{function_name}",le="{latency.value}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
209+
assert query in data
210+
211+
duration_count = f"""function_calls_duration_count{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
212+
assert duration_count in data
213+
214+
duration_sum = f"""function_calls_duration_sum{{function="{function_name}",module="test_decorator",objective_latency_threshold="",objective_name="",objective_percentile=""}}"""
215+
assert duration_sum in data

0 commit comments

Comments
 (0)