Skip to content

Add lock utils #158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Changelog
=========

- :feature:`158` Add locking utilities for controlling concurrency logic
- :support:`202` Bump various development dependencies and CI workflow action versions
- :feature:`194` Add the :obj:`pydis_core.utils.interactions.user_has_access` helper function, that returns whether the given user is in the allowed_users list, or has a role from allowed_roles.

Expand Down
2 changes: 2 additions & 0 deletions pydis_core/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
error_handling,
function,
interactions,
lock,
logging,
members,
messages,
Expand Down Expand Up @@ -47,6 +48,7 @@ def apply_monkey_patches() -> None:
error_handling,
function,
interactions,
lock,
logging,
members,
messages,
Expand Down
93 changes: 92 additions & 1 deletion pydis_core/utils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,113 @@
from __future__ import annotations

import functools
import inspect
import types
import typing
from collections.abc import Callable, Sequence, Set

__all__ = ["GlobalNameConflictError", "command_wraps", "update_wrapper_globals"]
__all__ = [
"GlobalNameConflictError",
"command_wraps",
"get_arg_value",
"get_arg_value_wrapper",
"get_bound_args",
"update_wrapper_globals",
]


if typing.TYPE_CHECKING:
_P = typing.ParamSpec("_P")
_R = typing.TypeVar("_R")

Argument = int | str
BoundArgs = typing.OrderedDict[str, typing.Any]
Decorator = typing.Callable[[typing.Callable], typing.Callable]
ArgValGetter = typing.Callable[[BoundArgs], typing.Any]


class GlobalNameConflictError(Exception):
"""Raised on a conflict between the globals used to resolve annotations of a wrapped function and its wrapper."""


def get_arg_value(name_or_pos: Argument, arguments: BoundArgs) -> typing.Any:
"""
Return a value from `arguments` based on a name or position.

Arguments:
arguments: An ordered mapping of parameter names to argument values.
Returns:
Value from `arguments` based on a name or position.
Raises:
TypeError: `name_or_pos` isn't a str or int.
ValueError: `name_or_pos` does not match any argument.
"""
if isinstance(name_or_pos, int):
# Convert arguments to a tuple to make them indexable.
arg_values = tuple(arguments.items())
arg_pos = name_or_pos

try:
_name, value = arg_values[arg_pos]
return value
except IndexError:
raise ValueError(f"Argument position {arg_pos} is out of bounds.")
elif isinstance(name_or_pos, str):
arg_name = name_or_pos
try:
return arguments[arg_name]
except KeyError:
raise ValueError(f"Argument {arg_name!r} doesn't exist.")
else:
raise TypeError("'arg' must either be an int (positional index) or a str (keyword).")


def get_arg_value_wrapper(
decorator_func: typing.Callable[[ArgValGetter], Decorator],
name_or_pos: Argument,
func: typing.Callable[[typing.Any], typing.Any] | None = None,
) -> Decorator:
"""
Call `decorator_func` with the value of the arg at the given name/position.

Arguments:
decorator_func: A function that must accept a callable as a parameter to which it will pass a mapping of
parameter names to argument values of the function it's decorating.
name_or_pos: The name/position of the arg to get the value from.
func: An optional callable which will return a new value given the argument's value.

Returns:
The decorator returned by `decorator_func`.
"""
def wrapper(args: BoundArgs) -> typing.Any:
value = get_arg_value(name_or_pos, args)
if func:
value = func(value)
return value

return decorator_func(wrapper)


def get_bound_args(func: typing.Callable, args: tuple, kwargs: dict[str, typing.Any]) -> BoundArgs:
"""
Bind `args` and `kwargs` to `func` and return a mapping of parameter names to argument values.

Default parameter values are also set.

Args:
args: The arguments to bind to ``func``
kwargs: The keyword arguments to bind to ``func``
func: The function to bind ``args`` and ``kwargs`` to
Returns:
A mapping of parameter names to argument values.
"""
sig = inspect.signature(func)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()

return bound_args.arguments


