Skip to content

Commit b49eb37

Browse files
authored
chore(typing): improve typing of WrappedFn (#390)
This change improves the typing of WrappedFn. It makes explictly the two signatures of tenacity.retry() with overload. This avoids mypy thinking the return type is `<nothing>`
1 parent 78c8d4b commit b49eb37

File tree

3 files changed

+76
-47
lines changed

3 files changed

+76
-47
lines changed

tenacity/__init__.py

Lines changed: 65 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
1818

19+
1920
import functools
2021
import sys
2122
import threading
@@ -91,37 +92,8 @@
9192
from .wait import WaitBaseT
9293

9394

95+
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
9496
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Any])
95-
_RetValT = t.TypeVar("_RetValT")
96-
97-
98-
def retry(*dargs: t.Any, **dkw: t.Any) -> t.Union[WrappedFn, t.Callable[[WrappedFn], WrappedFn]]: # noqa
99-
"""Wrap a function with a new `Retrying` object.
100-
101-
:param dargs: positional arguments passed to Retrying object
102-
:param dkw: keyword arguments passed to the Retrying object
103-
"""
104-
# support both @retry and @retry() as valid syntax
105-
if len(dargs) == 1 and callable(dargs[0]):
106-
return retry()(dargs[0])
107-
else:
108-
109-
def wrap(f: WrappedFn) -> WrappedFn:
110-
if isinstance(f, retry_base):
111-
warnings.warn(
112-
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
113-
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
114-
)
115-
if iscoroutinefunction(f):
116-
r: "BaseRetrying" = AsyncRetrying(*dargs, **dkw)
117-
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
118-
r = TornadoRetrying(*dargs, **dkw)
119-
else:
120-
r = Retrying(*dargs, **dkw)
121-
122-
return r.wraps(f)
123-
124-
return wrap
12597

12698

12799
class TryAgain(Exception):
@@ -382,14 +354,24 @@ def __iter__(self) -> t.Generator[AttemptManager, None, None]:
382354
break
383355

384356
@abstractmethod
385-
def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
357+
def __call__(
358+
self,
359+
fn: t.Callable[..., WrappedFnReturnT],
360+
*args: t.Any,
361+
**kwargs: t.Any,
362+
) -> WrappedFnReturnT:
386363
pass
387364

388365

389366
class Retrying(BaseRetrying):
390367
"""Retrying controller."""
391368

392-
def __call__(self, fn: t.Callable[..., _RetValT], *args: t.Any, **kwargs: t.Any) -> _RetValT:
369+
def __call__(
370+
self,
371+
fn: t.Callable[..., WrappedFnReturnT],
372+
*args: t.Any,
373+
**kwargs: t.Any,
374+
) -> WrappedFnReturnT:
393375
self.begin()
394376

395377
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
@@ -510,6 +492,57 @@ def __repr__(self) -> str:
510492
return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>"
511493

512494

495+
@t.overload
496+
def retry(func: WrappedFn) -> WrappedFn:
497+
...
498+
499+
500+
@t.overload
501+
def retry(
502+
sleep: t.Callable[[t.Union[int, float]], None] = sleep,
503+
stop: "StopBaseT" = stop_never,
504+
wait: "WaitBaseT" = wait_none(),
505+
retry: "RetryBaseT" = retry_if_exception_type(),
506+
before: t.Callable[["RetryCallState"], None] = before_nothing,
507+
after: t.Callable[["RetryCallState"], None] = after_nothing,
508+
before_sleep: t.Optional[t.Callable[["RetryCallState"], None]] = None,
509+
reraise: bool = False,
510+
retry_error_cls: t.Type["RetryError"] = RetryError,
511+
retry_error_callback: t.Optional[t.Callable[["RetryCallState"], t.Any]] = None,
512+
) -> t.Callable[[WrappedFn], WrappedFn]:
513+
...
514+
515+
516+
def retry(*dargs: t.Any, **dkw: t.Any) -> t.Any:
517+
"""Wrap a function with a new `Retrying` object.
518+
519+
:param dargs: positional arguments passed to Retrying object
520+
:param dkw: keyword arguments passed to the Retrying object
521+
"""
522+
# support both @retry and @retry() as valid syntax
523+
if len(dargs) == 1 and callable(dargs[0]):
524+
return retry()(dargs[0])
525+
else:
526+
527+
def wrap(f: WrappedFn) -> WrappedFn:
528+
if isinstance(f, retry_base):
529+
warnings.warn(
530+
f"Got retry_base instance ({f.__class__.__name__}) as callable argument, "
531+
f"this will probably hang indefinitely (did you mean retry={f.__class__.__name__}(...)?)"
532+
)
533+
r: "BaseRetrying"
534+
if iscoroutinefunction(f):
535+
r = AsyncRetrying(*dargs, **dkw)
536+
elif tornado and hasattr(tornado.gen, "is_coroutine_function") and tornado.gen.is_coroutine_function(f):
537+
r = TornadoRetrying(*dargs, **dkw)
538+
else:
539+
r = Retrying(*dargs, **dkw)
540+
541+
return r.wraps(f)
542+
543+
return wrap
544+
545+
513546
from tenacity._asyncio import AsyncRetrying # noqa:E402,I100
514547

