Skip to content

Commit 1c5c9c8

Browse files
ambvgraingert
andauthored
[3.9] bpo-44566: resolve differences between asynccontextmanager and contextmanager (GH-27024). (#27269)
(cherry picked from commit 7f1c330) Co-authored-by: Thomas Grainger <tagrain@gmail.com>
1 parent dae4928 commit 1c5c9c8

File tree

4 files changed

+78
-44
lines changed

4 files changed

+78
-44
lines changed

Lib/contextlib.py

Lines changed: 52 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,20 @@ def __init__(self, func, args, kwds):
9797
# for the class instead.
9898
# See http://bugs.python.org/issue19404 for more details.
9999

100-
101-
class _GeneratorContextManager(_GeneratorContextManagerBase,
102-
AbstractContextManager,
103-
ContextDecorator):
104-
"""Helper for @contextmanager decorator."""
105-
106100
def _recreate_cm(self):
107-
# _GCM instances are one-shot context managers, so the
101+
# _GCMB instances are one-shot context managers, so the
108102
# CM must be recreated each time a decorated function is
109103
# called
110104
return self.__class__(self.func, self.args, self.kwds)
111105

106+
107+
class _GeneratorContextManager(
108+
_GeneratorContextManagerBase,
109+
AbstractContextManager,
110+
ContextDecorator,
111+
):
112+
"""Helper for @contextmanager decorator."""
113+
112114
def __enter__(self):
113115
# do not keep args and kwds alive unnecessarily
114116
# they are only needed for recreation, which is not possible anymore
@@ -118,8 +120,8 @@ def __enter__(self):
118120
except StopIteration:
119121
raise RuntimeError("generator didn't yield") from None
120122

121-
def __exit__(self, type, value, traceback):
122-
if type is None:
123+
def __exit__(self, typ, value, traceback):
124+
if typ is None:
123125
try:
124126
next(self.gen)
125127
except StopIteration:
@@ -130,9 +132,9 @@ def __exit__(self, type, value, traceback):
130132
if value is None:
131133
# Need to force instantiation so we can reliably
132134
# tell if we get the same exception back
133-
value = type()
135+
value = typ()
134136
try:
135-
self.gen.throw(type, value, traceback)
137+
self.gen.throw(typ, value, traceback)
136138
except StopIteration as exc:
137139
# Suppress StopIteration *unless* it's the same exception that
138140
# was passed to throw(). This prevents a StopIteration
@@ -142,35 +144,39 @@ def __exit__(self, type, value, traceback):
142144
# Don't re-raise the passed in exception. (issue27122)
143145
if exc is value:
144146
return False
145-
# Likewise, avoid suppressing if a StopIteration exception
147+
# Avoid suppressing if a StopIteration exception
146148
# was passed to throw() and later wrapped into a RuntimeError
147-
# (see PEP 479).
148-
if type is StopIteration and exc.__cause__ is value:
149+
# (see PEP 479 for sync generators; async generators also
150+
# have this behavior). But do this only if the exception wrapped
151+
# by the RuntimeError is actually Stop(Async)Iteration (see
152+
# issue29692).
153+
if (
154+
isinstance(value, StopIteration)
155+
and exc.__cause__ is value
156+
):
149157
return False
150158
raise
151-
except:
159+
except BaseException as exc:
152160
# only re-raise if it's *not* the exception that was
153161
# passed to throw(), because __exit__() must not raise
154162
# an exception unless __exit__() itself failed. But throw()
155163
# has to raise the exception to signal propagation, so this
156164
# fixes the impedance mismatch between the throw() protocol
157165
# and the __exit__() protocol.
158-
#
159-
# This cannot use 'except BaseException as exc' (as in the
160-
# async implementation) to maintain compatibility with
161-
# Python 2, where old-style class exceptions are not caught
162-
# by 'except BaseException'.
163-
if sys.exc_info()[1] is value:
164-
return False
165-
raise
166+
if exc is not value:
167+
raise
168+
return False
166169
raise RuntimeError("generator didn't stop after throw()")
167170

168171

169172
class _AsyncGeneratorContextManager(_GeneratorContextManagerBase,
170173
AbstractAsyncContextManager):
171-
"""Helper for @asynccontextmanager."""
174+
"""Helper for @asynccontextmanager decorator."""
172175

173176
async def __aenter__(self):
177+
# do not keep args and kwds alive unnecessarily
178+
# they are only needed for recreation, which is not possible anymore
179+
del self.args, self.kwds, self.func
174180
try:
175181
return await self.gen.__anext__()
176182
except StopAsyncIteration:
@@ -181,35 +187,48 @@ async def __aexit__(self, typ, value, traceback):
181187
try:
182188
await self.gen.__anext__()
183189
except StopAsyncIteration:
184-
return
190+
return False
185191
else:
186192
raise RuntimeError("generator didn't stop")
187193
else:
188194
if value is None:
195+
# Need to force instantiation so we can reliably
196+
# tell if we get the same exception back
189197
value = typ()
190-
# See _GeneratorContextManager.__exit__ for comments on subtleties
191-
# in this implementation
192198
try:
193199
await self.gen.athrow(typ, value, traceback)
194-
raise RuntimeError("generator didn't stop after athrow()")
195200
except StopAsyncIteration as exc:
201+
# Suppress StopIteration *unless* it's the same exception that
202+
# was passed to throw(). This prevents a StopIteration
203+
# raised inside the "with" statement from being suppressed.
196204
return exc is not value
197205
except RuntimeError as exc:
206+
# Don't re-raise the passed in exception. (issue27122)
198207
if exc is value:
199208
return False
200-
# Avoid suppressing if a StopIteration exception
201-
# was passed to throw() and later wrapped into a RuntimeError
209+
# Avoid suppressing if a Stop(Async)Iteration exception
210+
# was passed to athrow() and later wrapped into a RuntimeError
202211
# (see PEP 479 for sync generators; async generators also
203212
# have this behavior). But do this only if the exception wrapped
204213
# by the RuntimeError is actully Stop(Async)Iteration (see
205214
# issue29692).
206-
if isinstance(value, (StopIteration, StopAsyncIteration)):
207-
if exc.__cause__ is value:
208-
return False
215+
if (
216+
isinstance(value, (StopIteration, StopAsyncIteration))
217+
and exc.__cause__ is value
218+
):
219+
return False
209220
raise
210221
except BaseException as exc:
222+
# only re-raise if it's *not* the exception that was
223+
# passed to throw(), because __exit__() must not raise
224+
# an exception unless __exit__() itself failed. But throw()
225+
# has to raise the exception to signal propagation, so this
226+
# fixes the impedance mismatch between the throw() protocol
227+
# and the __exit__() protocol.
211228
if exc is not value:
212229
raise
230+
return False
231+
raise RuntimeError("generator didn't stop after athrow()")
213232

214233

215234
def contextmanager(func):

Lib/test/test_contextlib.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,22 @@ def woohoo():
125125
self.assertEqual(state, [1, 42, 999])
126126

127127
def test_contextmanager_except_stopiter(self):
128-
stop_exc = StopIteration('spam')
129128
@contextmanager
130129
def woohoo():
131130
yield
132-
try:
133-
with self.assertWarnsRegex(DeprecationWarning,
134-
"StopIteration"):
135-
with woohoo():
136-
raise stop_exc
137-
except Exception as ex:
138-
self.assertIs(ex, stop_exc)
139-
else:
140-
self.fail('StopIteration was suppressed')
131+
132+
class StopIterationSubclass(StopIteration):
133+
pass
134+
135+
for stop_exc in (StopIteration('spam'), StopIterationSubclass('spam')):
136+
with self.subTest(type=type(stop_exc)):
137+
try:
138+
with woohoo():
139+
raise stop_exc
140+
except Exception as ex:
141+
self.assertIs(ex, stop_exc)
142+
else:
143+
self.fail(f'{stop_exc} was suppressed')
141144

142145
def test_contextmanager_except_pep479(self):
143146
code = """\

Lib/test/test_contextlib_async.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,18 @@ async def test_contextmanager_except_stopiter(self):
207207
async def woohoo():
208208
yield
209209

210-
for stop_exc in (StopIteration('spam'), StopAsyncIteration('ham')):
210+
class StopIterationSubclass(StopIteration):
211+
pass
212+
213+
class StopAsyncIterationSubclass(StopAsyncIteration):
214+
pass
215+
216+
for stop_exc in (
217+
StopIteration('spam'),
218+
StopAsyncIteration('ham'),
219+
StopIterationSubclass('spam'),
220+
StopAsyncIterationSubclass('spam')
221+
):
211222
with self.subTest(type=type(stop_exc)):
212223
try:
213224
async with woohoo():
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
handle StopIteration subclass raised from @contextlib.contextmanager generator

0 commit comments

Comments
 (0)