Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,3 @@ These are the currently assigned run modes.

Async kernel started as a [fork](https://github.com/ipython/ipykernel/commit/8322a7684b004ee95f07b2f86f61e28146a5996d)
of [IPyKernel](https://github.com/ipython/ipykernel). Thank you to the original contributors of IPyKernel that made Async kernel possible.

[^non-main-thread]: The Shell can run in other threads with the associated limitations with regard to signalling and interrupts.
2 changes: 1 addition & 1 deletion src/async_kernel/asyncshell.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ def __init__(self, *, protected: bool = True) -> None:
def stop(self, *, force=False) -> None:
"Stop this subshell."
if force or not self.protected:
self.pending_manager.deactivate(cancel_pending=True)
self.pending_manager.deactivate()
self.reset(new_session=False)
self.kernel.subshell_manager.subshells.pop(self.subshell_id, None)
self.set_trait("stopped", True)
Expand Down
7 changes: 6 additions & 1 deletion src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,5 +952,10 @@ async def wait(
return done, pending

def create_pending_group(self, *, shield: bool = False):
"Create a new [pending group][async_kernel.pending.PendingGroup]."
"""
Create a new [pending group][async_kernel.pending.PendingGroup].

Args:
shield: Shield the pending group from external cancellation.
"""
return PendingGroup(shield=shield)
195 changes: 101 additions & 94 deletions src/async_kernel/pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Self, overload

import anyio
from aiologic import CountdownEvent, Event
from aiologic import Event
from aiologic.lowlevel import async_checkpoint, create_async_event, enable_signal_safety, green_checkpoint
from typing_extensions import override
from wrapt import lazy_import

import async_kernel
from async_kernel.common import Fixed
from async_kernel.typing import Backend, PendingCreateOptions, PendingTrackerState, T
from async_kernel.typing import Backend, PendingCreateOptions, T

trio_checkpoint: Callable[[], Awaitable] = lazy_import("trio.lowlevel", "checkpoint") # pyright: ignore[reportAssignmentType]

Expand Down Expand Up @@ -131,7 +131,7 @@ def __repr__(self) -> str:
rep = (
"<Pending"
+ ((" ⛔" + (f"message={self._cancelled!s}" if self._cancelled else "")) if self.cancelled() else "")
+ (" 🏁" if self._done else " 🏃")
+ ((f" ❗ {e!r}" if (e := getattr(self, "_exception", None)) else " 🏁") if self._done else " 🏃")
)
rep = f"{rep} at {id(self)}"
if self._options:
Expand Down Expand Up @@ -376,16 +376,16 @@ class PendingTracker:
_active_contexts: ClassVar[dict[str, Self]] = {}
_contextvar: ClassVar[contextvars.ContextVar[str | None]] = contextvars.ContextVar("PendingManager", default=None)

_state = PendingTrackerState.idle
_active = False
_pending: Fixed[Self, set[Pending[Any]]] = Fixed(set)
_count_event = Fixed(CountdownEvent)
_tracking = False

context_id: Fixed[Self, str] = Fixed(lambda _: str(uuid.uuid4()))
"The context id (per instance)."

@property
def state(self) -> PendingTrackerState:
return self._state
def active(self) -> bool:
return self._active

@property
def pending(self) -> set[Pending[Any]]:
Expand All @@ -396,12 +396,6 @@ def __init_subclass__(cls) -> None:
cls._contextvar = contextvars.ContextVar(f"{cls.__module__}.{cls.__name__}", default=None)
return super().__init_subclass__()

def _set_context(self) -> contextvars.Token[str | None]:
assert self._state in {PendingTrackerState.active, PendingTrackerState.active}
self._active_classes.add(self.__class__)
self._active_contexts[self.context_id] = self
return self._contextvar.set(self.context_id)

@classmethod
def add_to_pending_trackers(cls, pen: Pending) -> None:
"""
Expand All @@ -412,8 +406,12 @@ def add_to_pending_trackers(cls, pen: Pending) -> None:
"""
# Called by `Pending` when a new instance for each new instance.
for cls_ in cls._active_classes:
if (id_ := cls_._contextvar.get()) and (pm := cls._active_contexts.get(id_)):
pm.track(pen)
if id_ := cls_._contextvar.get():
if pm := cls._active_contexts.get(id_):
pm.add(pen)
else:
msg = f"The context of {cls} no longer active!"
raise InvalidStateError(msg)

@classmethod
def current(cls) -> Self | None:
Expand All @@ -422,28 +420,50 @@ def current(cls) -> Self | None:
return current
return None

def track(self, pen: Pending) -> None:
def start_tracking(self) -> contextvars.Token[str | None]:
"""
Start tracking `Pending` in the current context.
"""
if self._tracking or not self.active:
raise InvalidStateError
assert self._active
self._active_classes.add(self.__class__)
self._active_contexts[self.context_id] = self
self._parent_context_id = self._contextvar.get()
self._tracking = True
return self._contextvar.set(self.context_id)

def stop_tracking(self, token: contextvars.Token[str | None]) -> None:
"""
Stop tracking using the token.

Args:
token: The token returned from [start_tracking][].
"""
self._contextvar.reset(token)
self._tracking = False
self._parent_context_id = None

def add(self, pen: Pending) -> None:
"Track `Pending` if it isn't done."
if self._state is not PendingTrackerState.active:
msg = f"{self.__class__.__name__} state is {self._state}!"
raise InvalidStateError(msg)
assert self._active
if not pen.done() and pen not in self._pending:
pen.add_done_callback(self.remove)
pen.add_done_callback(self.discard)
self._pending.add(pen)
self._count_event.up()
if self._tracking and (id_ := self._parent_context_id) and (parent := self._active_contexts.get(id_)):
parent.add(pen)

def remove(self, pen: Pending) -> None:
"Remove a `Pending` that is being tracked."
"Remove a `Pending`."
self._pending.remove(pen)
self._count_event.down()
if self._state is PendingTrackerState.exiting and not self._count_event.value:
self._state = PendingTrackerState.stopped
self._active_contexts.pop(self.context_id, None)
pen.remove_done_callback(self.discard)

def discard(self, pen: Pending) -> None:
"Remove a `Pending` if it is being tracked."
if pen.remove_done_callback(self.remove):
"Discard the `Pending`."
try:
self.remove(pen)
except IndexError:
pass


class PendingManager(PendingTracker):
Expand All @@ -459,52 +479,20 @@ def activate(self) -> Self:
"""
Enter the active state to begin tracking pending.
"""

if self._state not in {PendingTrackerState.idle, PendingTrackerState.stopped}:
raise InvalidStateError
self._state = PendingTrackerState.active
assert not self._active
self._active_contexts[self.context_id] = self
self._active_classes.add(self.__class__)
self._active = True
return self

def deactivate(self, *, cancel_pending: bool = True) -> CountdownEvent | None:
def deactivate(self) -> None:
"""
Leave the active state and remove pending.


Args:
cancel_pending: Cancel all `Pending` in the current 'pending' set.

Returns:
CountdownEvent: If there are pending that have not finished cancellation.
Leave the active state cancelling all pending.
"""
if self._state is PendingTrackerState.stopped:
return None
self._state = PendingTrackerState.exiting
for pen in self._pending.copy():
if cancel_pending:
pen.cancel(f"{self} has been deactivated")
else:
self.discard(pen)
if self._count_event.value:
return self._count_event
self._state = PendingTrackerState.stopped
self._active = False
self._active_contexts.pop(self.context_id, None)
return None

def start_tracking(self) -> contextvars.Token[str | None]:
"""
Start tracking `Pending` in the current context.
"""
if self.state is not PendingTrackerState.active:
raise InvalidStateError
return self._set_context()

def stop_tracking(self, token: contextvars.Token[str | None]) -> None:
"""
Stop tracking using the token returned using start tracking.
"""
self._contextvar.reset(token)
for pen in self._pending.copy():
pen.cancel(f"{self} has been deactivated")


class PendingGroup(PendingTracker, anyio.AsyncContextManagerMixin):
Expand Down Expand Up @@ -534,8 +522,10 @@ class PendingGroup(PendingTracker, anyio.AsyncContextManagerMixin):
```
"""

_cancel_scope: anyio.CancelScope | None
_cancel_scope: anyio.CancelScope
_cancelled: str | None = None
cancellation_timeout = 10
"The maximum time to wait for cancelled pending to be done."

caller = Fixed(lambda _: async_kernel.Caller())

Expand All @@ -546,43 +536,60 @@ def __init__(self, *, shield: bool = False) -> None:

@contextlib.asynccontextmanager
async def __asynccontextmanager__(self) -> AsyncGenerator[Self]:
if self._cancelled is None:
self._cancel_scope = anyio.CancelScope(shield=self._shield)
self._state = PendingTrackerState.active
token = self._set_context()
try:
with self._cancel_scope:
self._cancel_scope = anyio.CancelScope(shield=self._shield)
self._all_done = create_async_event()
self._active = True
self._leaving_context = False
token = self.start_tracking()
try:
with self._cancel_scope:
try:
yield self
await self._count_event
finally:
self._cancel_scope = None
self._state = PendingTrackerState.exiting
self._contextvar.reset(token)
if self._count_event.value:
for pen in self._pending.copy():
pen.cancel()
with anyio.CancelScope(shield=True):
await self._count_event
self._state = PendingTrackerState.stopped
self._leaving_context = True
if self._pending:
await self._all_done
except (anyio.get_cancelled_exc_class(), Exception) as e:
self.cancel(f"An error occurred: {e!r}")
raise
if self._cancelled is not None:
raise PendingCancelled(self._cancelled)
finally:
self._leaving_context = True
self.stop_tracking(token)
if self._pending:
if self._all_done or self._all_done.cancelled():
self._all_done = create_async_event()
if self._pending and not self._all_done:
with anyio.CancelScope(shield=True), anyio.move_on_after(self.cancellation_timeout):
await self._all_done
self._active = False

@override
def add(self, pen: Pending):
assert self._active
if self._cancelled is not None:
raise PendingCancelled(self._cancelled)
msg = f"Trying to add to a cancelled PendingGroup.\nCancellation messages: {self._cancelled}"
pen.cancel(msg)
else:
super().add(pen)

@override
def remove(self, pen: Pending) -> None:
"Remove pen from the group."
super().remove(pen)
if pen.done() and self._state is PendingTrackerState.active and (pen.cancelled() or pen.exception()):
msg = f"{pen} tracked by {self} failed or was cancelled"
for pen_ in self._pending.copy():
pen_.cancel(msg)
self.cancel(msg)
if pen.done() and self._active and (not pen.cancelled() and (pen.exception())):
self.cancel(f"Exception in member: {pen}")
if self._leaving_context and not self._pending:
self._all_done.set()

def cancel(self, msg: str | None = None) -> None:
"Cancel the pending group (thread-safe)."
self._cancelled = "\n".join(((self._cancelled if self._cancelled else ""), msg or ""))
if scope := self._cancel_scope:
self.caller.call_direct(scope.cancel, msg)
if self._active:
self._cancelled = "\n".join(((self._cancelled or ""), msg or ""))
if not self._cancel_scope.cancel_called:
self.caller.call_direct(self._cancel_scope.cancel, msg)
for pen_ in self.pending:
pen_.cancel(msg)

def cancelled(self) -> bool:
"""Return True if the pending group is cancelled."""
Expand Down
9 changes: 0 additions & 9 deletions src/async_kernel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,6 @@ def get_string(self, value: str | Tags, default: str = "") -> str:
"""


class PendingTrackerState(enum.Enum):
"The state of a [async_kernel.pending.PendingManager][]."

idle = enum.auto()
active = enum.auto()
exiting = enum.auto()
stopped = enum.auto()


class CallerState(enum.Enum):
"The State of a [async_kernel.caller.Caller][]."

Expand Down
3 changes: 2 additions & 1 deletion tests/test_callable_kernel_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ async def test_msg(self, interface: CallableKernelInterface, mocker):
msg = interface.msg("execute_request", content={"code": code})
msg["header"]["session"] = "test session"
buffers = [b"123"]
async with interface.kernel.caller.create_pending_group():
async with interface.kernel.caller.create_pending_group() as pg:
interface._handle_msg(orjson.dumps(msg).decode(), buffers) # pyright: ignore[reportPrivateUsage]
assert len(pg.pending) == 1

assert sender.call_count == 4
reply = orjson.loads(sender.call_args_list[2][0][0])
Expand Down
Loading
Loading