Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def set_value():
else:
set_value()

def _make_cancelled_error(self) -> FutureCancelledError:
return FutureCancelledError(self._cancelled) if isinstance(self._cancelled, str) else FutureCancelledError()

def done(self) -> bool:
"""Return True if the Future is done.

Expand All @@ -163,13 +166,18 @@ def add_done_callback(self, fn: Callable[[Self], object]) -> None:
else:
self.get_caller().call_no_context(fn, self)

def cancel(self) -> bool:
def cancel(self, msg: str | None = None) -> bool:
"""Cancel the Future and schedule callbacks (thread-safe using Caller).

Args:
msg: The message to use when raising a FutureCancelledError.

Returns if it has been cancelled.
"""
if not self.done():
self._cancelled = True
if msg and isinstance(self._cancelled, str):
msg = f"{self._cancelled}\n{msg}"
self._cancelled = msg or self._cancelled or True
if scope := self._cancel_scope:
if threading.current_thread() is self.thread:
scope.cancel()
Expand All @@ -179,7 +187,7 @@ def cancel(self) -> bool:

def cancelled(self) -> bool:
"""Return True if the Future is cancelled."""
return self._cancelled
return bool(self._cancelled)

def exception(self) -> BaseException | None:
"""Return the exception that was set on the Future.
Expand All @@ -189,7 +197,7 @@ def exception(self) -> BaseException | None:
If the Future isn't done yet, this method raises an [InvalidStateError][async_kernel.caller.InvalidStateError] exception.
"""
if self._cancelled:
raise FutureCancelledError
raise self._make_cancelled_error()
if not self.done():
raise InvalidStateError
return self._exception
Expand All @@ -207,8 +215,8 @@ def remove_done_callback(self, fn: Callable[[Self], object], /) -> int:

def set_cancel_scope(self, scope: anyio.CancelScope) -> None:
"Provide a cancel scope for cancellation."
if self._cancelled:
scope.cancel()
if self._cancelled or self._cancel_scope:
raise InvalidStateError
self._cancel_scope = scope

def get_caller(self) -> Caller:
Expand Down Expand Up @@ -358,6 +366,9 @@ async def _wrap_call(
args: tuple,
kwargs: dict,
) -> None:
if fut.cancelled():
fut.set_exception(fut._make_cancelled_error()) # pyright: ignore[reportPrivateUsage]
return
try:
with anyio.CancelScope() as scope:
fut.set_cancel_scope(scope)
Expand All @@ -378,7 +389,7 @@ async def _wrap_call(
self._outstanding -= 1 # # update first for _to_thread_on_done
if not fut.done():
if isinstance(e, self._cancelled_exception_class):
e = FutureCancelledError()
e = fut._make_cancelled_error() # pyright: ignore[reportPrivateUsage]
else:
self.log.exception("Exception occurred while running %s", func, exc_info=e)
fut.set_exception(e)
Expand Down
27 changes: 17 additions & 10 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,17 @@ async def test_set_result_twice_raises(self):
with pytest.raises(RuntimeError):
fut.set_result(2)

async def test_set_cancel_scope_twice_raises(self):
fut = Future()
with anyio.CancelScope() as cancel_scope:
fut.set_cancel_scope(cancel_scope)
with pytest.raises(InvalidStateError):
fut.set_cancel_scope(cancel_scope)

async def test_set_exception_twice_raises(self):
fut = Future()
fut.set_exception(ValueError())
with pytest.raises(RuntimeError):
with pytest.raises(InvalidStateError):
fut.set_exception(ValueError())

async def test_set_result_after_exception_raises(self):
Expand Down Expand Up @@ -424,13 +431,10 @@ async def close_tsc():

@pytest.mark.parametrize("mode", ["async", "blocking"])
@pytest.mark.parametrize("cancel_mode", ["local", "thread"])
@pytest.mark.parametrize("msg", ["msg", None, "twice"])
async def test_cancel(
self, anyio_backend, mode: Literal["async", "blocking"], cancel_mode: Literal["local", "thread"]
self, anyio_backend, mode: Literal["async", "blocking"], cancel_mode: Literal["local", "thread"], msg
):
async def async_func():
await anyio.sleep(10)
raise RuntimeError

def blocking_func():
import time # noqa: PLC0415

Expand All @@ -439,18 +443,21 @@ def blocking_func():
my_func = blocking_func
match mode:
case "async":
my_func = async_func
my_func = anyio.sleep_forever
case "blocking":
my_func = blocking_func

async with Caller(create=True) as caller:
fut = caller.call_soon(my_func)
if cancel_mode == "local":
fut.cancel()
fut.cancel(msg)
if msg == "twice":
fut.cancel(msg)
msg = f"{msg}(?s:.){msg}"
else:
caller.to_thread(fut.cancel)
caller.to_thread(fut.cancel, msg)

with pytest.raises(anyio.ClosedResourceError):
with pytest.raises(FutureCancelledError, match=msg if msg else ""):
await fut

async def test_cancelled_waiter(self, anyio_backend):
Expand Down
Loading