Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions aioitertools/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
from contextlib import suppress
import time
from typing import (
Any,
Expand Down Expand Up @@ -119,56 +120,60 @@ async def generator(x):
... # intermixed values yielded from gen1 and gen2
"""

exc_queue: asyncio.Queue[Exception] = asyncio.Queue()
queue: asyncio.Queue[T] = asyncio.Queue()
queue: asyncio.Queue[dict] = asyncio.Queue()

tailer_count: int = 0

async def tailer(iterable: AsyncIterable[T]) -> None:
nonlocal tailer_count

async def tailer(iter: AsyncIterable[T]) -> None:
try:
async for item in iter:
await queue.put(item)
async for item in iterable:
await queue.put({"value": item})
except asyncio.CancelledError:
if isinstance(iter, AsyncGenerator): # pragma:nocover
await iter.aclose()
if isinstance(iterable, AsyncGenerator): # pragma:nocover
with suppress(Exception):
await iterable.aclose()
raise
except Exception as e:
await exc_queue.put(e)
except Exception as exc:
await queue.put({"exception": exc})
finally:
tailer_count -= 1

if tailer_count == 0:
await queue.put({"done": True})

tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
pending = set(tasks)

if not tasks:
# Nothing to do
return

tailer_count = len(tasks)

try:
while pending:
try:
exc = exc_queue.get_nowait()
while True:
i = await queue.get()

if "value" in i:
yield i["value"]
elif "exception" in i:
if return_exceptions:
yield exc # type: ignore
yield i["exception"]
else:
raise exc
except asyncio.QueueEmpty:
pass

try:
value = queue.get_nowait()
yield value
except asyncio.QueueEmpty:
for task in list(pending):
if task.done():
pending.remove(task)
await asyncio.sleep(0.001)

raise i["exception"]
elif "done" in i:
break
except (asyncio.CancelledError, GeneratorExit):
pass

finally:
for task in tasks:
if not task.done():
task.cancel()

for task in tasks:
try:
with suppress(asyncio.CancelledError):
await task
except asyncio.CancelledError:
pass


@deprecated_wait_param
Expand Down
25 changes: 25 additions & 0 deletions aioitertools/tests/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ async def gen():
self.assertEqual(30, len(results))
self.assertListEqual(sorted(expected), sorted(results))

@async_test
async def test_as_generated_no_iterables(self):
gens = []
expected = []
results = []
async for value in aio.as_generated(gens):
results.append(value)
self.assertEqual(0, len(results))
self.assertListEqual(sorted(expected), sorted(results))

@async_test
async def test_as_generated_empty_iterables(self):
async def gen(stop):
for i in range(stop):
yield i
await asyncio.sleep(0)

gens = [gen(0), gen(1), gen(2)]
expected = [0, 0, 1]
results = []
async for value in aio.as_generated(gens):
results.append(value)
self.assertEqual(3, len(results))
self.assertListEqual(sorted(expected), sorted(results))

@async_test
async def test_as_generated_exception(self):
async def gen1():
Expand Down