Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/async_kernel/asyncshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def _list_matplotlib_backends_and_gui_loops(self) -> list[str | None]:

@contextlib.contextmanager
def context(self) -> Generator[None, Any, None]:
"A context manager where the shell is active."
with self.pending_manager.context():
yield

Expand Down
41 changes: 17 additions & 24 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from async_kernel import utils
from async_kernel.common import Fixed
from async_kernel.event_loop.run import Host, get_start_guest_run
from async_kernel.pending import Pending, PendingCancelled, PendingGroup, PendingTracker
from async_kernel.pending import Pending, PendingCancelled, PendingGroup, PendingManager, PendingTracker
from async_kernel.typing import Backend, CallerCreateOptions, CallerState, Loop, NoValue, RunSettings, T

with contextlib.suppress(ImportError):
Expand Down Expand Up @@ -534,7 +534,10 @@ async def _scheduler(self, backend: Backend, queue: SingleConsumerAsyncQueue, tg
task.add_done_callback(self._tasks.discard)
del task
else:
item.context.run(tg.start_soon, self._call_scheduled, item)
if context := item.context:
context.run(tg.start_soon, self._call_scheduled, item)
else:
tg.start_soon(self._call_scheduled, item)
del item, result
finally:
if asyncio_backend:
Expand Down Expand Up @@ -596,7 +599,8 @@ async def _call_scheduled(self, pen: Pending) -> None:
@staticmethod
def _reject(item: tuple | Pending) -> None:
if isinstance(item, Pending):
item.cancel("The caller has been closed", _force=True)
item.cancel("The caller has been closed")
item.set_result(None)

@classmethod
def _start_idle_worker_cleanup_thead(cls) -> None:
Expand Down Expand Up @@ -725,7 +729,7 @@ def schedule_call(
trackers: The tracker subclasses of active trackers which to add the pending.
**metadata: Additional metadata to store in the instance.
"""
pen = Pending(trackers, context, func=func, args=args, kwargs=kwargs, caller=self, **metadata)
pen = Pending(context, trackers, func=func, args=args, kwargs=kwargs, caller=self, **metadata)
if backend is NoValue or (backend := Backend(backend)) is self.backend:
self._queue.append(pen)
else:
Expand Down Expand Up @@ -908,32 +912,23 @@ def queue_call(
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Pending[T]:
) -> None:
"""
Queue the execution of `func` in a queue unique to it and the caller instance (thread-safe).

The returned pending is 'resettable' and will provide the result of the most recent successful
call once the queue has been emptied. Exceptions are not set, instead the result would be `None`.
A low level function to queue the execution of `func` in a queue unique to it and the caller instance (thread-safe).

Args:
func: The function.
*args: Arguments to use with `func`.
**kwargs: Keyword arguments to use with `func`.

Returns:
Pending: The pending where the queue loop is running.

Warning:
- Do not assume the result corresponds to the function call.
- The returned pending returns the last result of the queue call once the queue becomes empty.

Notes:
- The queue runs in a *task* wrapped with a [async_kernel.pending.Pending][] that remains running until one of the following occurs:
1. The pending is cancelled.
- The queue runs inside a pending that remains running until one of the following occurs:
1. The queue is stopped.
2. The method [Caller.queue_close][] is called with `func` or `func`'s hash.
3. `func` is deleted (utilising [weakref.finalize][]).
- The [context][contextvars.Context] of the initial call is used for subsequent queue calls.
- Exceptions are 'swallowed'; the last successful result is set on the pending.
- Exceptions are logged to caller.log but not propagated.
- The pending created on the first call will only registered with PendingManager subclassed trackers and **not** PendingGroup.
"""
key = hash(func)
if not (pen_ := self._queue_map.get(key)):
Expand All @@ -955,15 +950,13 @@ async def queue_loop() -> None:
if pen.cancelled():
raise
self.log.exception("Execution %s failed", item, exc_info=e)
if not queue.queue:
pen.set_result(result, reset=True)
item = result = None # noqa: PLW2901
item = result = None # noqa: PLW2901
finally:
self._queue_map.pop(key)

self._queue_map[key] = pen_ = self.schedule_call(queue_loop, (), {}, key=key, queue=queue)
pen_ = self.schedule_call(queue_loop, (), {}, None, PendingManager, key=key, queue=queue)
self._queue_map[key] = pen_
pen_.metadata["queue"].append((func, args, kwargs))
return pen_.add_to_trackers() # pyright: ignore[reportReturnType]

def queue_close(self, func: Callable | int) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def msg_handler(
case RunMode.direct:
self.callers[channel].call_direct(handler, job)
case RunMode.queue:
self.callers[channel].queue_call(handler, job).trackers = () # A slight optimisation
self.callers[channel].queue_call(handler, job)
case RunMode.task:
self.callers[channel].call_soon(handler, job)
case RunMode.thread:
Expand Down
115 changes: 51 additions & 64 deletions src/async_kernel/pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import weakref
from collections import deque
from collections.abc import AsyncGenerator, Awaitable, Callable, Generator
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, final, overload

import anyio
from aiologic import Event
Expand Down Expand Up @@ -80,15 +80,19 @@ def __init__(self) -> None:
self._instances[self.id] = self

def _activate(self) -> Token[str | None]:
return self._id_contextvar.set(self.id)
try:
return self._id_contextvar.set(self.id)
except AttributeError as e:
e.add_note("Pending tracker must be subclassed to use it!")
raise

def _deactivate(self, token: contextvars.Token[str | None]) -> None:
self._id_contextvar.reset(token)

def add(self, pen: Pending) -> None:
"Track `Pending` until it is done."

if isinstance(self, pen.trackers) and (pen not in self._pending):
if pen not in self._pending:
self._pending.add(pen)
pen.add_done_callback(self._on_pending_done)

Expand All @@ -99,11 +103,13 @@ def _on_pending_done(self, pen: Pending) -> None:

class PendingManager(PendingTracker):
"""
PendingManager is a class that can be used to capture the creation of [async_kernel.pending.Pending][]
in any specific context.
PendingManager is a `PendingTracker` subclass for tracking the creation of [async_kernel.pending.Pending][]
in multiple contexts.

This class can not be used directly and must be subclassed to be useful. For any
subclass, there is only one active instance in that context.
This class must be subclassed to be useful.

For each subclass there is zero or one active trackers at a time. Activating a manager will 'replace' a
previously active pending manager.

Notes:

Expand Down Expand Up @@ -158,16 +164,23 @@ def deactivate(self, token: contextvars.Token[str | None]) -> None:
"""
self._deactivate(token)

def remove(self, pen: Pending) -> None:
"""
Remove a pending from the manager.
"""
self._pending.remove(pen)

@contextlib.contextmanager
def context(self) -> Generator[None, Any, None]:
"""A context manager to activate this instance."""
"""A context manager where the pending manager is activated."""
token = self.activate()
try:
yield
finally:
self.deactivate(token)


@final
class PendingGroup(PendingTracker, anyio.AsyncContextManagerMixin):
"""
An asynchronous context manager for tracking [async_kernel.pending.Pending][] created in the context.
Expand Down Expand Up @@ -314,7 +327,6 @@ class Pending(Awaitable[T]):
"_exception",
"_result",
"context",
"trackers",
]

REPR_OMIT: ClassVar[set[str]] = {"func", "args", "kwargs"}
Expand All @@ -328,15 +340,11 @@ class Pending(Awaitable[T]):
_exception: Exception
_done: bool
_result: T
context: contextvars.Context
trackers: type[PendingTracker] | tuple[type[PendingTracker], ...]
"""
A tuple of [async_kernel.pending.PendingTracker][] subclasses that the pending is permitted to register with.
context: contextvars.Context | None
"""
The context associated with Pending.

Should be specified during init.

For some pending it may not make sense for it to be added to a [PendingGroup][]
Instead specify `PendingManager` instead of `PendingTracker`.
The context is updated for the Trackers during init.
"""

@property
Expand All @@ -348,17 +356,18 @@ def metadata(self) -> dict[str, Any]:

def __init__(
self,
trackers: type[PendingTracker] | tuple[type[PendingTracker], ...] = (),
context: contextvars.Context | None = None,
trackers: type[PendingTracker] | tuple[type[PendingTracker], ...] = (),
/,
**metadata: Any,
) -> None:
"""
Initializes a new Pending object with optional creation options and metadata.

Args:
trackers: A subclass or tuple of `PendingTracker` subclasses to which the pending can be added given the context.
context: A context to associate with the pending, if provided it is copied.
trackers: A subclass or tuple of `PendingTracker` subclasses to which the pending can be added in the current context.
**metadata: Arbitrary keyword arguments containing metadata to associate with this Pending instance.
trackers: Enabled by default. To deactivate tracking pass `trackers=False`

Behavior:
- Initializes internal state for tracking completion and cancellation
Expand All @@ -368,19 +377,21 @@ def __init__(
self._metadata_mappings[id(self)] = metadata
self._done = False
self._cancelled = None
self.trackers = trackers

# A copy the context is required to avoid `PendingTracker.id` leakage.
self.context = context = context.copy() if context else contextvars.copy_context()
if trackers or context:
# A copy the context is required to avoid `PendingTracker.id` leakage.
context = context.copy() if context else contextvars.copy_context()
self.context = context

# PendingTacker registration.
for cls in PendingTracker._subclasses: # pyright: ignore[reportPrivateUsage]
if id_ := context.get(cls._id_contextvar): # pyright: ignore[reportPrivateUsage]
if trackers and issubclass(cls, trackers) and (tracker := PendingTracker._instances.get(id_)): # pyright: ignore[reportPrivateUsage]
tracker.add(self)
else:
# Clear `PendingTracker.id`.
context.run(cls._id_contextvar.set, None) # pyright: ignore[reportPrivateUsage]
if context:
for cls in PendingTracker._subclasses: # pyright: ignore[reportPrivateUsage]
if id_ := context.get(cls._id_contextvar): # pyright: ignore[reportPrivateUsage]
if trackers and issubclass(cls, trackers) and (tracker := PendingTracker._instances.get(id_)): # pyright: ignore[reportPrivateUsage]
tracker.add(self)
else:
# Clear `PendingTracker.id`.
context.run(cls._id_contextvar.set, None) # pyright: ignore[reportPrivateUsage]

def __del__(self):
self._metadata_mappings.pop(id(self), None)
Expand Down Expand Up @@ -497,20 +508,14 @@ def _set_done(self, mode: Literal["result", "exception"], value) -> None:
except Exception:
pass

def set_result(self, value: T, *, reset: bool = False) -> None:
def set_result(self, value: T) -> None:
"""
Set the result (low-level-thread-safe).

Args:
value: The result.
reset: Revert to being not done.

Warning:
- When using reset ensure to proivide sufficient time for any waiters to retrieve the result.
value: The result to set.
"""
self._set_done("result", value)
if reset:
self._done = False

def set_exception(self, exception: BaseException) -> None:
"""
Expand All @@ -519,7 +524,7 @@ def set_exception(self, exception: BaseException) -> None:
self._set_done("exception", exception)

@enable_signal_safety
def cancel(self, msg: str | None = None, *, _force: bool = False) -> bool:
def cancel(self, msg: str | None = None) -> bool:
"""
Cancel the instance.

Expand All @@ -532,17 +537,13 @@ def cancel(self, msg: str | None = None, *, _force: bool = False) -> bool:

Returns: If it has been cancelled.
"""
try:
if not self._done:
if (cancelled := self._cancelled or "") and msg:
msg = f"{cancelled}\n{msg}"
self._cancelled = msg or cancelled
if canceller := getattr(self, "_canceller", None):
canceller(msg)
return self.cancelled()
finally:
if _force and not self._done:
self._set_done("result", None)
if not self._done:
if (cancelled := self._cancelled or "") and msg:
msg = f"{cancelled}\n{msg}"
self._cancelled = msg or cancelled
if canceller := getattr(self, "_canceller", None):
canceller(msg)
return self.cancelled()

def cancelled(self) -> bool:
"""Return True if the pending is cancelled."""
Expand Down Expand Up @@ -629,17 +630,3 @@ def exception(self) -> BaseException | None:
if self._cancelled is not None:
raise PendingCancelled(self._cancelled)
return getattr(self, "_exception", None)

def add_to_trackers(self) -> Self:
"""
Add this pending to the trackers active in the current context that are include.
"""
if trackers := self.trackers:
for cls_ in PendingTracker._subclasses: # pyright: ignore[reportPrivateUsage]
if (
issubclass(cls_, trackers)
and (id_ := cls_._id_contextvar.get()) # pyright: ignore[reportPrivateUsage]
and (pm := PendingTracker._instances.get(id_)) # pyright: ignore[reportPrivateUsage]
):
pm.add(self)
return self
7 changes: 3 additions & 4 deletions tests/test_callable_kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,10 @@ async def test_msg(self, interface: CallableKernelInterface, mocker):
msg = interface.msg("execute_request", content={"code": code})
msg["header"]["session"] = "test session"
buffers = [b"123"]
async with interface.kernel.caller.create_pending_group() as pg:
interface._handle_msg(orjson.dumps(msg).decode(), buffers) # pyright: ignore[reportPrivateUsage]
assert len(pg.pending) == 1
interface._handle_msg(orjson.dumps(msg).decode(), buffers) # pyright: ignore[reportPrivateUsage]

assert sender.call_count == 4
while sender.call_count != 4:
await anyio.sleep(0.01)
reply = orjson.loads(sender.call_args_list[2][0][0])
assert reply["header"]["msg_type"] == "execute_reply"
assert reply["content"]["status"] == "ok"
Expand Down
Loading