Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def schedule_call(
args: tuple,
kwargs: dict,
context: contextvars.Context | None = None,
trackers: type[PendingTracker] | tuple[type[PendingTracker], ...] = PendingTracker,
/,
**metadata: Any,
) -> Pending[T]:
Expand All @@ -628,12 +629,13 @@ def schedule_call(
args: Arguments corresponding to in the call to `func`.
kwargs: Keyword arguments to use with in the call to `func`.
context: The context to use, if not provided the current context is used.
trackers: The tracker subclasses of active trackers which to add the pending.
**metadata: Additional metadata to store in the instance.
"""
if self._state in {CallerState.stopping, CallerState.stopped}:
msg = f"{self} is {self._state.name}!"
raise RuntimeError(msg)
pen = Pending(func=func, args=args, kwargs=kwargs, caller=self, **metadata)
pen = Pending(trackers, func=func, args=args, kwargs=kwargs, caller=self, **metadata)
self._queue.append((context or contextvars.copy_context(), pen))
self._resume()
return pen
Expand Down
4 changes: 1 addition & 3 deletions src/async_kernel/pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,7 @@ def metadata(self) -> dict[str, Any]:
"""
return self._metadata_mappings[id(self)]

def __init__(
self, *, trackers: type[PendingTracker] | tuple[type[PendingTracker], ...] = PendingTracker, **metadata: Any
):
def __init__(self, trackers: type[PendingTracker] | tuple[type[PendingTracker], ...] = (), **metadata: Any):
"""
Initializes a new Pending object with optional creation options and metadata.

Expand Down
23 changes: 15 additions & 8 deletions tests/test_pending.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from aiologic.meta import await_for

from async_kernel.caller import Caller
from async_kernel.pending import InvalidStateError, Pending, PendingCancelled, PendingGroup, PendingManager
from async_kernel.pending import (
InvalidStateError,
Pending,
PendingCancelled,
PendingGroup,
PendingManager,
PendingTracker,
)
from async_kernel.typing import Backend


Expand Down Expand Up @@ -318,7 +325,7 @@ async def recursive():

async def test_discard(self, pm: PendingManager, caller: Caller):
pm.add(pen1 := caller.call_soon(lambda: 1 + 1))
pm.add(pen2 := Pending())
pm.add(pen2 := Pending(PendingManager))
assert await pen1 == 2
pm.discard(pen2)

Expand Down Expand Up @@ -427,7 +434,7 @@ async def test_wait_exception(self, caller: Caller):
with pytest.raises(PendingCancelled): # noqa: PT012
async with PendingGroup():
pen = caller.call_soon(anyio.sleep_forever)
Pending().set_exception(RuntimeError("stop"))
Pending(PendingGroup).set_exception(RuntimeError("stop"))
assert pen.cancelled() # pyright: ignore[reportPossiblyUnboundVariable]

async def test_cancelled_by_pending(self, caller: Caller):
Expand All @@ -439,7 +446,7 @@ async def test_cancelled_by_pending(self, caller: Caller):

async def test_discard(self, caller: Caller):
async with PendingGroup() as pg:
pen = Pending()
pen = Pending(PendingGroup)
pg.discard(pen)
assert pen not in pg.pending

Expand Down Expand Up @@ -499,19 +506,19 @@ async def test_tracking(self, caller: Caller):
token = pm.start_tracking()
try:
async with caller.create_pending_group() as pg:
pen = Pending()
pen = Pending(PendingTracker)
assert pen in pg.pending
assert pen in pm.pending

pen_no_track = Pending(trackers=())
pen_no_track = Pending()
assert pen_no_track not in pm.pending
assert pen_no_track not in pg.pending

pen_pg = Pending(trackers=(PendingGroup,))
pen_pg = Pending(PendingGroup)
assert pen_pg in pg.pending
assert pen_pg not in pm.pending

pen_pm = Pending(trackers=(PendingManager,))
pen_pm = Pending(PendingManager)
assert pen_pm in pm.pending
assert pen_pm not in pg.pending
pg._pending.clear() # pyright: ignore[reportPrivateUsage]
Expand Down
Loading