Skip to content

Commit

Permalink
Don't swallow CancelledError in MapAsyncIterator
Browse files Browse the repository at this point in the history
As discussed in #131, better than conversion to StopAsyncIteration.
  • Loading branch information
Cito committed May 2, 2021
1 parent 62ddc6c commit c6f73a8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
13 changes: 5 additions & 8 deletions src/graphql/subscription/map_async_iterator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from asyncio import Event, ensure_future, Future, wait, CancelledError
from asyncio import CancelledError, Event, Future, ensure_future, wait
from concurrent.futures import FIRST_COMPLETED
from inspect import isasyncgen, isawaitable
from typing import cast, Any, AsyncIterable, Callable, Optional, Set, Type, Union
Expand Down Expand Up @@ -43,19 +43,16 @@ async def __anext__(self) -> Any:
aclose = ensure_future(self._close_event.wait())
anext = ensure_future(self.iterator.__anext__())

# Suppress the StopAsyncIteration exception warning when the
# iterator is cancelled.
anext.add_done_callback(lambda *args: anext.exception())
try:
pending: Set[Future] = (
await wait([aclose, anext], return_when=FIRST_COMPLETED)
)[1]
except CancelledError as e:
# The iterator is cancelled
except CancelledError:
# cancel underlying tasks and close
aclose.cancel()
anext.cancel()
self.is_closed = True
raise StopAsyncIteration from e
await self.aclose()
raise # re-raise the cancellation

for task in pending:
task.cancel()
Expand Down
48 changes: 28 additions & 20 deletions tests/subscription/test_map_async_iterator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from asyncio import Event, ensure_future, CancelledError, sleep, Queue
from asyncio import CancelledError, Event, ensure_future, sleep

from pytest import mark, raises

Expand Down Expand Up @@ -459,39 +459,47 @@ async def aclose(self):
assert not iterator.is_closed

@mark.asyncio
async def cancel_async_iterator_while_waiting():
async def can_cancel_async_iterator_while_waiting():
class Iterator:
def __init__(self):
self.queue: Queue[int] = Queue()
self.queue.put_nowait(1) # suppress coverage warning
self.cancelled = False
self.is_closed = False
self.value = 1

def __aiter__(self):
return self

async def __anext__(self):
try:
return await self.queue.get()
except BaseException:
self.cancelled = True
await sleep(0.5)
return self.value # pragma: no cover
except CancelledError:
self.value = -1
raise

async def aclose(self):
self.is_closed = True

iterator = Iterator()
doubles = MapAsyncIterator(iterator, lambda x: x + x)
doubles = MapAsyncIterator(iterator, lambda x: x + x) # pragma: no cover exit
cancelled = False

async def iterator_task():
nonlocal cancelled
try:
async for double in doubles:
pass
# If cancellation is handled using StopAsyncIteration, it will reach
# here.
except CancelledError: # pragma: no cover
# Otherwise it should reach here
pass
async for _ in doubles:
assert False # pragma: no cover
except CancelledError:
cancelled = True

task = ensure_future(iterator_task())
await sleep(0.1)
await doubles.aclose()
await sleep(0.05)
assert not cancelled
assert not doubles.is_closed
assert iterator.value == 1
assert not iterator.is_closed
task.cancel()
await sleep(0.1)
assert iterator.cancelled
await sleep(0.05)
assert cancelled
assert iterator.value == -1
assert doubles.is_closed
assert iterator.is_closed

0 comments on commit c6f73a8

Please sign in to comment.