Skip to content

Commit a2a7b0f

Browse files
authored
[tasks] Improve typing parity
1 parent b2ac327 commit a2a7b0f

File tree

1 file changed

+67
-61
lines changed

1 file changed

+67
-61
lines changed

discord/ext/tasks/__init__.py

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,27 @@
2727
import asyncio
2828
import datetime
2929
from typing import (
30-
Any,
31-
Awaitable,
32-
Callable,
30+
Any,
31+
Awaitable,
32+
Callable,
3333
Generic,
34-
List,
35-
Optional,
36-
Type,
34+
List,
35+
Optional,
36+
Type,
3737
TypeVar,
3838
Union,
39-
cast,
4039
)
4140

4241
import aiohttp
4342
import discord
4443
import inspect
45-
import logging
4644
import sys
4745
import traceback
4846

4947
from collections.abc import Sequence
5048
from discord.backoff import ExponentialBackoff
5149
from discord.utils import MISSING
5250

53-
_log = logging.getLogger(__name__)
54-
5551
__all__ = (
5652
'loop',
5753
)
@@ -61,7 +57,6 @@
6157
LF = TypeVar('LF', bound=_func)
6258
FT = TypeVar('FT', bound=_func)
6359
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
64-
LT = TypeVar('LT', bound='Loop')
6560

6661

6762
class SleepHandle:
@@ -78,7 +73,7 @@ def recalculate(self, dt: datetime.datetime) -> None:
7873
relative_delta = discord.utils.compute_timedelta(dt)
7974
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
8075

81-
def wait(self) -> asyncio.Future:
76+
def wait(self) -> asyncio.Future[Any]:
8277
return self.future
8378

8479
def done(self) -> bool:
@@ -94,23 +89,25 @@ class Loop(Generic[LF]):
9489
9590
The main interface to create this is through :func:`loop`.
9691
"""
97-
def __init__(self,
92+
93+
def __init__(
94+
self,
9895
coro: LF,
9996
seconds: float,
10097
hours: float,
10198
minutes: float,
10299
time: Union[datetime.time, Sequence[datetime.time]],
103100
count: Optional[int],
104101
reconnect: bool,
105-
loop: Optional[asyncio.AbstractEventLoop],
102+
loop: asyncio.AbstractEventLoop,
106103
) -> None:
107104
self.coro: LF = coro
108105
self.reconnect: bool = reconnect
109-
self.loop: Optional[asyncio.AbstractEventLoop] = loop
106+
self.loop: asyncio.AbstractEventLoop = loop
110107
self.count: Optional[int] = count
111108
self._current_loop = 0
112-
self._handle = None
113-
self._task = None
109+
self._handle: SleepHandle = MISSING
110+
self._task: asyncio.Task[None] = MISSING
114111
self._injected = None
115112
self._valid_exception = (
116113
OSError,
@@ -131,7 +128,7 @@ def __init__(self,
131128

132129
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
133130
self._last_iteration_failed = False
134-
self._last_iteration = None
131+
self._last_iteration: datetime.datetime = MISSING
135132
self._next_iteration = None
136133

137134
if not inspect.iscoroutinefunction(self.coro):
@@ -147,9 +144,8 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
147144
else:
148145
await coro(*args, **kwargs)
149146

150-
151147
def _try_sleep_until(self, dt: datetime.datetime):
152-
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
148+
self._handle = SleepHandle(dt=dt, loop=self.loop)
153149
return self._handle.wait()
154150

155151
async def _loop(self, *args: Any, **kwargs: Any) -> None:
@@ -178,7 +174,7 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
178174
await asyncio.sleep(backoff.delay())
179175
else:
180176
await self._try_sleep_until(self._next_iteration)
181-
177+
182178
if self._stop_next_iteration:
183179
return
184180

@@ -211,14 +207,14 @@ def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
211207
if obj is None:
212208
return self
213209

214-
copy = Loop(
215-
self.coro,
216-
seconds=self._seconds,
217-
hours=self._hours,
210+
copy: Loop[LF] = Loop(
211+
self.coro,
212+
seconds=self._seconds,
213+
hours=self._hours,
218214
minutes=self._minutes,
219-
time=self._time,
215+
time=self._time,
220216
count=self.count,
221-
reconnect=self.reconnect,
217+
reconnect=self.reconnect,
222218
loop=self.loop,
223219
)
224220
copy._injected = obj
@@ -237,7 +233,7 @@ def seconds(self) -> Optional[float]:
237233
"""
238234
if self._seconds is not MISSING:
239235
return self._seconds
240-
236+
241237
@property
242238
def minutes(self) -> Optional[float]:
243239
"""Optional[:class:`float`]: Read-only value for the number of minutes
@@ -247,7 +243,7 @@ def minutes(self) -> Optional[float]:
247243
"""
248244
if self._minutes is not MISSING:
249245
return self._minutes
250-
246+
251247
@property
252248
def hours(self) -> Optional[float]:
253249
"""Optional[:class:`float`]: Read-only value for the number of hours
@@ -279,7 +275,7 @@ def next_iteration(self) -> Optional[datetime.datetime]:
279275
280276
.. versionadded:: 1.3
281277
"""
282-
if self._task is None:
278+
if self._task is MISSING:
283279
return None
284280
elif self._task and self._task.done() or self._stop_next_iteration:
285281
return None
@@ -305,7 +301,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
305301

