27
27
import asyncio
28
28
import datetime
29
29
from typing import (
30
- Any ,
31
- Awaitable ,
32
- Callable ,
30
+ Any ,
31
+ Awaitable ,
32
+ Callable ,
33
33
Generic ,
34
- List ,
35
- Optional ,
36
- Type ,
34
+ List ,
35
+ Optional ,
36
+ Type ,
37
37
TypeVar ,
38
38
Union ,
39
- cast ,
40
39
)
41
40
42
41
import aiohttp
43
42
import discord
44
43
import inspect
45
- import logging
46
44
import sys
47
45
import traceback
48
46
49
47
from collections .abc import Sequence
50
48
from discord .backoff import ExponentialBackoff
51
49
from discord .utils import MISSING
52
50
53
- _log = logging .getLogger (__name__ )
54
-
55
51
__all__ = (
56
52
'loop' ,
57
53
)
61
57
LF = TypeVar ('LF' , bound = _func )
62
58
FT = TypeVar ('FT' , bound = _func )
63
59
ET = TypeVar ('ET' , bound = Callable [[Any , BaseException ], Awaitable [Any ]])
64
- LT = TypeVar ('LT' , bound = 'Loop' )
65
60
66
61
67
62
class SleepHandle :
@@ -78,7 +73,7 @@ def recalculate(self, dt: datetime.datetime) -> None:
78
73
relative_delta = discord .utils .compute_timedelta (dt )
79
74
self .handle = self .loop .call_later (relative_delta , self .future .set_result , True )
80
75
81
- def wait (self ) -> asyncio .Future :
76
+ def wait (self ) -> asyncio .Future [ Any ] :
82
77
return self .future
83
78
84
79
def done (self ) -> bool :
@@ -94,23 +89,25 @@ class Loop(Generic[LF]):
94
89
95
90
The main interface to create this is through :func:`loop`.
96
91
"""
97
- def __init__ (self ,
92
+
93
+ def __init__ (
94
+ self ,
98
95
coro : LF ,
99
96
seconds : float ,
100
97
hours : float ,
101
98
minutes : float ,
102
99
time : Union [datetime .time , Sequence [datetime .time ]],
103
100
count : Optional [int ],
104
101
reconnect : bool ,
105
- loop : Optional [ asyncio .AbstractEventLoop ] ,
102
+ loop : asyncio .AbstractEventLoop ,
106
103
) -> None :
107
104
self .coro : LF = coro
108
105
self .reconnect : bool = reconnect
109
- self .loop : Optional [ asyncio .AbstractEventLoop ] = loop
106
+ self .loop : asyncio .AbstractEventLoop = loop
110
107
self .count : Optional [int ] = count
111
108
self ._current_loop = 0
112
- self ._handle = None
113
- self ._task = None
109
+ self ._handle : SleepHandle = MISSING
110
+ self ._task : asyncio . Task [ None ] = MISSING
114
111
self ._injected = None
115
112
self ._valid_exception = (
116
113
OSError ,
@@ -131,7 +128,7 @@ def __init__(self,
131
128
132
129
self .change_interval (seconds = seconds , minutes = minutes , hours = hours , time = time )
133
130
self ._last_iteration_failed = False
134
- self ._last_iteration = None
131
+ self ._last_iteration : datetime . datetime = MISSING
135
132
self ._next_iteration = None
136
133
137
134
if not inspect .iscoroutinefunction (self .coro ):
@@ -147,9 +144,8 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
147
144
else :
148
145
await coro (* args , ** kwargs )
149
146
150
-
151
147
def _try_sleep_until (self , dt : datetime .datetime ):
152
- self ._handle = SleepHandle (dt = dt , loop = self .loop ) # type: ignore
148
+ self ._handle = SleepHandle (dt = dt , loop = self .loop )
153
149
return self ._handle .wait ()
154
150
155
151
async def _loop (self , * args : Any , ** kwargs : Any ) -> None :
@@ -178,7 +174,7 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
178
174
await asyncio .sleep (backoff .delay ())
179
175
else :
180
176
await self ._try_sleep_until (self ._next_iteration )
181
-
177
+
182
178
if self ._stop_next_iteration :
183
179
return
184
180
@@ -211,14 +207,14 @@ def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
211
207
if obj is None :
212
208
return self
213
209
214
- copy = Loop (
215
- self .coro ,
216
- seconds = self ._seconds ,
217
- hours = self ._hours ,
210
+ copy : Loop [ LF ] = Loop (
211
+ self .coro ,
212
+ seconds = self ._seconds ,
213
+ hours = self ._hours ,
218
214
minutes = self ._minutes ,
219
- time = self ._time ,
215
+ time = self ._time ,
220
216
count = self .count ,
221
- reconnect = self .reconnect ,
217
+ reconnect = self .reconnect ,
222
218
loop = self .loop ,
223
219
)
224
220
copy ._injected = obj
@@ -237,7 +233,7 @@ def seconds(self) -> Optional[float]:
237
233
"""
238
234
if self ._seconds is not MISSING :
239
235
return self ._seconds
240
-
236
+
241
237
@property
242
238
def minutes (self ) -> Optional [float ]:
243
239
"""Optional[:class:`float`]: Read-only value for the number of minutes
@@ -247,7 +243,7 @@ def minutes(self) -> Optional[float]:
247
243
"""
248
244
if self ._minutes is not MISSING :
249
245
return self ._minutes
250
-
246
+
251
247
@property
252
248
def hours (self ) -> Optional [float ]:
253
249
"""Optional[:class:`float`]: Read-only value for the number of hours
@@ -279,7 +275,7 @@ def next_iteration(self) -> Optional[datetime.datetime]:
279
275
280
276
.. versionadded:: 1.3
281
277
"""
282
- if self ._task is None :
278
+ if self ._task is MISSING :
283
279
return None
284
280
elif self ._task and self ._task .done () or self ._stop_next_iteration :
285
281
return None
@@ -305,7 +301,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
305
301
306
302
return await self .coro (* args , ** kwargs )
307
303
308
- def start (self , * args : Any , ** kwargs : Any ) -> asyncio .Task :
304
+ def start (self , * args : Any , ** kwargs : Any ) -> asyncio .Task [ None ] :
309
305
r"""Starts the internal task in the event loop.
310
306
311
307
Parameters
@@ -326,13 +322,13 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
326
322
The task that has been created.
327
323
"""
328
324
329
- if self ._task is not None and not self ._task .done ():
325
+ if self ._task is not MISSING and not self ._task .done ():
330
326
raise RuntimeError ('Task is already launched and is not completed.' )
331
327
332
328
if self ._injected is not None :
333
329
args = (self ._injected , * args )
334
330
335
- if self .loop is None :
331
+ if self .loop is MISSING :
336
332
self .loop = asyncio .get_event_loop ()
337
333
338
334
self ._task = self .loop .create_task (self ._loop (* args , ** kwargs ))
@@ -356,7 +352,7 @@ def stop(self) -> None:
356
352
357
353
.. versionadded:: 1.2
358
354
"""
359
- if self ._task and not self ._task .done ():
355
+ if self ._task is not MISSING and not self ._task .done ():
360
356
self ._stop_next_iteration = True
361
357
362
358
def _can_be_cancelled (self ) -> bool :
@@ -383,7 +379,7 @@ def restart(self, *args: Any, **kwargs: Any) -> None:
383
379
The keyword arguments to use.
384
380
"""
385
381
386
- def restart_when_over (fut , * , args = args , kwargs = kwargs ):
382
+ def restart_when_over (fut : Any , * , args : Any = args , kwargs : Any = kwargs ) -> None :
387
383
self ._task .remove_done_callback (restart_when_over )
388
384
self .start (* args , ** kwargs )
389
385
@@ -446,9 +442,9 @@ def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
446
442
self ._valid_exception = tuple (x for x in self ._valid_exception if x not in exceptions )
447
443
return len (self ._valid_exception ) == old_length - len (exceptions )
448
444
449
- def get_task (self ) -> Optional [asyncio .Task ]:
445
+ def get_task (self ) -> Optional [asyncio .Task [ None ] ]:
450
446
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
451
- return self ._task
447
+ return self ._task if self . _task is not MISSING else None
452
448
453
449
def is_being_cancelled (self ) -> bool :
454
450
"""Whether the task is being cancelled."""
@@ -466,7 +462,7 @@ def is_running(self) -> bool:
466
462
467
463
.. versionadded:: 1.4
468
464
"""
469
- return not bool (self ._task .done ()) if self ._task else False
465
+ return not bool (self ._task .done ()) if self ._task is not MISSING else False
470
466
471
467
async def _error (self , * args : Any ) -> None :
472
468
exception : Exception = args [- 1 ]
@@ -560,28 +556,32 @@ def _get_next_sleep_time(self) -> datetime.datetime:
560
556
self ._time_index = 0
561
557
if self ._current_loop == 0 :
562
558
# if we're at the last index on the first iteration, we need to sleep until tomorrow
563
- return datetime .datetime .combine (datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (days = 1 ), self ._time [0 ])
559
+ return datetime .datetime .combine (
560
+ datetime .datetime .now (datetime .timezone .utc ) + datetime .timedelta (days = 1 ), self ._time [0 ]
561
+ )
564
562
565
563
next_time = self ._time [self ._time_index ]
566
564
567
565
if self ._current_loop == 0 :
568
566
self ._time_index += 1
569
567
return datetime .datetime .combine (datetime .datetime .now (datetime .timezone .utc ), next_time )
570
568
571
- next_date = cast ( datetime . datetime , self ._last_iteration )
569
+ next_date = self ._last_iteration
572
570
if self ._time_index == 0 :
573
571
# we can assume that the earliest time should be scheduled for "tomorrow"
574
572
next_date += datetime .timedelta (days = 1 )
575
573
576
574
self ._time_index += 1
577
575
return datetime .datetime .combine (next_date , next_time )
578
576
579
- def _prepare_time_index (self , now : Optional [ datetime .datetime ] = None ) -> None :
577
+ def _prepare_time_index (self , now : datetime .datetime = MISSING ) -> None :
580
578
# now kwarg should be a datetime.datetime representing the time "now"
581
579
# to calculate the next time index from
582
580
583
581
# pre-condition: self._time is set
584
- time_now = (now or datetime .datetime .now (datetime .timezone .utc ).replace (microsecond = 0 )).timetz ()
582
+ time_now = (
583
+ now if now is not MISSING else datetime .datetime .now (datetime .timezone .utc ).replace (microsecond = 0 )
584
+ ).timetz ()
585
585
for idx , time in enumerate (self ._time ):
586
586
if time >= time_now :
587
587
self ._time_index = idx
@@ -597,20 +597,24 @@ def _get_time_parameter(
597
597
utc : datetime .timezone = datetime .timezone .utc ,
598
598
) -> List [datetime .time ]:
599
599
if isinstance (time , dt ):
600
- ret = time if time .tzinfo is not None else time .replace (tzinfo = utc )
601
- return [ret ]
600
+ inner = time if time .tzinfo is not None else time .replace (tzinfo = utc )
601
+ return [inner ]
602
602
if not isinstance (time , Sequence ):
603
- raise TypeError (f'Expected datetime.time or a sequence of datetime.time for ``time``, received { type (time )!r} instead.' )
603
+ raise TypeError (
604
+ f'Expected datetime.time or a sequence of datetime.time for ``time``, received { type (time )!r} instead.'
605
+ )
604
606
if not time :
605
607
raise ValueError ('time parameter must not be an empty sequence.' )
606
608
607
- ret = []
609
+ ret : List [ datetime . time ] = []
608
610
for index , t in enumerate (time ):
609
611
if not isinstance (t , dt ):
610
- raise TypeError (f'Expected a sequence of { dt !r} for ``time``, received { type (t ).__name__ !r} at index { index } instead.' )
612
+ raise TypeError (
613
+ f'Expected a sequence of { dt !r} for ``time``, received { type (t ).__name__ !r} at index { index } instead.'
614
+ )
611
615
ret .append (t if t .tzinfo is not None else t .replace (tzinfo = utc ))
612
616
613
- ret = sorted (set (ret )) # de-dupe and sort times
617
+ ret = sorted (set (ret )) # de-dupe and sort times
614
618
return ret
615
619
616
620
def change_interval (
@@ -691,7 +695,7 @@ def loop(
691
695
time : Union [datetime .time , Sequence [datetime .time ]] = MISSING ,
692
696
count : Optional [int ] = None ,
693
697
reconnect : bool = True ,
694
- loop : Optional [ asyncio .AbstractEventLoop ] = None ,
698
+ loop : asyncio .AbstractEventLoop = MISSING ,
695
699
) -> Callable [[LF ], Loop [LF ]]:
696
700
"""A decorator that schedules a task in the background for you with
697
701
optional reconnect logic. The decorator returns a :class:`Loop`.
@@ -707,7 +711,7 @@ def loop(
707
711
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
708
712
The exact times to run this loop at. Either a non-empty list or a single
709
713
value of :class:`datetime.time` should be passed. Timezones are supported.
710
- If no timezone is given for the times, it is assumed to represent UTC time.
714
+ If no timezone is given for the times, it is assumed to represent UTC time.
711
715
712
716
This cannot be used in conjunction with the relative time parameters.
713
717
@@ -724,7 +728,7 @@ def loop(
724
728
Whether to handle errors and restart the task
725
729
using an exponential back-off algorithm similar to the
726
730
one used in :meth:`discord.Client.connect`.
727
- loop: Optional[ :class:`asyncio.AbstractEventLoop`]
731
+ loop: :class:`asyncio.AbstractEventLoop`
728
732
The loop to use to register the task, if not given
729
733
defaults to :func:`asyncio.get_event_loop`.
730
734
@@ -736,15 +740,17 @@ def loop(
736
740
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
737
741
or ``time`` parameter was passed in conjunction with relative time parameters.
738
742
"""
743
+
739
744
def decorator (func : LF ) -> Loop [LF ]:
740
- kwargs = {
741
- 'seconds' : seconds ,
742
- 'minutes' : minutes ,
743
- 'hours' : hours ,
744
- 'count' : count ,
745
- 'time' : time ,
746
- 'reconnect' : reconnect ,
747
- 'loop' : loop ,
748
- }
749
- return Loop (func , ** kwargs )
745
+ return Loop [LF ](
746
+ func ,
747
+ seconds = seconds ,
748
+ minutes = minutes ,
749
+ hours = hours ,
750
+ count = count ,
751
+ time = time ,
752
+ reconnect = reconnect ,
753
+ loop = loop ,
754
+ )
755
+
750
756
return decorator
0 commit comments