Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions src/async_kernel/caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def queue_call(
async def queue_loop() -> None:
pen = self.current_pending()
assert pen
item = result = None
result = None
try:
while True:
await checkpoint(self.backend)
Expand All @@ -783,14 +783,14 @@ async def queue_loop() -> None:
try:
result = item[0](*item[1], **item[2])
if inspect.iscoroutine(object=result):
await result
result = await result
except (anyio.get_cancelled_exc_class(), Exception) as e:
if pen.cancelled():
raise
self.log.exception("Execution %s failed", item, exc_info=e)
else:
pen.set_result(result, reset=True)
del item # pyright: ignore[reportPossiblyUnboundVariable]
item = result = None
event = create_async_event()
pen.metadata["resume"] = event.set
await checkpoint(self.backend)
Expand Down Expand Up @@ -840,7 +840,7 @@ async def as_completed(
resume = noop
result_ready = noop
done_results: deque[Pending[T]] = deque()
results: set[Pending[T]] = set()
unfinished: set[Pending[T]] = set()
done = False
current_pending = self.current_pending()
if isinstance(items, set | list | tuple):
Expand All @@ -863,11 +863,11 @@ async def iter_items():
pen = cast("Pending[T]", self.call_soon(await_for, pen))
pen.add_done_callback(result_done)
if not pen.done():
results.add(pen)
if max_concurrent_ and (len(results) == max_concurrent_):
unfinished.add(pen)
if max_concurrent_ and len(unfinished) == max_concurrent_:
event = create_async_event()
resume = event.set
if len(results) == max_concurrent_:
if len(unfinished) == max_concurrent_:
await event
resume = noop
await checkpoint(self.backend)
Expand All @@ -881,24 +881,22 @@ async def iter_items():

pen_ = self.call_soon(iter_items)
try:
while (not done) or results or done_results:
while (not done) or unfinished or done_results:
if done_results:
pen = done_results.popleft()
results.discard(pen)
# Ensure all done callbacks are complete.
await pen.wait(result=False)
unfinished.discard(pen)
yield pen
else:
if max_concurrent_ and len(results) < max_concurrent_:
if max_concurrent_ and len(unfinished) < max_concurrent_:
resume()
event = create_async_event()
result_ready = event.set
if not done or results:
if not done or unfinished:
await event
result_ready = noop
finally:
pen_.cancel()
for pen in results:
for pen in unfinished:
pen.remove_done_callback(result_done)
if cancel_unfinished:
pen.cancel("Cancelled by as_completed")
Expand Down Expand Up @@ -934,7 +932,8 @@ async def wait(
if pending:
with anyio.move_on_after(timeout):
async for pen in self.as_completed(pending.copy(), cancel_unfinished=False):
_ = (pending.discard(pen), done.add(pen))
pending.discard(pen)
done.add(pen)
if return_when == "FIRST_COMPLETED":
break
if return_when == "FIRST_EXCEPTION" and (pen.cancelled() or pen.exception()):
Expand Down
16 changes: 13 additions & 3 deletions tests/test_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ async def func(a, b, /, *, results, firstcall=firstcall):
assert not caller.queue_get(func)

async def test_queue_call_result(self, caller: Caller):
def pass_through(n):
async def pass_through(n):
return n

pen = Pending()
Expand Down Expand Up @@ -631,14 +631,24 @@ async def f(i: int):
results.add(await pen)
assert results == {0, 1}

async def test_as_completed_queue(self, caller: Caller):
async def f(i: int):
await anyio.sleep(i * 0.001)
return i

results = set()
async for pen in caller.as_completed(caller.queue_call(f, i) for i in range(2)):
results.add(pen.result())
assert results == {1}

async def test_wait_awaitables(self, caller: Caller):
async def f(i: int):
await anyio.sleep(i * 0.001)
return i

done, pending = await caller.wait(f(i) for i in range(2))
done, pending = await caller.wait((caller.queue_call(f, 1), caller.call_soon(f, 2), caller.to_thread(f, 3)))
assert not pending
assert {pen.result() for pen in done} == {0, 1}
assert {pen.result() for pen in done} == {1, 2, 3}

async def test_worker_in_pool_shutdown(self, caller: Caller, mocker):
pen1 = caller.to_thread(threading.get_ident)
Expand Down
Loading