Skip to content

Commit

Permalink
core[patch]: Fix runnable map ser/de (#20631)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Apr 19, 2024
1 parent 1cbab0e commit 48307e4
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 30 deletions.
34 changes: 17 additions & 17 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2136,7 +2136,7 @@ def _seq_input_schema(
**{
k: (v.annotation, v.default)
for k, v in next_input_schema.__fields__.items()
if k not in first.mapper.steps
if k not in first.mapper.steps__
},
)
elif isinstance(first, RunnablePick):
Expand Down Expand Up @@ -2981,11 +2981,11 @@ def mul_three(x: int) -> int:
print(output) # noqa: T201
"""

steps: Mapping[str, Runnable[Input, Any]]
steps__: Mapping[str, Runnable[Input, Any]]

def __init__(
self,
__steps: Optional[
steps__: Optional[
Mapping[
str,
Union[
Expand All @@ -3001,10 +3001,10 @@ def __init__(
Mapping[str, Union[Runnable[Input, Any], Callable[[Input], Any]]],
],
) -> None:
merged = {**__steps} if __steps is not None else {}
merged = {**steps__} if steps__ is not None else {}
merged.update(kwargs)
super().__init__( # type: ignore[call-arg]
steps={key: coerce_to_runnable(r) for key, r in merged.items()}
steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
)

@classmethod
Expand All @@ -3022,12 +3022,12 @@ class Config:
def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = name or self.name or f"RunnableParallel<{','.join(self.steps.keys())}>"
name = name or self.name or f"RunnableParallel<{','.join(self.steps__.keys())}>"
return super().get_name(suffix, name=name)

@property
def InputType(self) -> Any:
for step in self.steps.values():
for step in self.steps__.values():
if step.InputType:
return step.InputType

Expand All @@ -3038,14 +3038,14 @@ def get_input_schema(
) -> Type[BaseModel]:
if all(
s.get_input_schema(config).schema().get("type", "object") == "object"
for s in self.steps.values()
for s in self.steps__.values()
):
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Input"),
**{
k: (v.annotation, v.default)
for step in self.steps.values()
for step in self.steps__.values()
for k, v in step.get_input_schema(config).__fields__.items()
if k != "__root__"
},
Expand All @@ -3059,13 +3059,13 @@ def get_output_schema(
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
self.get_name("Output"),
**{k: (v.OutputType, None) for k, v in self.steps.items()},
**{k: (v.OutputType, None) for k, v in self.steps__.items()},
)

@property
def config_specs(self) -> List[ConfigurableFieldSpec]:
return get_unique_config_specs(
spec for step in self.steps.values() for spec in step.config_specs
spec for step in self.steps__.values() for spec in step.config_specs
)

def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
Expand All @@ -3074,7 +3074,7 @@ def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
graph = Graph()
input_node = graph.add_node(self.get_input_schema(config))
output_node = graph.add_node(self.get_output_schema(config))
for step in self.steps.values():
for step in self.steps__.values():
step_graph = step.get_graph()
step_graph.trim_first_node()
step_graph.trim_last_node()
Expand All @@ -3096,7 +3096,7 @@ def get_graph(self, config: Optional[RunnableConfig] = None) -> Graph:
def __repr__(self) -> str:
map_for_repr = ",\n ".join(
f"{k}: {indent_lines_after_first(repr(v), ' ' + k + ': ')}"
for k, v in self.steps.items()
for k, v in self.steps__.items()
)
return "{\n " + map_for_repr + "\n}"

Expand Down Expand Up @@ -3127,7 +3127,7 @@ def invoke(
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
steps = dict(self.steps__)
with get_executor_for_config(config) as executor:
futures = [
executor.submit(
Expand Down Expand Up @@ -3170,7 +3170,7 @@ async def ainvoke(
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = dict(self.steps)
steps = dict(self.steps__)
results = await asyncio.gather(
*(
step.ainvoke(
Expand Down Expand Up @@ -3199,7 +3199,7 @@ def _transform(
config: RunnableConfig,
) -> Iterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
steps = dict(self.steps__)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(safetee(input, len(steps), lock=threading.Lock()))
Expand Down Expand Up @@ -3264,7 +3264,7 @@ async def _atransform(
config: RunnableConfig,
) -> AsyncIterator[AddableDict]:
# Shallow copy steps to ignore mutations while in progress
steps = dict(self.steps)
steps = dict(self.steps__)
# Each step gets a copy of the input iterator,
# which is consumed in parallel in a separate thread.
input_copies = list(atee(input, len(steps), lock=asyncio.Lock()))
Expand Down
8 changes: 5 additions & 3 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def get_name(
self, suffix: Optional[str] = None, *, name: Optional[str] = None
) -> str:
name = (
name or self.name or f"RunnableAssign<{','.join(self.mapper.steps.keys())}>"
name
or self.name
or f"RunnableAssign<{','.join(self.mapper.steps__.keys())}>"
)
return super().get_name(suffix, name=name)

Expand Down Expand Up @@ -488,7 +490,7 @@ def _transform(
**kwargs: Any,
) -> Iterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = safetee(input, 2, lock=threading.Lock())

Expand Down Expand Up @@ -544,7 +546,7 @@ async def _atransform(
**kwargs: Any,
) -> AsyncIterator[Dict[str, Any]]:
# collect mapper keys
mapper_keys = set(self.mapper.steps.keys())
mapper_keys = set(self.mapper.steps__.keys())
# create two streams, one for the map and one for the passthrough
for_passthrough, for_map = atee(input, 2, lock=asyncio.Lock())
# create map output stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"buz": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -569,7 +569,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"text": {
"lc": 1,
"type": "constructor",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2051,7 +2051,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"key": {
"lc": 1,
"type": "not_implemented",
Expand All @@ -2073,7 +2073,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"question": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -4459,7 +4459,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"key": {
"lc": 1,
"type": "not_implemented",
Expand All @@ -4481,7 +4481,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"question": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -8760,7 +8760,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"question": {
"lc": 1,
"type": "constructor",
Expand Down Expand Up @@ -9860,7 +9860,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"chat": {
"lc": 1,
"type": "not_implemented",
Expand Down Expand Up @@ -10352,7 +10352,7 @@
"RunnableParallel"
],
"kwargs": {
"steps": {
"steps__": {
"chat": {
"lc": 1,
"type": "constructor",
Expand Down
6 changes: 5 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FakeStreamingListLLM,
)
from langchain_core.load import dumpd, dumps
from langchain_core.load.load import loads
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -76,7 +77,7 @@
add,
chain,
)
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.base import RunnableMap, RunnableSerializable
from langchain_core.runnables.utils import Input, Output
from langchain_core.tools import BaseTool, tool
from langchain_core.tracers import (
Expand Down Expand Up @@ -3553,6 +3554,9 @@ async def test_map_astream_iterator_input() -> None:
assert final_value.get("llm") == "i'm a textbot"
assert final_value.get("passthrough") == llm_res

simple_map = RunnableMap(passthrough=RunnablePassthrough())
assert loads(dumps(simple_map)) == simple_map


def test_with_config_with_config() -> None:
llm = FakeListLLM(responses=["i'm a textbot"])
Expand Down

0 comments on commit 48307e4

Please sign in to comment.