Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ def output_schema(self) -> Type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return []

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
class _Config:
arbitrary_types_allowed = True

Expand All @@ -150,7 +148,7 @@ class _Config:
for spec in config_specs
},
)
if config_specs
if config_specs and "configurable" in include
else None
)

Expand All @@ -161,7 +159,7 @@ class _Config:
**{
field_name: (field_type, None)
for field_name, field_type in RunnableConfig.__annotations__.items()
if field_name in include
if field_name in [i for i in include if i != "configurable"]
},
)

Expand Down Expand Up @@ -873,7 +871,7 @@ def configurable_fields(
"available keys are {self.__fields__.keys()}"
)

return RunnableConfigurableFields(bound=self, fields=kwargs)
return RunnableConfigurableFields(default=self, fields=kwargs)

def configurable_alternatives(
self,
Expand All @@ -885,7 +883,7 @@ def configurable_alternatives(
)

return RunnableConfigurableAlternatives(
which=which, bound=self, alternatives=kwargs
which=which, default=self, alternatives=kwargs
)


Expand Down Expand Up @@ -2051,9 +2049,7 @@ def output_schema(self) -> type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
Expand Down Expand Up @@ -2132,9 +2128,7 @@ def output_schema(self) -> Type[BaseModel]:
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return self.bound.config_specs

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.bound.config_schema(include=include)

@classmethod
Expand Down
59 changes: 33 additions & 26 deletions libs/langchain/langchain/schema/runnable/configurable.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from __future__ import annotations

import enum
from abc import abstractmethod
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Expand All @@ -32,7 +32,7 @@


class DynamicRunnable(RunnableSerializable[Input, Output]):
bound: RunnableSerializable[Input, Output]
default: RunnableSerializable[Input, Output]

class Config:
arbitrary_types_allowed = True
Expand All @@ -47,19 +47,19 @@ def get_lc_namespace(cls) -> List[str]:

@property
def InputType(self) -> Type[Input]:
return self.bound.InputType
return self.default.InputType

@property
def OutputType(self) -> Type[Output]:
return self.bound.OutputType
return self.default.OutputType

@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema
return self.default.input_schema

@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema
return self.default.output_schema

@abstractmethod
def _prepare(
Expand Down Expand Up @@ -88,8 +88,8 @@ def batch(
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]

if all(p is self.bound for p in prepared):
return self.bound.batch(
if all(p is self.default for p in prepared):
return self.default.batch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)

Expand Down Expand Up @@ -131,8 +131,8 @@ async def abatch(
configs = get_config_list(config, len(inputs))
prepared = [self._prepare(c) for c in configs]

if all(p is self.bound for p in prepared):
return await self.bound.abatch(
if all(p is self.default for p in prepared):
return await self.default.abatch(
inputs, config, return_exceptions=return_exceptions, **kwargs
)

Expand Down Expand Up @@ -202,18 +202,18 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
id=spec.id,
name=spec.name,
description=spec.description
or self.bound.__fields__[field_name].field_info.description,
or self.default.__fields__[field_name].field_info.description,
annotation=spec.annotation
or self.bound.__fields__[field_name].annotation,
default=getattr(self.bound, field_name),
or self.default.__fields__[field_name].annotation,
default=getattr(self.default, field_name),
)
for field_name, spec in self.fields.items()
]

def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.bound.configurable_fields(**{**self.fields, **kwargs})
return self.default.configurable_fields(**{**self.fields, **kwargs})

def _prepare(
self, config: Optional[RunnableConfig] = None
Expand All @@ -227,49 +227,56 @@ def _prepare(
}

if configurable:
return self.bound.__class__(**{**self.bound.dict(), **configurable})
return self.default.__class__(**{**self.default.dict(), **configurable})
else:
return self.bound
return self.default


# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
pass


class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
which: ConfigurableField

alternatives: Dict[str, RunnableSerializable[Input, Output]]

default_key: str = "default"

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
alt_keys = self.alternatives.keys()
which_keys = tuple(Literal[k] for k in alt_keys) + ( # type: ignore
Literal["default"],
which_enum = StrEnum( # type: ignore[call-overload]
self.which.name or self.which.id,
((v, v) for v in list(self.alternatives.keys()) + [self.default_key]),
)
return [
ConfigurableFieldSpec(
id=self.which.id,
name=self.which.name,
description=self.which.description,
annotation=Union[which_keys], # type: ignore
default="default",
annotation=which_enum,
default=self.default_key,
),
*self.bound.config_specs,
*self.default.config_specs,
] + [s for alt in self.alternatives.values() for s in alt.config_specs]

def configurable_fields(
self, **kwargs: ConfigurableField
) -> RunnableSerializable[Input, Output]:
return self.__class__(
which=self.which,
bound=self.bound.configurable_fields(**kwargs),
default=self.default.configurable_fields(**kwargs),
alternatives=self.alternatives,
)

def _prepare(
self, config: Optional[RunnableConfig] = None
) -> Runnable[Input, Output]:
config = config or {}
which = config.get("configurable", {}).get(self.which.id)
if not which:
return self.bound
which = str(config.get("configurable", {}).get(self.which.id, self.default_key))
if which == self.default_key:
return self.default
elif which in self.alternatives:
return self.alternatives[which]
else:
Expand Down
4 changes: 1 addition & 3 deletions libs/langchain/langchain/schema/runnable/fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
for spec in step.config_specs
)

def config_schema(
self, *, include: Optional[Sequence[str]] = None
) -> Type[BaseModel]:
def config_schema(self, *, include: Sequence[str]) -> Type[BaseModel]:
return self.runnable.config_schema(include=include)

@classmethod
Expand Down
25 changes: 15 additions & 10 deletions libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def test_configurable_fields() -> None:

assert fake_llm_configurable.invoke("...") == "a"

assert fake_llm_configurable.config_schema().schema() == {
assert fake_llm_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_configurable_fields() -> None:
text="Hello, John!"
)

assert prompt_configurable.config_schema().schema() == {
assert prompt_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableConfigurableFieldsConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_configurable_fields() -> None:

assert chain_configurable.invoke({"name": "John"}) == "a"

assert chain_configurable.config_schema().schema() == {
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -690,7 +690,9 @@ def test_configurable_fields() -> None:
"llm3": "a",
}

assert chain_with_map_configurable.config_schema().schema() == {
assert chain_with_map_configurable.config_schema(
include=["configurable"]
).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
Expand Down Expand Up @@ -760,22 +762,25 @@ def test_configurable_fields_example() -> None:

assert chain_configurable.invoke({"name": "John"}) == "a"

assert chain_configurable.config_schema().schema() == {
assert chain_configurable.config_schema(include=["configurable"]).schema() == {
"title": "RunnableSequenceConfig",
"type": "object",
"properties": {"configurable": {"$ref": "#/definitions/Configurable"}},
"definitions": {
"LLM": {
"title": "LLM",
"description": "An enumeration.",
"enum": ["chat", "default"],
"type": "string",
},
"Configurable": {
"title": "Configurable",
"type": "object",
"properties": {
"llm": {
"title": "LLM",
"default": "default",
"anyOf": [
{"enum": ["chat"], "type": "string"},
{"enum": ["default"], "type": "string"},
],
"allOf": [{"$ref": "#/definitions/LLM"}],
},
"llm_responses": {
"title": "LLM Responses",
Expand All @@ -791,7 +796,7 @@ def test_configurable_fields_example() -> None:
"type": "string",
},
},
}
},
},
}

Expand Down