Skip to content

Commit 0c60655

Browse files
committed
asyncontextmanager: Always close async generators
1 parent 0af86f3 commit 0c60655

File tree

2 files changed

+54
-44
lines changed

2 files changed

+54
-44
lines changed

async_generator/_tests/test_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,26 +153,34 @@ async def yeehaw():
153153

154154

155155
async def test_asynccontextmanager_too_many_yields():
156+
closed_count = 0
157+
156158
@asynccontextmanager
157159
@async_generator
158160
async def doubleyield():
159161
try:
160162
await yield_()
161163
except Exception:
162164
pass
163-
await yield_()
165+
try:
166+
await yield_()
167+
finally:
168+
nonlocal closed_count
169+
closed_count += 1
164170

165171
with pytest.raises(RuntimeError) as excinfo:
166172
async with doubleyield():
167173
pass
168174

169175
assert "didn't stop" in str(excinfo.value)
176+
assert closed_count == 1
170177

171178
with pytest.raises(RuntimeError) as excinfo:
172179
async with doubleyield():
173180
raise ValueError
174181

175182
assert "didn't stop after athrow" in str(excinfo.value)
183+
assert closed_count == 2
176184

177185

178186
async def test_asynccontextmanager_requires_asyncgenfunction():

async_generator/_util.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -36,50 +36,52 @@ async def __aenter__(self):
3636
raise RuntimeError("async generator didn't yield") from None
3737

3838
async def __aexit__(self, type, value, traceback):
39-
if type is None:
40-
try:
41-
await self._agen.asend(None)
42-
except StopAsyncIteration:
43-
return False
44-
else:
45-
raise RuntimeError("async generator didn't stop")
46-
else:
47-
# It used to be possible to have type != None, value == None:
48-
# https://bugs.python.org/issue1705170
49-
# but AFAICT this can't happen anymore.
50-
assert value is not None
51-
try:
52-
await self._agen.athrow(type, value, traceback)
53-
raise RuntimeError(
54-
"async generator didn't stop after athrow()"
55-
)
56-
except StopAsyncIteration as exc:
57-
# Suppress StopIteration *unless* it's the same exception that
58-
# was passed to throw(). This prevents a StopIteration
59-
# raised inside the "with" statement from being suppressed.
60-
return (exc is not value)
61-
except RuntimeError as exc:
62-
# Don't re-raise the passed in exception. (issue27112)
63-
if exc is value:
64-
return False
65-
# Likewise, avoid suppressing if a StopIteration exception
66-
# was passed to throw() and later wrapped into a RuntimeError
67-
# (see PEP 479).
68-
if (isinstance(value, (StopIteration, StopAsyncIteration))
69-
and exc.__cause__ is value):
39+
async with aclosing(self._agen):
40+
if type is None:
41+
try:
42+
await self._agen.asend(None)
43+
except StopAsyncIteration:
7044
return False
71-
raise
72-
except:
73-
# only re-raise if it's *not* the exception that was
74-
# passed to throw(), because __exit__() must not raise
75-
# an exception unless __exit__() itself failed. But throw()
76-
# has to raise the exception to signal propagation, so this
77-
# fixes the impedance mismatch between the throw() protocol
78-
# and the __exit__() protocol.
79-
#
80-
if sys.exc_info()[1] is value:
81-
return False
82-
raise
45+
else:
46+
raise RuntimeError("async generator didn't stop")
47+
else:
48+
# It used to be possible to have type != None, value == None:
49+
# https://bugs.python.org/issue1705170
50+
# but AFAICT this can't happen anymore.
51+
assert value is not None
52+
try:
53+
await self._agen.athrow(type, value, traceback)
54+
raise RuntimeError(
55+
"async generator didn't stop after athrow()"
56+
)
57+
except StopAsyncIteration as exc:
58+
# Suppress StopIteration *unless* it's the same exception
59+
# that was passed to throw(). This prevents a
60+
# StopIteration raised inside the "with" statement from
61+
# being suppressed.
62+
return (exc is not value)
63+
except RuntimeError as exc:
64+
# Don't re-raise the passed in exception. (issue27112)
65+
if exc is value:
66+
return False
67+
# Likewise, avoid suppressing if a StopIteration exception
68+
# was passed to throw() and later wrapped into a
69+
# RuntimeError (see PEP 479).
70+
if (isinstance(value, (StopIteration, StopAsyncIteration))
71+
and exc.__cause__ is value):
72+
return False
73+
raise
74+
except:
75+
# only re-raise if it's *not* the exception that was
76+
# passed to throw(), because __exit__() must not raise an
77+
# exception unless __exit__() itself failed. But throw()
78+
# has to raise the exception to signal propagation, so
79+
# this fixes the impedance mismatch between the throw()
80+
# protocol and the __exit__() protocol.
81+
#
82+
if sys.exc_info()[1] is value:
83+
return False
84+
raise
8385

8486
def __enter__(self):
8587
raise RuntimeError(

0 commit comments

Comments
 (0)