Skip to content

Commit 84ddead

Browse files
ChrisLoveringNumerlorMarkKoz
committed
Add lock utils
This includes some additional function utils too. Co-authored-by: Numerlor <numerlor@numerlor.me> Co-authored-by: MarkKoz <KozlovMark@gmail.com>
1 parent a6760a6 commit 84ddead

File tree

3 files changed

+232
-0
lines changed

3 files changed

+232
-0
lines changed

pydis_core/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
cooldown,
99
function,
1010
interactions,
11+
lock,
1112
logging,
1213
members,
1314
regex,
@@ -40,6 +41,7 @@ def apply_monkey_patches() -> None:
4041
cooldown,
4142
function,
4243
interactions,
44+
lock,
4345
logging,
4446
members,
4547
regex,

pydis_core/utils/function.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import functools
6+
import inspect
67
import types
78
import typing
89
from collections.abc import Callable, Sequence, Set
@@ -14,11 +15,89 @@
1415
_P = typing.ParamSpec("_P")
1516
_R = typing.TypeVar("_R")
1617

18+
Argument = typing.Union[int, str]
19+
BoundArgs = typing.OrderedDict[str, typing.Any]
20+
Decorator = typing.Callable[[typing.Callable], typing.Callable]
21+
ArgValGetter = typing.Callable[[BoundArgs], typing.Any]
22+
1723

1824
class GlobalNameConflictError(Exception):
1925
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""
2026

2127

