Skip to content

Commit c7a6ddd

Browse files
Merge pull request #88 from Numerlor/no-duplicate-deco
2 parents ac156ec + a6f1ebf commit c7a6ddd

File tree

7 files changed

+731
-287
lines changed

7 files changed

+731
-287
lines changed

botcore/utils/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
"""Useful utilities and tools for Discord bot development."""
22

3-
from botcore.utils import _monkey_patches, caching, channel, commands, interactions, logging, members, regex, scheduling
3+
from botcore.utils import (
4+
_monkey_patches,
5+
caching,
6+
channel,
7+
commands,
8+
cooldown,
9+
function,
10+
interactions,
11+
logging,
12+
members,
13+
regex,
14+
scheduling,
15+
)
416
from botcore.utils._extensions import unqualify
517

618

@@ -25,6 +37,8 @@ def apply_monkey_patches() -> None:
2537
caching,
2638
channel,
2739
commands,
40+
cooldown,
41+
function,
2842
interactions,
2943
logging,
3044
members,

botcore/utils/cooldown.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""Helpers for setting a cooldown on commands."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import random
7+
import time
8+
import typing
9+
import weakref
10+
from collections.abc import Awaitable, Callable, Hashable, Iterable
11+
from contextlib import suppress
12+
from dataclasses import dataclass
13+
14+
import discord
15+
from discord.ext.commands import CommandError, Context
16+
17+
from botcore.utils import scheduling
18+
from botcore.utils.function import command_wraps
19+
20+
__all__ = ["CommandOnCooldown", "block_duplicate_invocations", "P", "R"]
21+
22+
_KEYWORD_SEP_SENTINEL = object()
23+
24+
_ArgsList = list[object]
25+
_HashableArgsTuple = tuple[Hashable, ...]
26+
27+
if typing.TYPE_CHECKING:
28+
import typing_extensions
29+
from botcore import BotBase
30+
31+
P = typing.ParamSpec("P")
32+
"""The command's signature."""
33+
R = typing.TypeVar("R")
34+
"""The command's return value."""
35+
36+
37+
class CommandOnCooldown(CommandError, typing.Generic[P, R]):
38+
"""Raised when a command is invoked while on cooldown."""
39+
40+
def __init__(
41+
self,
42+
message: str | None,
43+
function: Callable[P, Awaitable[R]],
44+
/,
45+
*args: P.args,
46+
**kwargs: P.kwargs,
47+
):
48+
super().__init__(message, function, args, kwargs)
49+
self._function = function
50+
self._args = args
51+
self._kwargs = kwargs
52+
53+
async def call_without_cooldown(self) -> R:
54+
"""
55+
Run the command this cooldown blocked.
56+
57+
Returns:
58+
The command's return value.
59+
"""
60+
return await self._function(*self._args, **self._kwargs)
61+
62+
63+
@dataclass
64+
class _CooldownItem:
65+
non_hashable_arguments: _ArgsList
66+
timeout_timestamp: float
67+
68+
69+
@dataclass
70+
class _SeparatedArguments:
71+
"""Arguments separated into their hashable and non-hashable parts."""
72+
73+
hashable: _HashableArgsTuple
74+
non_hashable: _ArgsList
75+
76+
@classmethod
77+
def from_full_arguments(cls, call_arguments: Iterable[object]) -> typing_extensions.Self:
78+
"""Create a new instance from full call arguments."""
79+
hashable = list[Hashable]()
80+
non_hashable = list[object]()
81+
82+
for item in call_arguments:
83+
try:
84+
hash(item)
85+
except TypeError:
86+
non_hashable.append(item)
87+
else:
88+
hashable.append(item)
89+
90+
return cls(tuple(hashable), non_hashable)
91+
92+
93+
class _CommandCooldownManager:
94+
"""
95+
Manage invocation cooldowns for a command through the arguments the command is called with.
96+
97+
Use `set_cooldown` to set a cooldown,
98+
and `is_on_cooldown` to check for a cooldown for a channel with the given arguments.
99+
A cooldown lasts for `cooldown_duration` seconds.
100+
"""
101+
102+
def __init__(self, *, cooldown_duration: float):
103+
self._cooldowns = dict[tuple[Hashable, _HashableArgsTuple], list[_CooldownItem]]()
104+
self._cooldown_duration = cooldown_duration
105+
self.cleanup_task = scheduling.create_task(
106+
self._periodical_cleanup(random.uniform(0, 10)),
107+
name="CooldownManager cleanup",
108+
)
109+
weakref.finalize(self, self.cleanup_task.cancel)
110+
111+
def set_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> None:
112+
"""Set `call_arguments` arguments on cooldown in `channel`."""
113+
timeout_timestamp = time.monotonic() + self._cooldown_duration
114+
separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
115+
cooldowns_list = self._cooldowns.setdefault(
116+
(channel, separated_arguments.hashable),
117+
[],
118+
)
119+
120+
for item in cooldowns_list:
121+
if item.non_hashable_arguments == separated_arguments.non_hashable:
122+
item.timeout_timestamp = timeout_timestamp
123+
return
124+
125+
cooldowns_list.append(_CooldownItem(separated_arguments.non_hashable, timeout_timestamp))
126+
127+
def is_on_cooldown(self, channel: Hashable, call_arguments: Iterable[object]) -> bool:
128+
"""Check whether `call_arguments` is on cooldown in `channel`."""
129+
current_time = time.monotonic()
130+
separated_arguments = _SeparatedArguments.from_full_arguments(call_arguments)
131+
cooldowns_list = self._cooldowns.get(
132+
(channel, separated_arguments.hashable),
133+
[],
134+
)
135+
136+
for item in cooldowns_list:
137+
if item.non_hashable_arguments == separated_arguments.non_hashable:
138+
return item.timeout_timestamp > current_time
139+
return False
140+
141+
async def _periodical_cleanup(self, initial_delay: float) -> None:
142+
"""
143+
Delete stale items every hour after waiting for `initial_delay`.
144+
145+
The `initial_delay` ensures cleanups are not running for every command at the same time.
146+
A strong reference to self is only kept while cleanup is running.
147+
"""
148+
weak_self = weakref.ref(self)
149+
del self
150+
151+
await asyncio.sleep(initial_delay)
152+
while True:
153+
await asyncio.sleep(60 * 60)
154+
weak_self()._delete_stale_items()
155+
156+
def _delete_stale_items(self) -> None:
157+
"""Remove expired items from internal collections."""
158+
current_time = time.monotonic()
159+
160+
for key, cooldowns_list in self._cooldowns.copy().items():
161+
filtered_cooldowns = [
162+
cooldown_item for cooldown_item in cooldowns_list if cooldown_item.timeout_timestamp < current_time
163+
]
164+
165+
if not filtered_cooldowns:
166+
del self._cooldowns[key]
167+
else:
168+
self._cooldowns[key] = filtered_cooldowns
169+
170+
171+
def _create_argument_tuple(*args: object, **kwargs: object) -> tuple[object, ...]:
172+
return (*args, _KEYWORD_SEP_SENTINEL, *kwargs.items())
173+
174+
175+
def block_duplicate_invocations(
176+
*,
177+
cooldown_duration: float = 5,
178+
send_notice: bool = False,
179+
args_preprocessor: Callable[P, Iterable[object]] | None = None,
180+
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
181+
"""
182+
Prevent duplicate invocations of a command with the same arguments in a channel for ``cooldown_duration`` seconds.
183+
184+
Args:
185+
cooldown_duration: Length of the cooldown in seconds.
186+
send_notice: If :obj:`True`, notify the user about the cooldown with a reply.
187+
args_preprocessor: If specified, this function is called with the args and kwargs the function is called with,
188+
its return value is then used to check for the cooldown instead of the raw arguments.
189+
190+
Returns:
191+
A decorator that adds a wrapper which applies the cooldowns.
192+
193+
Warning:
194+
The created wrapper raises :exc:`CommandOnCooldown` when the command is on cooldown.
195+
"""
196+
197+
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
198+
mgr = _CommandCooldownManager(cooldown_duration=cooldown_duration)
199+
200+
@command_wraps(func)
201+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
202+
if args_preprocessor is not None:
203+
all_args = args_preprocessor(*args, **kwargs)
204+
else:
205+
all_args = _create_argument_tuple(*args[2:], **kwargs) # skip self and ctx from the command
206+
ctx = typing.cast("Context[BotBase]", args[1])
207+
208+
if not isinstance(ctx.channel, discord.DMChannel):
209+
if mgr.is_on_cooldown(ctx.channel, all_args):
210+
if send_notice:
211+
with suppress(discord.NotFound):
212+
await ctx.reply("The command is on cooldown with the given arguments.")
213+
raise CommandOnCooldown(ctx.message.content, func, *args, **kwargs)
214+
mgr.set_cooldown(ctx.channel, all_args)
215+
216+
return await func(*args, **kwargs)
217+
218+
return wrapper
219+
220+
return decorator

botcore/utils/function.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Utils for manipulating functions."""
2+
3+
from __future__ import annotations
4+
5+
import functools
6+
import types
7+
import typing
8+
from collections.abc import Callable, Sequence, Set
9+
10+
__all__ = ["command_wraps", "GlobalNameConflictError", "update_wrapper_globals"]
11+
12+
13+
if typing.TYPE_CHECKING:
14+
_P = typing.ParamSpec("_P")
15+
_R = typing.TypeVar("_R")
16+
17+
18+
class GlobalNameConflictError(Exception):
19+
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
20+
21+
22+
def update_wrapper_globals(
23+
wrapper: Callable[_P, _R],
24+
wrapped: Callable[_P, _R],
25+
*,
26+
ignored_conflict_names: Set[str] = frozenset(),
27+
) -> Callable[_P, _R]:
28+
r"""
29+
Create a copy of ``wrapper``\, the copy's globals are updated with ``wrapped``\'s globals.
30+
31+
For forwardrefs in command annotations, discord.py uses the ``__global__`` attribute of the function
32+
to resolve their values. This breaks for decorators that replace the function because they have
33+
their own globals.
34+
35+
.. warning::
36+
This function captures the state of ``wrapped``\'s module's globals when it's called;
37+
changes won't be reflected in the new function's globals.
38+
39+
Args:
40+
wrapper: The function to wrap.
41+
wrapped: The function to wrap with.
42+
ignored_conflict_names: A set of names to ignore if a conflict between them is found.
43+
44+
Raises:
45+
:exc:`GlobalNameConflictError`:
46+
If ``wrapper`` and ``wrapped`` share a global name that's also used in ``wrapped``\'s typehints,
47+
and is not in ``ignored_conflict_names``.
48+
"""
49+
wrapped = typing.cast(types.FunctionType, wrapped)
50+
wrapper = typing.cast(types.FunctionType, wrapper)
51+
52+
annotation_global_names = (
53+
ann.split(".", maxsplit=1)[0] for ann in wrapped.__annotations__.values() if isinstance(ann, str)
54+
)
55+
# Conflicting globals from both functions' modules that are also used in the wrapper and in wrapped's annotations.
56+
shared_globals = (
57+
set(wrapper.__code__.co_names)
58+
& set(annotation_global_names)
59+
& set(wrapped.__globals__)
60+
& set(wrapper.__globals__)
61+
- ignored_conflict_names
62+
)
63+
if shared_globals:
64+
raise GlobalNameConflictError(
65+
f"wrapper and the wrapped function share the following "
66+
f"global names used by annotations: {', '.join(shared_globals)}. Resolve the conflicts or add "
67+
f"the name to the `ignored_conflict_names` set to suppress this error if this is intentional."
68+
)
69+
70+
new_globals = wrapper.__globals__.copy()
71+
new_globals.update((k, v) for k, v in wrapped.__globals__.items() if k not in wrapper.__code__.co_names)
72+
return types.FunctionType(
73+
code=wrapper.__code__,
74+
globals=new_globals,
75+
name=wrapper.__name__,
76+
argdefs=wrapper.__defaults__,
77+
closure=wrapper.__closure__,
78+
)
79+
80+
81+
def command_wraps(
82+
wrapped: Callable[_P, _R],
83+
assigned: Sequence[str] = functools.WRAPPER_ASSIGNMENTS,
84+
updated: Sequence[str] = functools.WRAPPER_UPDATES,
85+
*,
86+
ignored_conflict_names: Set[str] = frozenset(),
87+
) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
88+
r"""
89+
Update the decorated function to look like ``wrapped``\, and update globals for discord.py forwardref evaluation.
90+
91+
See :func:`update_wrapper_globals` for more details on how the globals are updated.
92+
93+
Args:
94+
wrapped: The function to wrap with.
95+
assigned: Sequence of attribute names that are directly assigned from ``wrapped`` to ``wrapper``.
96+
updated: Sequence of attribute names that are ``.update``d on ``wrapper`` from the attributes on ``wrapped``.
97+
ignored_conflict_names: A set of names to ignore if a conflict between them is found.
98+
99+
Returns:
100+
A decorator that behaves like :func:`functools.wraps`,
101+
with the wrapper replaced with the function :func:`update_wrapper_globals` returned.
102+
""" # noqa: D200
103+
def decorator(wrapper: Callable[_P, _R]) -> Callable[_P, _R]:
104+
return functools.update_wrapper(
105+
update_wrapper_globals(wrapper, wrapped, ignored_conflict_names=ignored_conflict_names),
106+
wrapped,
107+
assigned,
108+
updated,
109+
)
110+
111+
return decorator

0 commit comments

Comments
 (0)