Skip to content

Commit 1a4f2b5

Browse files
committed
gh-124958: fix asyncio.TaskGroup and _PyFuture refcycles
1 parent 5e9e506 commit 1a4f2b5

File tree

4 files changed

+131
-9
lines changed

4 files changed

+131
-9
lines changed

Lib/asyncio/futures.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def result(self):
190190
the future is done and has an exception set, this exception is raised.
191191
"""
192192
if self._state == _CANCELLED:
193-
exc = self._make_cancelled_error()
194-
raise exc
193+
raise self._make_cancelled_error()
195194
if self._state != _FINISHED:
196195
raise exceptions.InvalidStateError('Result is not ready.')
197196
self.__log_traceback = False
@@ -208,8 +207,7 @@ def exception(self):
208207
InvalidStateError.
209208
"""
210209
if self._state == _CANCELLED:
211-
exc = self._make_cancelled_error()
212-
raise exc
210+
raise self._make_cancelled_error()
213211
if self._state != _FINISHED:
214212
raise exceptions.InvalidStateError('Exception is not set.')
215213
self.__log_traceback = False

Lib/asyncio/taskgroups.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ async def __aenter__(self):
6666
return self
6767

6868
async def __aexit__(self, et, exc, tb):
69+
try:
6970
self._exiting = True
7071

7172
if (exc is not None and
@@ -146,14 +147,22 @@ async def __aexit__(self, et, exc, tb):
146147
if self._parent_task.cancelling():
147148
self._parent_task.uncancel()
148149
self._parent_task.cancel()
150+
raise BaseExceptionGroup(
151+
'unhandled errors in a TaskGroup',
152+
self._errors,
153+
) from None
154+
finally:
149155
# Exceptions are heavy objects that can have object
150156
# cycles (bad for GC); let's not keep a reference to
151157
# a bunch of them.
152-
try:
153-
me = BaseExceptionGroup('unhandled errors in a TaskGroup', self._errors)
154-
raise me from None
155-
finally:
158+
propagate_cancellation_error = None
159+
self._parent_task = None
156160
self._errors = None
161+
self._base_error = None
162+
et = None
163+
exc = None
164+
tb = None
165+
157166

158167
def create_task(self, coro, *, name=None, context=None):
159168
"""Create a new task in this group and return it.

Lib/test/test_asyncio/test_futures.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,28 @@ def __del__(self):
659659
fut = self._new_future(loop=self.loop)
660660
fut.set_result(Evil())
661661

662+
def test_future_cancelled_result_refcycles(self):
663+
f = self._new_future(loop=self.loop)
664+
f.cancel()
665+
exc = None
666+
try:
667+
f.result()
668+
except asyncio.CancelledError as e:
669+
exc = e
670+
self.assertIsNotNone(exc)
671+
self.assertListEqual(gc.get_referrers(exc), [])
672+
673+
def test_future_cancelled_exception_refcycles(self):
674+
f = self._new_future(loop=self.loop)
675+
f.cancel()
676+
exc = None
677+
try:
678+
f.exception()
679+
except asyncio.CancelledError as e:
680+
exc = e
681+
self.assertIsNotNone(exc)
682+
self.assertListEqual(gc.get_referrers(exc), [])
683+
662684

663685
@unittest.skipUnless(hasattr(futures, '_CFuture'),
664686
'requires the C _asyncio module')

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Adapted with permission from the EdgeDB project;
22
# license: PSFL.
33

4-
4+
import gc
55
import asyncio
66
import contextvars
77
import contextlib
@@ -11,6 +11,10 @@
1111

1212
from test.test_asyncio.utils import await_without_task
1313

14+
if False:
15+
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = asyncio.tasks._PyTask
16+
asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = asyncio.futures._PyFuture
17+
1418

1519
# To prevent a warning "test altered the execution environment"
1620
def tearDownModule():
@@ -899,6 +903,95 @@ async def outer():
899903

900904
await outer()
901905

906+
async def test_exception_refcycles_direct(self):
907+
"""Test that TaskGroup doesn't keep a reference to the raised ExceptionGroup"""
908+
tg = asyncio.TaskGroup()
909+
exc = None
910+
911+
class _Done(Exception):
912+
pass
913+
914+
try:
915+
async with tg:
916+
raise _Done
917+
except ExceptionGroup as e:
918+
exc = e
919+
920+
self.assertIsNotNone(exc)
921+
self.assertListEqual(gc.get_referrers(exc), [])
922+
923+
924+
async def test_exception_refcycles_errors(self):
925+
"""Test that TaskGroup deletes self._errors, and __aexit__ args"""
926+
tg = asyncio.TaskGroup()
927+
exc = None
928+
929+
class _Done(Exception):
930+
pass
931+
932+
try:
933+
async with tg:
934+
raise _Done
935+
except* _Done as excs:
936+
exc = excs.exceptions[0]
937+
938+
self.assertIsInstance(exc, _Done)
939+
self.assertListEqual(gc.get_referrers(exc), [])
940+
941+
942+
async def test_exception_refcycles_parent_task(self):
943+
"""Test that TaskGroup deletes self._parent_task"""
944+
tg = asyncio.TaskGroup()
945+
exc = None
946+
947+
class _Done(Exception):
948+
pass
949+
950+
async def coro_fn():
951+
async with tg:
952+
raise _Done
953+
954+
try:
955+
async with asyncio.TaskGroup() as tg2:
956+
tg2.create_task(coro_fn())
957+
except* _Done as excs:
958+
exc = excs.exceptions[0].exceptions[0]
959+
960+
self.assertIsInstance(exc, _Done)
961+
self.assertListEqual(gc.get_referrers(exc), [])
962+
963+
async def test_exception_refcycles_propagate_cancellation_error(self):
964+
"""Test that TaskGroup deletes propagate_cancellation_error"""
965+
tg = asyncio.TaskGroup()
966+
exc = None
967+
968+
try:
969+
async with asyncio.timeout(-1):
970+
async with tg:
971+
await asyncio.sleep(0)
972+
except TimeoutError as e:
973+
exc = e.__cause__
974+
975+
self.assertIsInstance(exc, asyncio.CancelledError)
976+
self.assertListEqual(gc.get_referrers(exc), [])
977+
978+
async def test_exception_refcycles_base_error(self):
979+
"""Test that TaskGroup deletes self._base_error"""
980+
class MyKeyboardInterrupt(KeyboardInterrupt):
981+
pass
982+
983+
tg = asyncio.TaskGroup()
984+
exc = None
985+
986+
try:
987+
async with tg:
988+
raise MyKeyboardInterrupt
989+
except MyKeyboardInterrupt as e:
990+
exc = e
991+
992+
self.assertIsNotNone(exc)
993+
self.assertListEqual(gc.get_referrers(exc), [])
994+
902995

903996
if __name__ == "__main__":
904997
unittest.main()

0 commit comments

Comments
 (0)