From ed8e9c437a62e305bb38a2624282152262b014d4 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Mon, 3 Jun 2024 07:18:10 -0700 Subject: [PATCH] core: In RunnableSequence pass kwargs to the first step (#22393) - This is a pattern that shows up occasionally in langgraph questions, people chain a graph to something else after, and want to pass the graph some kwargs (eg. stream_mode) --- libs/core/langchain_core/runnables/base.py | 69 +++++---- .../unit_tests/runnables/test_runnable.py | 131 ++++++++++++++++-- 2 files changed, 161 insertions(+), 39 deletions(-) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index b51cb5d89fc87..43b025ad441c5 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2379,7 +2379,9 @@ def __ror__( name=self.name, ) - def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: + def invoke( + self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any + ) -> Output: from langchain_core.beta.runnables.context import config_with_context # setup callbacks and context @@ -2396,13 +2398,14 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu # invoke all steps in sequence try: for i, step in enumerate(self.steps): - input = step.invoke( - input, - # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") - ), + # mark each step as a child run + config = patch_config( + config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) + if i == 0: + input = step.invoke(input, config, **kwargs) + else: + input = step.invoke(input, config) # finish the root run except BaseException as e: run_manager.on_chain_error(e) @@ -2433,13 +2436,14 @@ async def ainvoke( # invoke all steps in sequence try: for i, step in enumerate(self.steps): - input = await step.ainvoke( - input, - # mark each step as a child run - patch_config( - config, callbacks=run_manager.get_child(f"seq:step:{i+1}") - ), + # mark each step as a child run + config = patch_config( + config, callbacks=run_manager.get_child(f"seq:step:{i+1}") ) + if i == 0: + input = await step.ainvoke(input, config, **kwargs) + else: + input = await step.ainvoke(input, config) # finish the root run except BaseException as e: await run_manager.on_chain_error(e) @@ -2519,7 +2523,7 @@ def batch( if i not in failed_inputs_map ], return_exceptions=return_exceptions, - **kwargs, + **(kwargs if stepidx == 0 else {}), ) # If an input failed, add it to the map for i, inp in zip(remaining_idxs, inputs): @@ -2549,6 +2553,8 @@ def batch( ) for rm, config in zip(run_managers, configs) ], + return_exceptions=return_exceptions, + **(kwargs if i == 0 else {}), ) # finish the root runs @@ -2646,7 +2652,7 @@ async def abatch( if i not in failed_inputs_map ], return_exceptions=return_exceptions, - **kwargs, + **(kwargs if stepidx == 0 else {}), ) # If an input failed, add it to the map for i, inp in zip(remaining_idxs, inputs): @@ -2676,6 +2682,8 @@ async def abatch( ) for rm, config in zip(run_managers, configs) ], + return_exceptions=return_exceptions, + **(kwargs if i == 0 else {}), ) # finish the root runs except BaseException as e: @@ -2704,6 +2712,7 @@ def _transform( input: Iterator[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> Iterator[Output]: from langchain_core.beta.runnables.context import config_with_context @@ -2714,14 +2723,14 @@ def _transform( # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output final_pipeline = cast(Iterator[Output], input) - for step in steps: - final_pipeline = step.transform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), - ), + for idx, step in enumerate(steps): + config = patch_config( + config, callbacks=run_manager.get_child(f"seq:step:{idx+1}") ) + if idx == 0: + final_pipeline = step.transform(final_pipeline, config, **kwargs) + else: + final_pipeline = step.transform(final_pipeline, config) for output in final_pipeline: yield output @@ -2731,6 +2740,7 @@ async def _atransform( input: AsyncIterator[Input], run_manager: AsyncCallbackManagerForChainRun, config: RunnableConfig, + **kwargs: Any, ) -> AsyncIterator[Output]: from langchain_core.beta.runnables.context import aconfig_with_context @@ -2742,14 +2752,15 @@ async def _atransform( # steps that don't natively support transforming an input stream will # buffer input in memory until all available, and then start emitting output final_pipeline = cast(AsyncIterator[Output], input) - for step in steps: - final_pipeline = step.atransform( - final_pipeline, - patch_config( - config, - callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"), - ), + for idx, step in enumerate(steps): + config = patch_config( + config, + callbacks=run_manager.get_child(f"seq:step:{idx+1}"), ) + if idx == 0: + final_pipeline = step.atransform(final_pipeline, config, **kwargs) + else: + final_pipeline = step.atransform(final_pipeline, config) async for output in final_pipeline: yield output diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 72a9494a8078a..2672d17c5bc36 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -182,6 +182,7 @@ def invoke( self, input: str, config: Optional[RunnableConfig] = None, + **kwargs: Any, ) -> int: return len(input) @@ -1409,26 +1410,136 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() - seq: Runnable = fake | RunnablePassthrough(mock) + seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + + assert await seq.ainvoke("hello", my_kwarg="value") == 5 + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert await seq.abatch( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) == [ + 5, + 6, + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() - assert await seq.ainvoke("hello") == 5 - assert mock.call_args_list == [mocker.call(5)] + assert sorted( + [ + a + async for a in seq.abatch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ] + ) == [ + (0, 5), + (1, 6), + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list mock.reset_mock() assert [ - part async for part in seq.astream("hello", dict(metadata={"key": "value"})) + part + async for part in seq.astream( + "hello", dict(metadata={"key": "value"}), my_kwarg="value" + ) ] == [5] - assert mock.call_args_list == [mocker.call(5)] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), + ] + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list + mock.reset_mock() + + assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [ + 5, + 6, + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list mock.reset_mock() - assert seq.invoke("hello") == 5 - assert mock.call_args_list == [mocker.call(5)] + assert sorted( + a + for a in seq.batch_as_completed( + ["hello", "byebye"], my_kwarg="value", return_exceptions=True + ) + ) == [ + (0, 5), + (1, 6), + ] + assert len(mock.call_args_list) == 4 + for call in [ + mocker.call("hello", my_kwarg="value"), + mocker.call("byebye", my_kwarg="value"), + mocker.call(5), + mocker.call(6), + ]: + assert call in mock.call_args_list mock.reset_mock() - assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [ - 5 + assert [ + part + for part in seq.stream( + "hello", dict(metadata={"key": "value"}), my_kwarg="value" + ) + ] == [5] + assert mock.call_args_list == [ + mocker.call("hello", my_kwarg="value"), + mocker.call(5), ] - assert mock.call_args_list == [mocker.call(5)] mock.reset_mock()