Skip to content

Commit

Permalink
Use ParamSpec to properly annotate decorators
Browse files Browse the repository at this point in the history
Signed-off-by: Iurii Pliner <yury.pliner@gmail.com>
  • Loading branch information
Pliner committed Dec 11, 2024
1 parent d9aae14 commit 53b6afb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
23 changes: 15 additions & 8 deletions prometheus_client/context_managers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import functools
import sys
from timeit import default_timer
from types import TracebackType
from typing import (
Any, Callable, Literal, Optional, Tuple, Type, TYPE_CHECKING, TypeVar,
Callable, Literal, Optional, Tuple, Type, TYPE_CHECKING, TypeVar,
Union,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

if TYPE_CHECKING:
from . import Counter
F = TypeVar("F", bound=Callable[..., Any])

TParam = ParamSpec("TParam")
TResult = TypeVar("TResult")


class ExceptionCounter:
Expand All @@ -24,9 +31,9 @@ def __exit__(self, typ: Optional[Type[BaseException]], value: Optional[BaseExcep
self._counter.inc()
return False

def __call__(self, f: "F") -> "F":
def __call__(self, f: Callable[TParam, TResult]) -> Callable[TParam, TResult]:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def wrapped(*args: TParam.args, **kwargs: TParam.kwargs) -> TResult:
with self:
return f(*args, **kwargs)
return wrapped # type: ignore
Expand All @@ -42,9 +49,9 @@ def __enter__(self):
def __exit__(self, typ, value, traceback):
self._gauge.dec()

def __call__(self, f: "F") -> "F":
def __call__(self, f: Callable[TParam, TResult]) -> Callable[TParam, TResult]:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def wrapped(*args: TParam.args, **kwargs: TParam.kwargs) -> TResult:
with self:
return f(*args, **kwargs)
return wrapped # type: ignore
Expand All @@ -71,9 +78,9 @@ def __exit__(self, typ, value, traceback):
def labels(self, *args, **kw):
self._metric = self._metric.labels(*args, **kw)

def __call__(self, f: "F") -> "F":
def __call__(self, f: Callable[TParam, TResult]) -> Callable[TParam, TResult]:
@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
def wrapped(*args: TParam.args, **kwargs: TParam.kwargs) -> TResult:
# Obtaining new instance of timer every time
# ensures thread safety and reentrancy.
with self._new_timer():
Expand Down
1 change: 0 additions & 1 deletion prometheus_client/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)

T = TypeVar('T', bound='MetricWrapperBase')
F = TypeVar("F", bound=Callable[..., Any])


def _build_full_name(metric_type, name, namespace, subsystem, unit):
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
extras_require={
'twisted': ['twisted'],
},
install_requires=[
'typing-extensions>=4; python_version<"3.10"',
],
test_suite="tests",
python_requires=">=3.9",
classifiers=[
Expand Down

0 comments on commit 53b6afb

Please sign in to comment.