306302
return await self.coro(*args, **kwargs)
307303

308-
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
304+
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
309305
r"""Starts the internal task in the event loop.
310306
311307
Parameters
@@ -326,13 +322,13 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
326322
The task that has been created.
327323
"""
328324

329-
if self._task is not None and not self._task.done():
325+
if self._task is not MISSING and not self._task.done():
330326
raise RuntimeError('Task is already launched and is not completed.')
331327

332328
if self._injected is not None:
333329
args = (self._injected, *args)
334330

335-
if self.loop is None:
331+
if self.loop is MISSING:
336332
self.loop = asyncio.get_event_loop()
337333

338334
self._task = self.loop.create_task(self._loop(*args, **kwargs))
@@ -356,7 +352,7 @@ def stop(self) -> None:
356352
357353
.. versionadded:: 1.2
358354
"""
359-
if self._task and not self._task.done():
355+
if self._task is not MISSING and not self._task.done():
360356
self._stop_next_iteration = True
361357

362358
def _can_be_cancelled(self) -> bool:
@@ -383,7 +379,7 @@ def restart(self, *args: Any, **kwargs: Any) -> None:
383379
The keyword arguments to use.
384380
"""
385381

386-
def restart_when_over(fut, *, args=args, kwargs=kwargs):
382+
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
387383
self._task.remove_done_callback(restart_when_over)
388384
self.start(*args, **kwargs)
389385

@@ -446,9 +442,9 @@ def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
446442
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
447443
return len(self._valid_exception) == old_length - len(exceptions)
448444

449-
def get_task(self) -> Optional[asyncio.Task]:
445+
def get_task(self) -> Optional[asyncio.Task[None]]:
450446
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
451-
return self._task
447+
return self._task if self._task is not MISSING else None
452448

