Skip to content

Commit

Permalink
[tasks] Improve typing parity
Browse files Browse the repository at this point in the history
  • Loading branch information
NCPlayz authored Aug 27, 2021
1 parent b2ac327 commit a2a7b0f
Showing 1 changed file with 67 additions and 61 deletions.
128 changes: 67 additions & 61 deletions discord/ext/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,27 @@
import asyncio
import datetime
from typing import (
Any,
Awaitable,
Callable,
Any,
Awaitable,
Callable,
Generic,
List,
Optional,
Type,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)

import aiohttp
import discord
import inspect
import logging
import sys
import traceback

from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING

_log = logging.getLogger(__name__)

__all__ = (
'loop',
)
Expand All @@ -61,7 +57,6 @@
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')


class SleepHandle:
Expand All @@ -78,7 +73,7 @@ def recalculate(self, dt: datetime.datetime) -> None:
relative_delta = discord.utils.compute_timedelta(dt)
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)

def wait(self) -> asyncio.Future:
def wait(self) -> asyncio.Future[Any]:
return self.future

def done(self) -> bool:
Expand All @@ -94,23 +89,25 @@ class Loop(Generic[LF]):
The main interface to create this is through :func:`loop`.
"""
def __init__(self,

def __init__(
self,
coro: LF,
seconds: float,
hours: float,
minutes: float,
time: Union[datetime.time, Sequence[datetime.time]],
count: Optional[int],
reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop],
loop: asyncio.AbstractEventLoop,
) -> None:
self.coro: LF = coro
self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.loop: asyncio.AbstractEventLoop = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle = None
self._task = None
self._handle: SleepHandle = MISSING
self._task: asyncio.Task[None] = MISSING
self._injected = None
self._valid_exception = (
OSError,
Expand All @@ -131,7 +128,7 @@ def __init__(self,

self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
self._last_iteration_failed = False
self._last_iteration = None
self._last_iteration: datetime.datetime = MISSING
self._next_iteration = None

if not inspect.iscoroutinefunction(self.coro):
Expand All @@ -147,9 +144,8 @@ async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> Non
else:
await coro(*args, **kwargs)


def _try_sleep_until(self, dt: datetime.datetime):
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
self._handle = SleepHandle(dt=dt, loop=self.loop)
return self._handle.wait()

async def _loop(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -178,7 +174,7 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None:
await asyncio.sleep(backoff.delay())
else:
await self._try_sleep_until(self._next_iteration)

if self._stop_next_iteration:
return

Expand Down Expand Up @@ -211,14 +207,14 @@ def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
if obj is None:
return self

copy = Loop(
self.coro,
seconds=self._seconds,
hours=self._hours,
copy: Loop[LF] = Loop(
self.coro,
seconds=self._seconds,
hours=self._hours,
minutes=self._minutes,
time=self._time,
time=self._time,
count=self.count,
reconnect=self.reconnect,
reconnect=self.reconnect,
loop=self.loop,
)
copy._injected = obj
Expand All @@ -237,7 +233,7 @@ def seconds(self) -> Optional[float]:
"""
if self._seconds is not MISSING:
return self._seconds

