Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ Homepage = "https://fleming79.github.io/async-kernel"
Documentation = "https://fleming79.github.io/async-kernel"
Source = "https://github.com/fleming79/async-kernel"
Tracker = "https://github.com/fleming79/async-kernel/issues"
Changelog = "https://fleming79.github.io/async-kernel/latest/about/changelog/"


[project.scripts]
async-kernel = "async_kernel.command:command_line"
Expand Down
9 changes: 8 additions & 1 deletion src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class Caller:
_stopped = False
_protected = False
_running = False
_future_var: contextvars.ContextVar[Future | None] = contextvars.ContextVar("_future_var", default=None)
thread: threading.Thread
"The thread in which the caller will run."
backend: Backend
Expand Down Expand Up @@ -419,6 +420,7 @@ async def _wrap_call(
args: tuple,
kwargs: dict,
) -> None:
self._future_var.set(fut)
if fut.cancelled():
fut.set_result(cast("T", None)) # This will cancel
return
Expand All @@ -429,7 +431,7 @@ async def _wrap_call(
if (delay_ := delay - time.monotonic() + starttime) > 0:
await anyio.sleep(float(delay_))
result = func(*args, **kwargs) if callable(func) else func # pyright: ignore[reportAssignmentType]
if inspect.isawaitable(result):
if inspect.isawaitable(result) and result is not fut:
result: T = await result
if fut.cancelled() and not scope.cancel_called:
scope.cancel()
Expand Down Expand Up @@ -726,6 +728,11 @@ async def caller_context() -> None:
assert isinstance(caller, cls)
return caller

@classmethod
def current_future(cls) -> Future[Any] | None:
"""Return the current future when called from inside a function scheduled by Caller."""
return cls._future_var.get()

@classmethod
async def as_completed(
cls,
Expand Down
6 changes: 6 additions & 0 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ async def my_func():
val = await fut
assert val is True

async def test_current_future(self, anyio_backend):
async with Caller(create=True) as caller:
fut = caller.call_soon(Caller.current_future)
res = await fut
assert res is fut

async def test_closed_in_call_soon(self, anyio_backend):
ready = threading.Event()
proceed = threading.Event()
Expand Down
Loading