515548
if tornado:

tenacity/_asyncio.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import functools
1919
import sys
20-
import typing
20+
import typing as t
2121
from asyncio import sleep
2222

2323
from tenacity import AttemptManager
@@ -26,24 +26,20 @@
2626
from tenacity import DoSleep
2727
from tenacity import RetryCallState
2828

29-
30-
WrappedFn = typing.TypeVar("WrappedFn", bound=typing.Callable[..., typing.Any])
31-
_RetValT = typing.TypeVar("_RetValT")
29+
WrappedFnReturnT = t.TypeVar("WrappedFnReturnT")
30+
WrappedFn = t.TypeVar("WrappedFn", bound=t.Callable[..., t.Awaitable[t.Any]])
3231

3332

3433
class AsyncRetrying(BaseRetrying):
35-
def __init__(
36-
self, sleep: typing.Callable[[float], typing.Awaitable[typing.Any]] = sleep, **kwargs: typing.Any
37-
) -> None:
34+
sleep: t.Callable[[float], t.Awaitable[t.Any]]
35+
36+
def __init__(self, sleep: t.Callable[[float], t.Awaitable[t.Any]] = sleep, **kwargs: t.Any) -> None:
3837
super().__init__(**kwargs)
3938
self.sleep = sleep
4039

4140
async def __call__( # type: ignore[override]
42-
self,
43-
fn: typing.Callable[..., typing.Awaitable[_RetValT]],
44-
*args: typing.Any,
45-
**kwargs: typing.Any,
46-
) -> _RetValT:
41+
self, fn: WrappedFn, *args: t.Any, **kwargs: t.Any
42+
) -> WrappedFnReturnT:
4743
self.begin()
4844

4945
retry_state = RetryCallState(retry_object=self, fn=fn, args=args, kwargs=kwargs)
@@ -62,7 +58,7 @@ async def __call__( # type: ignore[override]
6258
else:
6359
return do # type: ignore[no-any-return]
6460

65-
def __iter__(self) -> typing.Generator[AttemptManager, None, None]:
61+
def __iter__(self) -> t.Generator[AttemptManager, None, None]:
6662
raise TypeError("AsyncRetrying object is not iterable")
6763

6864
def __aiter__(self) -> "AsyncRetrying":
@@ -88,7 +84,7 @@ def wraps(self, fn: WrappedFn) -> WrappedFn:
8884
# Ensure wrapper is recognized as a coroutine function.
8985

9086
@functools.wraps(fn)
91-
async def async_wrapped(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
87+
async def async_wrapped(*args: t.Any, **kwargs: t.Any) -> t.Any:
9288
return await fn(*args, **kwargs)
9389

9490
# Preserve attributes

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ commands =
3333

3434
[testenv:mypy]
3535
deps =
36-
mypy
36+
mypy>=1.0.0
3737
commands =
3838
mypy tenacity
3939

0 commit comments

Comments
 (0)