@property
def minutes(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of minutes
Expand All @@ -247,7 +243,7 @@ def minutes(self) -> Optional[float]:
"""
if self._minutes is not MISSING:
return self._minutes

@property
def hours(self) -> Optional[float]:
"""Optional[:class:`float`]: Read-only value for the number of hours
Expand Down Expand Up @@ -279,7 +275,7 @@ def next_iteration(self) -> Optional[datetime.datetime]:
.. versionadded:: 1.3
"""
if self._task is None:
if self._task is MISSING:
return None
elif self._task and self._task.done() or self._stop_next_iteration:
return None
Expand All @@ -305,7 +301,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:

return await self.coro(*args, **kwargs)

def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
r"""Starts the internal task in the event loop.
Parameters
Expand All @@ -326,13 +322,13 @@ def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
The task that has been created.
"""

if self._task is not None and not self._task.done():
if self._task is not MISSING and not self._task.done():
raise RuntimeError('Task is already launched and is not completed.')

if self._injected is not None:
args = (self._injected, *args)

if self.loop is None:
if self.loop is MISSING:
self.loop = asyncio.get_event_loop()

self._task = self.loop.create_task(self._loop(*args, **kwargs))
Expand All @@ -356,7 +352,7 @@ def stop(self) -> None:
.. versionadded:: 1.2
"""
if self._task and not self._task.done():
if self._task is not MISSING and not self._task.done():
self._stop_next_iteration = True

def _can_be_cancelled(self) -> bool:
Expand All @@ -383,7 +379,7 @@ def restart(self, *args: Any, **kwargs: Any) -> None:
The keyword arguments to use.
"""

def restart_when_over(fut, *, args=args, kwargs=kwargs):
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
self._task.remove_done_callback(restart_when_over)
self.start(*args, **kwargs)

Expand Down Expand Up @@ -446,9 +442,9 @@ def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
return len(self._valid_exception) == old_length - len(exceptions)

def get_task(self) -> Optional[asyncio.Task]:
def get_task(self) -> Optional[asyncio.Task[None]]:
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
return self._task
return self._task if self._task is not MISSING else None

def is_being_cancelled(self) -> bool:
"""Whether the task is being cancelled."""
Expand All @@ -466,7 +462,7 @@ def is_running(self) -> bool:
.. versionadded:: 1.4
"""
return not bool(self._task.done()) if self._task else False
return not bool(self._task.done()) if self._task is not MISSING else False

async def _error(self, *args: Any) -> None:
exception: Exception = args[-1]
Expand Down Expand Up @@ -560,28 +556,32 @@ def _get_next_sleep_time(self) -> datetime.datetime:
self._time_index = 0
if self._current_loop == 0:
# if we're at the last index on the first iteration, we need to sleep until tomorrow
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
return datetime.datetime.combine(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
)

next_time = self._time[self._time_index]

if self._current_loop == 0:
self._time_index += 1
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)

next_date = cast(datetime.datetime, self._last_iteration)
next_date = self._last_iteration
if self._time_index == 0:
# we can assume that the earliest time should be scheduled for "tomorrow"
next_date += datetime.timedelta(days=1)

self._time_index += 1
return datetime.datetime.combine(next_date, next_time)

def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
# now kwarg should be a datetime.datetime representing the time "now"
# to calculate the next time index from

# pre-condition: self._time is set
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
time_now = (
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
).timetz()
for idx, time in enumerate(self._time):
if time >= time_now:
self._time_index = idx
Expand All @@ -597,20 +597,24 @@ def _get_time_parameter(
utc: datetime.timezone = datetime.timezone.utc,
) -> List[datetime.time]:
if isinstance(time, dt):
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [ret]
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
return [inner]
if not isinstance(time, Sequence):
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
raise TypeError(
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
)
if not time:
raise ValueError('time parameter must not be an empty sequence.')

ret = []
ret: List[datetime.time] = []
for index, t in enumerate(time):
if not isinstance(t, dt):
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
raise TypeError(
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
)
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))

ret = sorted(set(ret)) # de-dupe and sort times
ret = sorted(set(ret)) # de-dupe and sort times
return ret

def change_interval(
Expand Down Expand Up @@ -691,7 +695,7 @@ def loop(
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
count: Optional[int] = None,
reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None,
loop: asyncio.AbstractEventLoop = MISSING,
) -> Callable[[LF], Loop[LF]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
Expand All @@ -707,7 +711,7 @@ def loop(
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
The exact times to run this loop at. Either a non-empty list or a single
value of :class:`datetime.time` should be passed. Timezones are supported.
If no timezone is given for the times, it is assumed to represent UTC time.
If no timezone is given for the times, it is assumed to represent UTC time.
This cannot be used in conjunction with the relative time parameters.
Expand All @@ -724,7 +728,7 @@ def loop(
Whether to handle errors and restart the task
using an exponential back-off algorithm similar to the
one used in :meth:`discord.Client.connect`.
loop: Optional[:class:`asyncio.AbstractEventLoop`]
loop: :class:`asyncio.AbstractEventLoop`
The loop to use to register the task, if not given
defaults to :func:`asyncio.get_event_loop`.
Expand All @@ -736,15 +740,17 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters.
"""

def decorator(func: LF) -> Loop[LF]:
kwargs = {
'seconds': seconds,
'minutes': minutes,
'hours': hours,
'count': count,
'time': time,
'reconnect': reconnect,
'loop': loop,
}
return Loop(func, **kwargs)
return Loop[LF](
func,
seconds=seconds,
minutes=minutes,
hours=hours,
count=count,
time=time,
reconnect=reconnect,
loop=loop,
)

return decorator

0 comments on commit a2a7b0f

Please sign in to comment.