Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: In astream_events v2 propagate cancel/break to the inner astream call #22865

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 36 additions & 21 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Expand Down Expand Up @@ -79,7 +80,7 @@
is_async_callable,
is_async_generator,
)
from langchain_core.utils.aiter import atee, py_anext
from langchain_core.utils.aiter import aclosing, atee, py_anext
from langchain_core.utils.iter import safetee

if TYPE_CHECKING:
Expand Down Expand Up @@ -1141,8 +1142,9 @@ async def reverse(s: str) -> str:
'Only versions "v1" and "v2" of the schema is currently supported.'
)

async for event in event_stream:
yield event
async with aclosing(event_stream):
async for event in event_stream:
yield event

def transform(
self,
Expand Down Expand Up @@ -1948,7 +1950,7 @@ async def _atransform_stream_with_config(
kwargs["run_manager"] = run_manager
context = copy_context()
context.run(_set_config_context, child_config)
iterator = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]
iterator_ = context.run(transformer, input_for_transform, **kwargs) # type: ignore[arg-type]

if stream_handler := next(
(
Expand All @@ -1960,7 +1962,11 @@ async def _atransform_stream_with_config(
None,
):
# populates streamed_output in astream_log() output if needed
iterator = stream_handler.tap_output_aiter(run_manager.run_id, iterator)
iterator = stream_handler.tap_output_aiter(
run_manager.run_id, iterator_
)
else:
iterator = iterator_
try:
while True:
if accepts_context(asyncio.create_task):
Expand Down Expand Up @@ -2001,6 +2007,9 @@ async def _atransform_stream_with_config(
raise
else:
await run_manager.on_chain_end(final_output, inputs=final_input)
finally:
if hasattr(iterator_, "aclose"):
await iterator_.aclose()


class RunnableSerializable(Serializable, Runnable[Input, Output]):
Expand Down Expand Up @@ -3907,23 +3916,29 @@ async def f(*args, **kwargs): # type: ignore[no-untyped-def]

if is_async_generator(afunc):
output: Optional[Output] = None
async for chunk in cast(
AsyncIterator[Output],
acall_func_with_variable_args(
cast(Callable, afunc),
input,
config,
run_manager,
**kwargs,
),
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
async with aclosing(
cast(
AsyncGenerator[Any, Any],
acall_func_with_variable_args(
cast(Callable, afunc),
input,
config,
run_manager,
**kwargs,
),
)
) as stream:
async for chunk in cast(
AsyncIterator[Output],
stream,
):
if output is None:
output = chunk
else:
try:
output = output + chunk # type: ignore[operator]
except TypeError:
output = chunk
else:
output = await acall_func_with_variable_args(
cast(Callable, afunc), input, config, run_manager, **kwargs
Expand Down
20 changes: 10 additions & 10 deletions libs/core/langchain_core/tracers/event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import py_anext
from langchain_core.utils.aiter import aclosing, py_anext

if TYPE_CHECKING:
from langchain_core.documents import Document
Expand Down Expand Up @@ -903,11 +903,10 @@ async def _astream_events_implementation_v2(
async def consume_astream() -> None:
try:
# if astream also calls tap_output_aiter this will be a no-op
async for _ in event_streamer.tap_output_aiter(
run_id, runnable.astream(input, config, **kwargs)
):
# All the content will be picked up
pass
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
async for _ in event_streamer.tap_output_aiter(run_id, stream):
# All the content will be picked up
pass
finally:
await event_streamer.send_stream.aclose()

Expand Down Expand Up @@ -942,7 +941,8 @@ async def consume_astream() -> None:
yield event
finally:
# Wait for the runnable to finish, if not cancelled (eg. by break)
try:
await task
except asyncio.CancelledError:
pass
if task.cancel():
try:
await task
except asyncio.CancelledError:
pass
40 changes: 40 additions & 0 deletions libs/core/langchain_core/utils/aiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""

from collections import deque
from contextlib import AbstractAsyncContextManager
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
Expand All @@ -18,6 +20,7 @@
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -207,3 +210,40 @@ async def aclose(self) -> None:


atee = Tee


class aclosing(AbstractAsyncContextManager):
"""Async context manager for safely finalizing an asynchronously cleaned-up
resource such as an async generator, calling its ``aclose()`` method.

Code like this:

async with aclosing(<module>.fetch(<arguments>)) as agen:
<block>

is equivalent to this:

agen = <module>.fetch(<arguments>)
try:
<block>
finally:
await agen.aclose()

"""

def __init__(
self, thing: Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]
) -> None:
self.thing = thing

async def __aenter__(self) -> Union[AsyncGenerator[Any, Any], AsyncIterator[Any]]:
return self.thing

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if hasattr(self.thing, "aclose"):
await self.thing.aclose()
140 changes: 140 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module that contains tests for runnable.astream_events API."""
import asyncio
import sys
import uuid
from itertools import cycle
Expand Down Expand Up @@ -38,13 +39,15 @@
RunnableConfig,
RunnableGenerator,
RunnableLambda,
chain,
ensure_config,
)
from langchain_core.runnables.config import get_callback_manager_for_config
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.schema import StreamEvent
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import tool
from langchain_core.utils.aiter import aclosing
from tests.unit_tests.stubs import AnyStr


Expand Down Expand Up @@ -2195,3 +2198,140 @@ def passthrough_to_trigger_issue(x: str) -> str:
for event in events
if event["event"] == "on_chat_model_stream"
] == ["hello", " ", "world"]


async def test_break_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()

async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise

def reset(self) -> None:
self.started = False
self.cancelled = False

alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()

outer_cancelled = False

@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise

# test interrupting astream_events v2

got_event = False
thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
async with aclosing(
sequence.astream_events({"value": 1}, thread2, version="v2")
) as stream:
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
break

# did break
assert got_event
# did cancel outer chain
assert outer_cancelled

# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False

# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True

# node "anotherwhile" should never start
assert anotherwhile.started is False


async def test_cancel_astream_events() -> None:
class AwhileMaker:
def __init__(self) -> None:
self.reset()

async def __call__(self, input: Any) -> Any:
self.started = True
try:
await asyncio.sleep(0.5)
return input
except asyncio.CancelledError:
self.cancelled = True
raise

def reset(self) -> None:
self.started = False
self.cancelled = False

alittlewhile = AwhileMaker()
awhile = AwhileMaker()
anotherwhile = AwhileMaker()

outer_cancelled = False

@chain
async def sequence(input: Any) -> Any:
try:
yield await alittlewhile(input)
yield await awhile(input)
yield await anotherwhile(input)
except asyncio.CancelledError:
nonlocal outer_cancelled
outer_cancelled = True
raise

got_event = False

async def aconsume(stream: AsyncIterator[Any]) -> None:
nonlocal got_event
# here we don't need aclosing as cancelling the task is propagated
# to the async generator being consumed
async for chunk in stream:
if chunk["event"] == "on_chain_stream":
got_event = True
assert chunk["data"]["chunk"] == {"value": 1}
task.cancel()

thread2: RunnableConfig = {"configurable": {"thread_id": 2}}
task = asyncio.create_task(
aconsume(sequence.astream_events({"value": 1}, thread2, version="v2"))
)

with pytest.raises(asyncio.CancelledError):
await task

# did break
assert got_event
# did cancel outer chain
assert outer_cancelled

# node "alittlewhile" starts, not cancelled
assert alittlewhile.started is True
assert alittlewhile.cancelled is False

# node "awhile" starts but is cancelled
assert awhile.started is True
assert awhile.cancelled is True

# node "anotherwhile" should never start
assert anotherwhile.started is False
Loading