28+
def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any:
29+
"""
30+
Return a value from `arguments` based on a name or position.
31+
32+
`arguments` is an ordered mapping of parameter names to argument values.
33+
34+
Raise TypeError if `name_or_pos` isn't a str or int.
35+
Raise ValueError if `name_or_pos` does not match any argument.
36+
"""
37+
if isinstance(name_or_pos, int):
38+
# Convert arguments to a tuple to make them indexable.
39+
arg_values = tuple(arguments.items())
40+
arg_pos = name_or_pos
41+
42+
try:
43+
name, value = arg_values[arg_pos]
44+
return value
45+
except IndexError:
46+
raise ValueError(f"Argument position {arg_pos} is out of bounds.")
47+
elif isinstance(name_or_pos, str):
48+
arg_name = name_or_pos
49+
try:
50+
return arguments[arg_name]
51+
except KeyError:
52+
raise ValueError(f"Argument {arg_name!r} doesn't exist.")
53+
else:
54+
raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")
55+
56+
57+
def get_arg_value_wrapper(
58+
decorator_func: typing.Callable[[ArgValGetter], Decorator],
59+
name_or_pos: Argument,
60+
func: typing.Callable[[typing.Any], typing.Any] = None,
61+
) -> Decorator:
62+
"""
63+
Call `decorator_func` with the value of the arg at the given name/position.
64+
65+
`decorator_func` must accept a callable as a parameter to which it will pass a mapping of
66+
parameter names to argument values of the function it's decorating.
67+
68+
`func` is an optional callable which will return a new value given the argument's value.
69+
70+
Return the decorator returned by `decorator_func`.
71+
"""
72+
def wrapper(args: BoundArgs) -> typing.Any:
73+
value = get_arg_value(name_or_pos, args)
74+
if func:
75+
value = func(value)
76+
return value
77+
78+
return decorator_func(wrapper)
79+
80+
81+
def get_bound_args(func: typing.Callable, args: typing.Tuple, kwargs: typing.Dict[str, typing.Any]) -> BoundArgs:
82+
"""
83+
Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.
84+
85+
Default parameter values are also set.
86+
87+
Args:
88+
args: The arguments to bind to ``func``
89+
kwargs: The keyword arguments to bind to ``func``
90+
func: The function to bind ``args`` and ``kwargs`` to
91+
Returns:
92+
A mapping of parameter names to argument values.
93+
"""
94+
sig = inspect.signature(func)
95+
bound_args = sig.bind(*args, **kwargs)
96+
bound_args.apply_defaults()
97+
98+
return bound_args.arguments
99+
100+
22101
def update_wrapper_globals(
23102
wrapper: Callable[_P, _R],
24103
wrapped: Callable[_P, _R],

pydis_core/utils/lock.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import asyncio
2+
import inspect
3+
import types
4+
from collections import defaultdict
5+
from functools import partial
6+
from typing import Any, Awaitable, Callable, Hashable, Union
7+
from weakref import WeakValueDictionary
8+
9+
from pydis_core.utils import function
10+
from pydis_core.utils.function import command_wraps
11+
from pydis_core.utils.logging import get_logger
12+
13+
log = get_logger(__name__)
14+
__lock_dicts = defaultdict(WeakValueDictionary)
15+
16+
_IdCallableReturn = Union[Hashable, Awaitable[Hashable]]
17+
_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
18+
ResourceId = Union[Hashable, _IdCallable]
19+
20+
21+
class LockedResourceError(RuntimeError):
22+
"""
23+
Exception raised when an operation is attempted on a locked resource.
24+
25+
Attributes:
26+
`type` -- name of the locked resource's type
27+
`id` -- ID of the locked resource
28+
"""
29+
30+
def __init__(self, resource_type: str, resource_id: Hashable):
31+
self.type = resource_type
32+
self.id = resource_id
33+
34+
super().__init__(
35+
f"Cannot operate on {self.type.lower()} `{self.id}`; "
36+
"it is currently locked and in use by another operation."
37+
)
38+
39+
40+
class SharedEvent:
41+
"""
42+
Context manager managing an internal event exposed through the wait coro.
43+
44+
While any code is executing in this context manager, the underlying event will not be set;
45+
when all of the holders finish the event will be set.
46+
"""
47+
48+
def __init__(self):
49+
self._active_count = 0
50+
self._event = asyncio.Event()
51+
self._event.set()
52+
53+
def __enter__(self):
54+
"""Increment the count of the active holders and clear the internal event."""
55+
self._active_count += 1
56+
self._event.clear()
57+
58+
def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
59+
"""Decrement the count of the active holders; if 0 is reached set the internal event."""
60+
self._active_count -= 1
61+
if not self._active_count:
62+
self._event.set()
63+
64+
async def wait(self) -> None:
65+
"""Wait for all active holders to exit."""
66+
await self._event.wait()
67+
68+
69+
def lock(
70+
namespace: Hashable,
71+
resource_id: ResourceId,
72+
*,
73+
raise_error: bool = False,
74+
wait: bool = False,
75+
) -> Callable:
76+
"""
77+
Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.
78+
79+
If `wait` is True, wait until the lock becomes available. Otherwise, if any other mutually
80+
exclusive function currently holds the lock for a resource, do not run the decorated function
81+
and return None.
82+
83+
If `raise_error` is True, raise `LockedResourceError` if the lock cannot be acquired.
84+
85+
`namespace` is an identifier used to prevent collisions among resource IDs.
86+
87+
`resource_id` identifies a resource on which to perform a mutually exclusive operation.
88+
It may also be a callable or awaitable which will return the resource ID given an ordered
89+
mapping of the parameters' names to arguments' values.
90+
91+
If decorating a command, this decorator must go before (below) the `command` decorator.
92+
"""
93+
def decorator(func: types.FunctionType) -> types.FunctionType:
94+
name = func.__name__
95+
96+
@command_wraps(func)
97+
async def wrapper(*args, **kwargs) -> Any:
98+
log.trace(f"{name}: mutually exclusive decorator called")
99+
100+
if callable(resource_id):
101+
log.trace(f"{name}: binding args to signature")
102+
bound_args = function.get_bound_args(func, args, kwargs)
103+
104+
log.trace(f"{name}: calling the given callable to get the resource ID")
105+
id_ = resource_id(bound_args)
106+
107+
if inspect.isawaitable(id_):
108+
log.trace(f"{name}: awaiting to get resource ID")
109+
id_ = await id_
110+
else:
111+
id_ = resource_id
112+
113+
log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")
114+
115+
# Get the lock for the ID. Create a lock if one doesn't exist yet.
116+
locks = __lock_dicts[namespace]
117+
lock_ = locks.setdefault(id_, asyncio.Lock())
118+
119+
# It's safe to check an asyncio.Lock is free before acquiring it because:
120+
# 1. Synchronous code like `if not lock_.locked()` does not yield execution
121+
# 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
122+
# 3. awaits only yield execution to the event loop at actual I/O boundaries
123+
if wait or not lock_.locked():
124+
log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
125+
async with lock_:
126+
return await func(*args, **kwargs)
127+
else:
128+
log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
129+
if raise_error:
130+
raise LockedResourceError(str(namespace), id_)
131+
132+
return wrapper
133+
return decorator
134+
135+
136+
def lock_arg(
137+
namespace: Hashable,
138+
name_or_pos: function.Argument,
139+
func: Callable[[Any], _IdCallableReturn] = None,
140+
*,
141+
raise_error: bool = False,
142+
wait: bool = False,
143+
) -> Callable:
144+
"""
145+
Apply the `lock` decorator using the value of the arg at the given name/position as the ID.
146+
147+
`func` is an optional callable or awaitable which will return the ID given the argument value.
148+
See `lock` docs for more information.
149+
"""
150+
decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
151+
return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)

0 commit comments

Comments
 (0)