Skip to content

Commit acd4562

Browse files
committed
Add @asynccontextmanager and get tests passing again
1 parent 7d7d5d5 commit acd4562

File tree

4 files changed

+260
-7
lines changed

4 files changed

+260
-7
lines changed

async_generator/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
isasyncgen,
77
isasyncgenfunction,
88
)
9-
from ._util import aclosing
9+
from ._util import aclosing, asynccontextmanager
1010

1111
__all__ = [
1212
"async_generator", "yield_", "yield_from_", "aclosing", "isasyncgen",
13-
"isasyncgenfunction"
13+
"isasyncgenfunction", "asynccontextmanager",
1414
]

async_generator/_tests/test_async_generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import wraps
88
import gc
99

10-
from . import (
10+
from .. import (
1111
async_generator,
1212
yield_,
1313
yield_from_,
@@ -717,7 +717,7 @@ async def f():
717717
# Finicky tests to check that the overly clever ctype stuff has plausible
718718
# refcounting
719719

720-
from . import impl
720+
from .. import _impl
721721

722722

723723
@pytest.mark.skipif(not hasattr(sys, "getrefcount"), reason="CPython only")
@@ -728,12 +728,12 @@ def test_refcnt():
728728
print(sys.getrefcount(x))
729729
print(sys.getrefcount(x))
730730
base_count = sys.getrefcount(x)
731-
l = [impl._wrap(x) for _ in range(100)]
731+
l = [_impl._wrap(x) for _ in range(100)]
732732
print(sys.getrefcount(x))
733733
print(sys.getrefcount(x))
734734
print(sys.getrefcount(x))
735735
assert sys.getrefcount(x) >= base_count + 100
736-
l2 = [impl._unwrap(box) for box in l]
736+
l2 = [_impl._unwrap(box) for box in l]
737737
assert sys.getrefcount(x) >= base_count + 200
738738
print(sys.getrefcount(x))
739739
print(sys.getrefcount(x))

async_generator/_tests/test_util.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from . import aclosing, async_generator, yield_
3+
from .. import aclosing, async_generator, yield_, asynccontextmanager
44

55

66
@async_generator
@@ -34,3 +34,158 @@ async def test_aclosing():
3434
except ValueError:
3535
pass
3636
assert closed_slot[0]
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_contextmanager_do_not_unchain_non_stopiteration_exceptions():
41+
@asynccontextmanager
42+
@async_generator
43+
async def manager_issue29692():
44+
try:
45+
await yield_()
46+
except Exception as exc:
47+
raise RuntimeError('issue29692:Chained') from exc
48+
49+
with pytest.raises(RuntimeError) as excinfo:
50+
async with manager_issue29692():
51+
raise ZeroDivisionError
52+
assert excinfo.value.args[0] == 'issue29692:Chained'
53+
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)
54+
55+
# This is a little funky because of implementation details in
56+
# async_generator It can all go away once we stop supporting Python3.5
57+
with pytest.raises(RuntimeError) as excinfo:
58+
async with manager_issue29692():
59+
exc = StopIteration('issue29692:Unchained')
60+
raise exc
61+
assert excinfo.value.args[0] == 'issue29692:Chained'
62+
cause = excinfo.value.__cause__
63+
assert cause.args[0] == 'generator raised StopIteration'
64+
assert cause.__cause__ is exc
65+
66+
with pytest.raises(StopAsyncIteration) as excinfo:
67+
async with manager_issue29692():
68+
raise StopAsyncIteration('issue29692:Unchained')
69+
assert excinfo.value.args[0] == 'issue29692:Unchained'
70+
assert excinfo.value.__cause__ is None
71+
72+
@asynccontextmanager
73+
@async_generator
74+
async def noop_async_context_manager():
75+
await yield_()
76+
77+
with pytest.raises(StopIteration):
78+
async with noop_async_context_manager():
79+
raise StopIteration
80+
81+
82+
# Native async generators are only available from Python 3.6 and onwards
83+
nativeasyncgenerators = True
84+
try:
85+
exec(
86+
"""
87+
@asynccontextmanager
88+
async def manager_issue29692_2():
89+
try:
90+
yield
91+
except Exception as exc:
92+
raise RuntimeError('issue29692:Chained') from exc
93+
"""
94+
)
95+
except SyntaxError:
96+
nativeasyncgenerators = False
97+
98+
99+
@pytest.mark.skipif(
100+
not nativeasyncgenerators,
101+
reason="Python < 3.6 doesn't have native async generators"
102+
)
103+
@pytest.mark.asyncio
104+
async def test_native_contextmanager_do_not_unchain_non_stopiteration_exceptions(
105+
):
106+
107+
with pytest.raises(RuntimeError) as excinfo:
108+
async with manager_issue29692_2():
109+
raise ZeroDivisionError
110+
assert excinfo.value.args[0] == 'issue29692:Chained'
111+
assert isinstance(excinfo.value.__cause__, ZeroDivisionError)
112+
113+
for cls in [StopIteration, StopAsyncIteration]:
114+
with pytest.raises(cls) as excinfo:
115+
async with manager_issue29692_2():
116+
raise cls('issue29692:Unchained')
117+
assert excinfo.value.args[0] == 'issue29692:Unchained'
118+
assert excinfo.value.__cause__ is None
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_asynccontextmanager_exception_passthrough():
123+
# This was the cause of annoying coverage flapping, see gh-140
124+
@asynccontextmanager
125+
@async_generator
126+
async def noop_async_context_manager():
127+
await yield_()
128+
129+
for exc_type in [StopAsyncIteration, RuntimeError, ValueError]:
130+
with pytest.raises(exc_type):
131+
async with noop_async_context_manager():
132+
raise exc_type
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_asynccontextmanager_catches_exception():
137+
@asynccontextmanager
138+
@async_generator
139+
async def catch_it():
140+
with pytest.raises(ValueError):
141+
await yield_()
142+
143+
async with catch_it():
144+
raise ValueError
145+
146+
147+
@pytest.mark.asyncio
148+
async def test_asynccontextmanager_no_yield():
149+
@asynccontextmanager
150+
@async_generator
151+
async def yeehaw():
152+
pass
153+
154+
with pytest.raises(RuntimeError) as excinfo:
155+
async with yeehaw():
156+
assert False # pragma: no cover
157+
158+
assert "didn't yield" in str(excinfo.value)
159+
160+
161+
@pytest.mark.asyncio
162+
async def test_asynccontextmanager_too_many_yields():
163+
@asynccontextmanager
164+
@async_generator
165+
async def doubleyield():
166+
try:
167+
await yield_()
168+
except Exception:
169+
pass
170+
await yield_()
171+
172+
with pytest.raises(RuntimeError) as excinfo:
173+
async with doubleyield():
174+
pass
175+
176+
assert "didn't stop" in str(excinfo.value)
177+
178+
with pytest.raises(RuntimeError) as excinfo:
179+
async with doubleyield():
180+
raise ValueError
181+
182+
assert "didn't stop after athrow" in str(excinfo.value)
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_asynccontextmanager_requires_asyncgenfunction():
187+
with pytest.raises(TypeError):
188+
189+
@asynccontextmanager
190+
def syncgen(): # pragma: no cover
191+
yield

async_generator/_util.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import sys
2+
from functools import wraps
3+
from ._impl import isasyncgenfunction
4+
15
class aclosing:
26
def __init__(self, aiter):
37
self._aiter = aiter
@@ -7,3 +11,97 @@ async def __aenter__(self):
711

812
async def __aexit__(self, *args):
913
await self._aiter.aclose()
14+
15+
16+
# Very much derived from the one in contextlib, by copy/pasting and then
17+
# asyncifying everything. (Also I dropped the obscure support for using
18+
# context managers as function decorators. It could be re-added; I just
19+
# couldn't be bothered.)
20+
# So this is a derivative work licensed under the PSF License, which requires
21+
# the following notice:
22+
#
23+
# Copyright © 2001-2017 Python Software Foundation; All Rights Reserved
24+
class _AsyncGeneratorContextManager:
25+
def __init__(self, func, args, kwds):
26+
self._func_name = func.__name__
27+
self._agen = func(*args, **kwds).__aiter__()
28+
29+
async def __aenter__(self):
30+
if sys.version_info < (3, 5, 2):
31+
self._agen = await self._agen
32+
try:
33+
return await self._agen.asend(None)
34+
except StopAsyncIteration:
35+
raise RuntimeError("async generator didn't yield") from None
36+
37+
async def __aexit__(self, type, value, traceback):
38+
if type is None:
39+
try:
40+
await self._agen.asend(None)
41+
except StopAsyncIteration:
42+
return False
43+
else:
44+
raise RuntimeError("async generator didn't stop")
45+
else:
46+
# It used to be possible to have type != None, value == None:
47+
# https://bugs.python.org/issue1705170
48+
# but AFAICT this can't happen anymore.
49+
assert value is not None
50+
try:
51+
await self._agen.athrow(type, value, traceback)
52+
raise RuntimeError(
53+
"async generator didn't stop after athrow()"
54+
)
55+
except StopAsyncIteration as exc:
56+
# Suppress StopIteration *unless* it's the same exception that
57+
# was passed to throw(). This prevents a StopIteration
58+
# raised inside the "with" statement from being suppressed.
59+
return (exc is not value)
60+
except RuntimeError as exc:
61+
# Don't re-raise the passed in exception. (issue27112)
62+
if exc is value:
63+
return False
64+
# Likewise, avoid suppressing if a StopIteration exception
65+
# was passed to throw() and later wrapped into a RuntimeError
66+
# (see PEP 479).
67+
if (isinstance(value, (StopIteration, StopAsyncIteration))
68+
and exc.__cause__ is value):
69+
return False
70+
raise
71+
except:
72+
# only re-raise if it's *not* the exception that was
73+
# passed to throw(), because __exit__() must not raise
74+
# an exception unless __exit__() itself failed. But throw()
75+
# has to raise the exception to signal propagation, so this
76+
# fixes the impedance mismatch between the throw() protocol
77+
# and the __exit__() protocol.
78+
#
79+
if sys.exc_info()[1] is value:
80+
return False
81+
raise
82+
83+
def __enter__(self):
84+
raise RuntimeError(
85+
"use 'async with {func_name}(...)', not 'with {func_name}(...)'".
86+
format(func_name=self._func_name)
87+
)
88+
89+
def __exit__(self): # pragma: no cover
90+
assert False, """Never called, but should be defined"""
91+
92+
93+
def asynccontextmanager(func):
94+
"""Like @contextmanager, but async."""
95+
if not isasyncgenfunction(func):
96+
raise TypeError(
97+
"must be an async generator (native or from async_generator; "
98+
"if using @async_generator then @acontextmanager must be on top."
99+
)
100+
101+
@wraps(func)
102+
def helper(*args, **kwds):
103+
return _AsyncGeneratorContextManager(func, args, kwds)
104+
105+
# A hint for sphinxcontrib-trio:
106+
helper.__returns_acontextmanager__ = True
107+
return helper

0 commit comments

Comments
 (0)