Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from async_kernel import utils
from async_kernel.common import Fixed
from async_kernel.pending import Pending, PendingCancelled, PendingGroup, checkpoint
from async_kernel.pending import Pending, PendingCancelled, PendingGroup, PendingTracker, checkpoint
from async_kernel.typing import Backend, CallerCreateOptions, CallerState, NoValue, PendingCreateOptions, T

with contextlib.suppress(ImportError):
Expand Down Expand Up @@ -735,23 +735,32 @@ def queue_get(self, func: Callable) -> Pending[None] | None:
"""
return self._queue_map.get(hash(func))

def queue_call(
def queue_call_advanced(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
args: tuple,
kwargs: dict,
/,
*args: P.args,
**kwargs: P.kwargs,
*,
allow_tracking: NoValue | bool = NoValue, # pyright: ignore[reportInvalidTypeForm]
track: bool = True,
) -> Pending[T]:
"""
Queue the execution of `func` in a queue unique to it and the caller instance (thread-safe).

Args:
func: The function.
*args: Arguments to use with `func`.
**kwargs: Keyword arguments to use with `func`.
args: Arguments to use with `func`.
kwargs: Keyword arguments to use with `func`.
allow_tracking: Used for the first call for `func`. Defaults to the value of `track` if not provided.
track: Allow the present call to be tracked by a [PendingTracker][async_kernel.pending.PendingTracker].
This includes the subclasses [PendingManager][async_kernel.pending.PendingManager] and [PendingGroup][async_kernel.pending.PendingGroup].

Returns:
Pending: The pending where the queue loop is running.

Warning:
- Do not assume the result matches the function call.
- Do not assume the result corresponds to the function call.
- The returned pending returns the last result of the queue call once the queue becomes empty.

Notes:
Expand All @@ -761,9 +770,6 @@ def queue_call(
3. `func` is deleted (utilising [weakref.finalize][]).
- The [context][contextvars.Context] of the initial call is used for subsequent queue calls.
- Exceptions are 'swallowed'; the last successful result is set on the pending.

Returns:
Pending: The pending where the queue loop is running.
"""
key = hash(func)
if not (pen_ := self._queue_map.get(key)):
Expand Down Expand Up @@ -797,14 +803,43 @@ async def queue_loop() -> None:
if not queue:
await event
pen.metadata["resume"] = noop
del event
finally:
self._queue_map.pop(key)

self._queue_map[key] = pen_ = self.schedule_call(queue_loop, (), {}, key=key, queue=queue, resume=noop)
options = PendingCreateOptions(allow_tracking=track if allow_tracking is NoValue else allow_tracking)
self._queue_map[key] = pen_ = self.schedule_call(
queue_loop, (), {}, options, key=key, queue=queue, resume=noop
)
pen_.metadata["queue"].append((func, args, kwargs))
pen_.metadata["resume"]()
if track:
PendingTracker.add_to_pending_trackers(pen_)
return pen_ # pyright: ignore[reportReturnType]

def queue_call(
self,
func: Callable[P, T | CoroutineType[Any, Any, T]],
/,
*args: P.args,
**kwargs: P.kwargs,
) -> Pending[T]:
"""
Queue the execution of `func` in a queue unique to it and the caller instance (thread-safe).

The returned pending is 'resettable' and will provide the result of the most recent successful
call once the queue has been emptied. Exceptions are not set, instead the result would be `None`.

Args:
func: The function.
*args: Arguments to use with `func`.
**kwargs: Keyword arguments to use with `func`.

Returns:
Pending: The pending where the queue loop is running.
"""
return self.queue_call_advanced(func, args, kwargs, track=True)

def queue_close(self, func: Callable | int) -> None:
"""
Close the execution queue associated with `func` (thread-safe).
Expand Down Expand Up @@ -953,9 +988,22 @@ async def wait(

def create_pending_group(self, *, shield: bool = False):
"""
Create a new [pending group][async_kernel.pending.PendingGroup].
Create a new [PendingGroup][async_kernel.pending.PendingGroup] instance.

The pending group will wait for all pending created in its context to complete (except for those that opt out).
If any pending result in exception, the pending group and all registered pending are cancelled.
If the pending group context is cancelled or results in exception, all pending in the group are
also cancelled.

Args:
shield: Shield the pending group from external cancellation.

Usage:

```python
async with Caller().create_pending_group() as pg:
pg.caller.to_thread(my_func)
...
```
"""
return PendingGroup(shield=shield)
2 changes: 1 addition & 1 deletion src/async_kernel/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def msg_handler(
case RunMode.direct:
self.callers[channel].call_direct(handler, job)
case RunMode.queue:
self.callers[channel].queue_call(handler, job)
self.callers[channel].queue_call_advanced(handler, (job,), {}, allow_tracking=True, track=False)
case RunMode.task:
self.callers[channel].call_soon(handler, job)
case RunMode.thread:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,13 @@ async def test_nested_raises(self, caller: Caller):
pen = pg2.caller.call_soon(lambda: 1 / 0)

assert pen.exception() # pyright: ignore[reportPossiblyUnboundVariable]

async def test_queue(self, caller: Caller):
def func(val):
return val

pen = caller.queue_call(func, 1)
async with caller.create_pending_group() as pg:
assert pg.caller is caller
assert caller.queue_call(func, 2) is pen
assert pen in pg.pending
Loading