diff --git a/src/asynkit/coroutine.py b/src/asynkit/coroutine.py index 850f2e2..3c88c77 100644 --- a/src/asynkit/coroutine.py +++ b/src/asynkit/coroutine.py @@ -108,7 +108,10 @@ def coro_is_new(coro: Suspendable) -> bool: elif inspect.isgenerator(coro): return inspect.getgeneratorstate(coro) == inspect.GEN_CREATED elif inspect.isasyncgen(coro): - return coro.ag_frame is not None and not coro.ag_running + # async generators have an ag_await if they are suspended + # ag_running() means that it is inside an anext() or athrow() + # but it may be suspended. + return coro.ag_frame is not None and coro.ag_await is None and not coro.ag_running else: raise TypeError( f"a coroutine or coroutine like object is required. Got: {type(coro)}" @@ -124,9 +127,7 @@ def coro_is_suspended(coro: Suspendable) -> bool: elif inspect.isgenerator(coro): return inspect.getgeneratorstate(coro) == inspect.GEN_SUSPENDED elif inspect.isasyncgen(coro): - # This is true only if we are inside an anext() or athrow(), not if the - # inner coroutine is itself doing an await before yielding a value. - return coro.ag_running + return coro.ag_await is not None else: raise TypeError( f"a coroutine or coroutine like object is required. Got: {type(coro)}"