453449
def is_being_cancelled(self) -> bool:
454450
"""Whether the task is being cancelled."""
@@ -466,7 +462,7 @@ def is_running(self) -> bool:
466462
467463
.. versionadded:: 1.4
468464
"""
469-
return not bool(self._task.done()) if self._task else False
465+
return not bool(self._task.done()) if self._task is not MISSING else False
470466

471467
async def _error(self, *args: Any) -> None:
472468
exception: Exception = args[-1]
@@ -560,28 +556,32 @@ def _get_next_sleep_time(self) -> datetime.datetime:
560556
self._time_index = 0
561557
if self._current_loop == 0:
562558
# if we're at the last index on the first iteration, we need to sleep until tomorrow
563-
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
559+
return datetime.datetime.combine(
560+
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
561+
)
564562

565563
next_time = self._time[self._time_index]
566564

567565
if self._current_loop == 0:
568566
self._time_index += 1
569567
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
570568

571-
next_date = cast(datetime.datetime, self._last_iteration)
569+
next_date = self._last_iteration
572570
if self._time_index == 0:
573571
# we can assume that the earliest time should be scheduled for "tomorrow"
574572
next_date += datetime.timedelta(days=1)
575573

576574
self._time_index += 1
577575
return datetime.datetime.combine(next_date, next_time)
578576

579-
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
577+
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
580578
# now kwarg should be a datetime.datetime representing the time "now"
581579
# to calculate the next time index from
582580

583581
# pre-condition: self._time is set
584-
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
582+
time_now = (
583+
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
584+
).timetz()
585585
for idx, time in enumerate(self._time):
586586
if time >= time_now:
587587
self._time_index = idx
@@ -597,20 +597,24 @@ def _get_time_parameter(
597597
utc: datetime.timezone = datetime.timezone.utc,
598598
) -> List[datetime.time]:
599599
if isinstance(time, dt):
600-
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
601-
return [ret]
600+
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
601+
return [inner]
602602
if not isinstance(time, Sequence):
603-
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
603+
raise TypeError(
604+
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
605+
)
604606
if not time:
605607
raise ValueError('time parameter must not be an empty sequence.')
606608

607-
ret = []
609+
ret: List[datetime.time] = []
608610
for index, t in enumerate(time):
609611
if not isinstance(t, dt):
610-
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
612+
raise TypeError(
613+
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
614+
)
611615
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
612616

613-
ret = sorted(set(ret)) # de-dupe and sort times
617+
ret = sorted(set(ret)) # de-dupe and sort times
614618
return ret
615619

616620
def change_interval(
@@ -691,7 +695,7 @@ def loop(
691695
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
692696
count: Optional[int] = None,
693697
reconnect: bool = True,
694-
loop: Optional[asyncio.AbstractEventLoop] = None,
698+
loop: asyncio.AbstractEventLoop = MISSING,
695699
) -> Callable[[LF], Loop[LF]]:
696700
"""A decorator that schedules a task in the background for you with
697701
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -707,7 +711,7 @@ def loop(
707711
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
708712
The exact times to run this loop at. Either a non-empty list or a single
709713
value of :class:`datetime.time` should be passed. Timezones are supported.
710-
If no timezone is given for the times, it is assumed to represent UTC time.
714+
If no timezone is given for the times, it is assumed to represent UTC time.
711715
712716
This cannot be used in conjunction with the relative time parameters.
713717
@@ -724,7 +728,7 @@ def loop(
724728
Whether to handle errors and restart the task
725729
using an exponential back-off algorithm similar to the
726730
one used in :meth:`discord.Client.connect`.
727-
loop: Optional[:class:`asyncio.AbstractEventLoop`]
731+
loop: :class:`asyncio.AbstractEventLoop`
728732
The loop to use to register the task, if not given
729733
defaults to :func:`asyncio.get_event_loop`.
730734
@@ -736,15 +740,17 @@ def loop(
736740
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
737741
or ``time`` parameter was passed in conjunction with relative time parameters.
738742
"""
743+
739744
def decorator(func: LF) -> Loop[LF]:
740-
kwargs = {
741-
'seconds': seconds,
742-
'minutes': minutes,
743-
'hours': hours,
744-
'count': count,
745-
'time': time,
746-
'reconnect': reconnect,
747-
'loop': loop,
748-
}
749-
return Loop(func, **kwargs)
745+
return Loop[LF](
746+
func,
747+
seconds=seconds,
748+
minutes=minutes,
749+
hours=hours,
750+
count=count,
751+
time=time,
752+
reconnect=reconnect,
753+
loop=loop,
754+
)
755+
750756
return decorator

0 commit comments

Comments
 (0)