def update_wrapper_globals(
wrapper: Callable[_P, _R],
wrapped: Callable[_P, _R],
Expand Down
156 changes: 156 additions & 0 deletions pydis_core/utils/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import asyncio
import inspect
import types
from collections import defaultdict
from collections.abc import Awaitable, Callable, Hashable
from functools import partial
from typing import Any
from weakref import WeakValueDictionary

from pydis_core.utils import function
from pydis_core.utils.function import command_wraps
from pydis_core.utils.logging import get_logger

log = get_logger(__name__)
__lock_dicts = defaultdict(WeakValueDictionary)

_IdCallableReturn = Hashable | Awaitable[Hashable]
_IdCallable = Callable[[function.BoundArgs], _IdCallableReturn]
ResourceId = Hashable | _IdCallable


class LockedResourceError(RuntimeError):
"""
Exception raised when an operation is attempted on a locked resource.

Attributes:
type (str): Name of the locked resource's type
id (typing.Hashable): ID of the locked resource
"""

def __init__(self, resource_type: str, resource_id: Hashable):
self.type = resource_type
self.id = resource_id

super().__init__(
f"Cannot operate on {self.type.lower()} `{self.id}`; "
"it is currently locked and in use by another operation."
)


class SharedEvent:
"""
Context manager managing an internal event exposed through the wait coro.

While any code is executing in this context manager, the underlying event will not be set;
when all of the holders finish the event will be set.
"""

def __init__(self):
self._active_count = 0
self._event = asyncio.Event()
self._event.set()

def __enter__(self):
"""Increment the count of the active holders and clear the internal event."""
self._active_count += 1
self._event.clear()

def __exit__(self, _exc_type, _exc_val, _exc_tb): # noqa: ANN001
"""Decrement the count of the active holders; if 0 is reached set the internal event."""
self._active_count -= 1
if not self._active_count:
self._event.set()

async def wait(self) -> None:
"""Wait for all active holders to exit."""
await self._event.wait()


def lock(
namespace: Hashable,
resource_id: ResourceId,
*,
raise_error: bool = False,
wait: bool = False,
) -> Callable:
"""
Turn the decorated coroutine function into a mutually exclusive operation on a `resource_id`.

If decorating a command, this decorator must go before (below) the `command` decorator.

Arguments:
namespace (typing.Hashable): An identifier used to prevent collisions among resource IDs.
resource_id: identifies a resource on which to perform a mutually exclusive operation.
It may also be a callable or awaitable which will return the resource ID given an ordered
mapping of the parameters' names to arguments' values.
raise_error (bool): If True, raise `LockedResourceError` if the lock cannot be acquired.
wait (bool): If True, wait until the lock becomes available. Otherwise, if any other mutually
exclusive function currently holds the lock for a resource, do not run the decorated function
and return None.

Raises:
:exc:`LockedResourceError`: If the lock can't be acquired and `raise_error` is set to True.
"""
def decorator(func: types.FunctionType) -> types.FunctionType:
name = func.__name__

@command_wraps(func)
async def wrapper(*args, **kwargs) -> Any:
log.trace(f"{name}: mutually exclusive decorator called")

if callable(resource_id):
log.trace(f"{name}: binding args to signature")
bound_args = function.get_bound_args(func, args, kwargs)

log.trace(f"{name}: calling the given callable to get the resource ID")
id_ = resource_id(bound_args)

if inspect.isawaitable(id_):
log.trace(f"{name}: awaiting to get resource ID")
id_ = await id_
else:
id_ = resource_id

log.trace(f"{name}: getting the lock object for resource {namespace!r}:{id_!r}")

# Get the lock for the ID. Create a lock if one doesn't exist yet.
locks = __lock_dicts[namespace]
lock_ = locks.setdefault(id_, asyncio.Lock())

# It's safe to check an asyncio.Lock is free before acquiring it because:
# 1. Synchronous code like `if not lock_.locked()` does not yield execution
# 2. `asyncio.Lock.acquire()` does not internally await anything if the lock is free
# 3. awaits only yield execution to the event loop at actual I/O boundaries
if wait or not lock_.locked():
log.debug(f"{name}: acquiring lock for resource {namespace!r}:{id_!r}...")
async with lock_:
return await func(*args, **kwargs)
else:
log.info(f"{name}: aborted because resource {namespace!r}:{id_!r} is locked")
if raise_error:
raise LockedResourceError(str(namespace), id_)
return None

return wrapper
return decorator


def lock_arg(
namespace: Hashable,
name_or_pos: function.Argument,
func: Callable[[Any], _IdCallableReturn] | None = None,
*,
raise_error: bool = False,
wait: bool = False,
) -> Callable:
"""
Apply the `lock` decorator using the value of the arg at the given name/position as the ID.

See `lock` docs for more information.

Arguments:
func: An optional callable or awaitable which will return the ID given the argument value.
"""
decorator_func = partial(lock, namespace, raise_error=raise_error, wait=wait)
return function.get_arg_value_wrapper(decorator_func, name_or_pos, func)