Skip to content

Commit

Permalink
core: In RunnableSequence pass kwargs to the first step (#22393)
Browse files Browse the repository at this point in the history
- This is a pattern that shows up occasionally in langgraph questions,
people chain a graph to something else after, and want to pass the graph
some kwargs (eg. stream_mode)
  • Loading branch information
nfcampos authored Jun 3, 2024
1 parent eabcfaa commit ed8e9c4
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 39 deletions.
69 changes: 40 additions & 29 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,9 @@ def __ror__(
name=self.name,
)

def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None, **kwargs: Any
) -> Output:
from langchain_core.beta.runnables.context import config_with_context

# setup callbacks and context
Expand All @@ -2396,13 +2398,14 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
input = step.invoke(
input,
# mark each step as a child run
patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
if i == 0:
input = step.invoke(input, config, **kwargs)
else:
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
Expand Down Expand Up @@ -2433,13 +2436,14 @@ async def ainvoke(
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
input = await step.ainvoke(
input,
# mark each step as a child run
patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
),
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i+1}")
)
if i == 0:
input = await step.ainvoke(input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
Expand Down Expand Up @@ -2519,7 +2523,7 @@ def batch(
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
Expand Down Expand Up @@ -2549,6 +2553,8 @@ def batch(
)
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**(kwargs if i == 0 else {}),
)

# finish the root runs
Expand Down Expand Up @@ -2646,7 +2652,7 @@ async def abatch(
if i not in failed_inputs_map
],
return_exceptions=return_exceptions,
**kwargs,
**(kwargs if stepidx == 0 else {}),
)
# If an input failed, add it to the map
for i, inp in zip(remaining_idxs, inputs):
Expand Down Expand Up @@ -2676,6 +2682,8 @@ async def abatch(
)
for rm, config in zip(run_managers, configs)
],
return_exceptions=return_exceptions,
**(kwargs if i == 0 else {}),
)
# finish the root runs
except BaseException as e:
Expand Down Expand Up @@ -2704,6 +2712,7 @@ def _transform(
input: Iterator[Input],
run_manager: CallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> Iterator[Output]:
from langchain_core.beta.runnables.context import config_with_context

Expand All @@ -2714,14 +2723,14 @@ def _transform(
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(Iterator[Output], input)
for step in steps:
final_pipeline = step.transform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
for idx, step in enumerate(steps):
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{idx+1}")
)
if idx == 0:
final_pipeline = step.transform(final_pipeline, config, **kwargs)
else:
final_pipeline = step.transform(final_pipeline, config)

for output in final_pipeline:
yield output
Expand All @@ -2731,6 +2740,7 @@ async def _atransform(
input: AsyncIterator[Input],
run_manager: AsyncCallbackManagerForChainRun,
config: RunnableConfig,
**kwargs: Any,
) -> AsyncIterator[Output]:
from langchain_core.beta.runnables.context import aconfig_with_context

Expand All @@ -2742,14 +2752,15 @@ async def _atransform(
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
final_pipeline = cast(AsyncIterator[Output], input)
for step in steps:
final_pipeline = step.atransform(
final_pipeline,
patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{steps.index(step)+1}"),
),
for idx, step in enumerate(steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx+1}"),
)
if idx == 0:
final_pipeline = step.atransform(final_pipeline, config, **kwargs)
else:
final_pipeline = step.atransform(final_pipeline, config)
async for output in final_pipeline:
yield output

Expand Down
131 changes: 121 additions & 10 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def invoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
**kwargs: Any,
) -> int:
return len(input)

Expand Down Expand Up @@ -1409,26 +1410,136 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
fake = FakeRunnable()
mock = mocker.Mock()

seq: Runnable = fake | RunnablePassthrough(mock)
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock)

assert await seq.ainvoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()

assert await seq.abatch(["hello", "byebye"], my_kwarg="value") == [5, 6]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert await seq.abatch(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
) == [
5,
6,
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert await seq.ainvoke("hello") == 5
assert mock.call_args_list == [mocker.call(5)]
assert sorted(
[
a
async for a in seq.abatch_as_completed(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
)
]
) == [
(0, 5),
(1, 6),
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert [
part async for part in seq.astream("hello", dict(metadata={"key": "value"}))
part
async for part in seq.astream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [mocker.call(5)]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()

assert seq.invoke("hello", my_kwarg="value") == 5 # type: ignore[call-arg]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
mock.reset_mock()

assert seq.batch(["hello", "byebye"], my_kwarg="value") == [5, 6]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert seq.batch(["hello", "byebye"], my_kwarg="value", return_exceptions=True) == [
5,
6,
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert seq.invoke("hello") == 5
assert mock.call_args_list == [mocker.call(5)]
assert sorted(
a
for a in seq.batch_as_completed(
["hello", "byebye"], my_kwarg="value", return_exceptions=True
)
) == [
(0, 5),
(1, 6),
]
assert len(mock.call_args_list) == 4
for call in [
mocker.call("hello", my_kwarg="value"),
mocker.call("byebye", my_kwarg="value"),
mocker.call(5),
mocker.call(6),
]:
assert call in mock.call_args_list
mock.reset_mock()

assert [part for part in seq.stream("hello", dict(metadata={"key": "value"}))] == [
5
assert [
part
for part in seq.stream(
"hello", dict(metadata={"key": "value"}), my_kwarg="value"
)
] == [5]
assert mock.call_args_list == [
mocker.call("hello", my_kwarg="value"),
mocker.call(5),
]
assert mock.call_args_list == [mocker.call(5)]
mock.reset_mock()


Expand Down

0 comments on commit ed8e9c4

Please sign in to comment.