Skip to content
4 changes: 4 additions & 0 deletions newsfragments/612.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
The nursery context manager was rewritten to avoid use of
`@asynccontextmanager` and `@async_generator`. This reduces extraneous frames
in exception traces and addresses bugs regarding `StopIteration` and
`StopAsyncIteration` exceptions not propagating correctly.
96 changes: 46 additions & 50 deletions trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from sniffio import current_async_library_cvar

import attr
from async_generator import (
async_generator, yield_, asynccontextmanager, isasyncgen
)
from async_generator import isasyncgen
from sortedcontainers import SortedDict
from outcome import Error, Value

Expand Down Expand Up @@ -295,57 +293,55 @@ def started(self, value=None):
self._old_nursery._check_nursery_closed()


@asynccontextmanager
@async_generator
@enable_ki_protection
async def open_nursery():
"""Returns an async context manager which creates a new nursery.
class NurseryManager:
"""Nursery context manager.

This context manager's ``__aenter__`` method executes synchronously. Its
``__aexit__`` method blocks until all child tasks have exited.
Note we explicitly avoid @asynccontextmanager and @async_generator
since they add a lot of extraneous stack frames to exceptions, as
well as cause problematic behavior with handling of StopIteration
and StopAsyncIteration.

"""
assert currently_ki_protected()
with open_cancel_scope() as scope:
nursery = Nursery(current_task(), scope)
nested_child_exc = None
try:
await yield_(nursery)
except BaseException as exc:
nested_child_exc = exc

@enable_ki_protection
async def __aenter__(self):
assert currently_ki_protected()
self._scope_manager = open_cancel_scope()
scope = self._scope_manager.__enter__()
self._nursery = Nursery(current_task(), scope)
return self._nursery

@enable_ki_protection
async def __aexit__(self, etype, exc, tb):
assert currently_ki_protected()
await nursery._nested_child_finished(nested_child_exc)


# I *think* this is equivalent to the above, and it gives *much* nicer
# exception tracebacks... but I'm a little nervous about it because it's much
# trickier code :-(
#
# class NurseryManager:
# @enable_ki_protection
# async def __aenter__(self):
# self._scope_manager = open_cancel_scope()
# scope = self._scope_manager.__enter__()
# self._parent_nursery = Nursery(current_task(), scope)
# return self._parent_nursery
#
# @enable_ki_protection
# async def __aexit__(self, etype, exc, tb):
# try:
# await self._parent_nursery._clean_up(exc)
# except BaseException as new_exc:
# if not self._scope_manager.__exit__(
# type(new_exc), new_exc, new_exc.__traceback__):
# if exc is new_exc:
# return False
# else:
# raise
# else:
# self._scope_manager.__exit__(None, None, None)
# return True
#
# def open_nursery():
# return NurseryManager()
try:
await self._nursery._nested_child_finished(exc)
except BaseException as new_exc:
try:
if self._scope_manager.__exit__(
type(new_exc), new_exc, new_exc.__traceback__
):
return True
except BaseException as scope_manager_exc:
if scope_manager_exc == exc:
return False
raise # scope_manager_exc
raise # new_exc
else:
self._scope_manager.__exit__(None, None, None)
return True

def __enter__(self):
raise RuntimeError(
"use 'async with open_nursery(...)', not 'with open_nursery(...)'"
)

def __exit__(self): # pragma: no cover
assert False, """Never called, but should be defined"""


def open_nursery():
return NurseryManager()


class Nursery:
Expand Down
74 changes: 74 additions & 0 deletions trio/_core/tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .tutil import check_sequence_matches, gc_collect_harder
from ... import _core
from ..._timeouts import sleep
from ..._util import aiter_compat
from ...testing import (
wait_all_tasks_blocked,
Sequencer,
Expand Down Expand Up @@ -1823,6 +1824,79 @@ async def start_sleep_then_crash(nursery):
assert _core.current_time() - t0 == 7


async def test_nursery_explicit_exception():
with pytest.raises(KeyError):
async with _core.open_nursery():
raise KeyError()


async def test_nursery_stop_iteration():
async def fail():
raise ValueError

try:
async with _core.open_nursery() as nursery:
nursery.start_soon(fail)
raise StopIteration
except _core.MultiError as e:
assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError)


async def test_nursery_stop_async_iteration():
class it(object):
def __init__(self, count):
self.count = count
self.val = 0

async def __anext__(self):
await sleep(0)
val = self.val
if val >= self.count:
raise StopAsyncIteration
self.val += 1
return val

class async_zip(object):
def __init__(self, *largs):
self.nexts = [obj.__anext__ for obj in largs]

async def _accumulate(self, f, items, i):
items[i] = await f()

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
nexts = self.nexts
items = [
None,
] * len(nexts)
got_stop = False

def handle(exc):
nonlocal got_stop
if isinstance(exc, StopAsyncIteration):
got_stop = True
return None
else: # pragma: no cover
return exc

with _core.MultiError.catch(handle):
async with _core.open_nursery() as nursery:
for i, f in enumerate(nexts):
nursery.start_soon(self._accumulate, f, items, i)

if got_stop:
raise StopAsyncIteration
return items

result = []
async for vals in async_zip(it(4), it(2)):
result.append(vals)
assert result == [[0, 0], [1, 1]]


def test_contextvar_support():
var = contextvars.ContextVar("test")
var.set("before")
Expand Down