Skip to content

Commit

Permalink
Revert "nc/runnable-dynamic-schemas-from-config" (langchain-ai#12037)
Browse files Browse the repository at this point in the history
This reverts commit a46eef6.

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Oct 19, 2023
1 parent a46eef6 commit 85eaa4c
Show file tree
Hide file tree
Showing 13 changed files with 82 additions and 151 deletions.
10 changes: 4 additions & 6 deletions libs/langchain/langchain/chains/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,15 @@ class Chain(RunnableSerializable[Dict[str, Any], Dict[str, Any]], ABC):
chains and cannot return as rich of an output as `__call__`.
"""

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainInput", **{k: (Any, None) for k in self.input_keys}
)

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"ChainOutput", **{k: (Any, None) for k in self.output_keys}
Expand Down
23 changes: 9 additions & 14 deletions libs/langchain/langchain/chains/combine_documents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Field, create_model
from langchain.schema.runnable.config import RunnableConfig
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter


Expand All @@ -29,17 +28,15 @@ class BaseCombineDocumentsChain(Chain, ABC):
input_key: str = "input_documents" #: :meta private:
output_key: str = "output_text" #: :meta private:

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsInput",
**{self.input_key: (List[Document], None)}, # type: ignore[call-overload]
)

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"CombineDocumentsOutput",
**{self.output_key: (str, None)}, # type: ignore[call-overload]
Expand Down Expand Up @@ -170,18 +167,16 @@ def output_keys(self) -> List[str]:
"""
return self.combine_docs_chain.output_keys

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"AnalyzeDocumentChain",
**{self.input_key: (str, None)}, # type: ignore[call-overload]
)

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.combine_docs_chain.get_output_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.combine_docs_chain.output_schema

def _call(
self,
Expand Down
10 changes: 4 additions & 6 deletions libs/langchain/langchain/chains/combine_documents/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple

from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig


class MapReduceDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -99,9 +98,8 @@ class MapReduceDocumentsChain(BaseCombineDocumentsChain):
return_intermediate_steps: bool = False
"""Return the results of the map steps in the output."""

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def output_schema(self) -> type[BaseModel]:
if self.return_intermediate_steps:
return create_model(
"MapReduceDocumentsOutput",
Expand All @@ -111,7 +109,7 @@ def get_output_schema(
}, # type: ignore[call-overload]
)

return super().get_output_schema(config)
return super().output_schema

@property
def output_keys(self) -> List[str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

from langchain.callbacks.manager import Callbacks
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.pydantic_v1 import BaseModel, Extra, create_model, root_validator
from langchain.schema.runnable.config import RunnableConfig


class MapRerankDocumentsChain(BaseCombineDocumentsChain):
Expand Down Expand Up @@ -78,9 +77,8 @@ class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def output_schema(self) -> type[BaseModel]:
schema: Dict[str, Any] = {
self.output_key: (str, None),
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.schema.runnable.config import RunnableConfig
from langchain.schema.vectorstore import VectorStore

# Depending on the memory type and configuration, the chat history format may differ.
Expand Down Expand Up @@ -96,9 +95,8 @@ def input_keys(self) -> List[str]:
"""Input keys."""
return ["question", "chat_history"]

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
return InputType

@property
Expand Down
5 changes: 2 additions & 3 deletions libs/langchain/langchain/schema/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@ def OutputType(self) -> Any:

return Union[StringPromptValue, ChatPromptValueConcrete]

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"PromptInput",
Expand Down
78 changes: 28 additions & 50 deletions libs/langchain/langchain/schema/runnable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,6 @@ def OutputType(self) -> Type[Output]:

@property
def input_schema(self) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
return self.get_input_schema()

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of input this runnable accepts specified as a pydantic model."""
root_type = self.InputType

Expand All @@ -180,12 +174,6 @@ def get_input_schema(

@property
def output_schema(self) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
return self.get_output_schema()

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
"""The type of output this runnable produces specified as a pydantic model."""
root_type = self.OutputType

Expand Down Expand Up @@ -1056,15 +1044,13 @@ def InputType(self) -> Type[Input]:
def OutputType(self) -> Type[Output]:
return self.last.OutputType

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.first.get_input_schema(config)
@property
def input_schema(self) -> Type[BaseModel]:
return self.first.input_schema

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.last.get_output_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.last.output_schema

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
Expand Down Expand Up @@ -1565,11 +1551,10 @@ def InputType(self) -> Any:

return Any

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
if all(
s.get_input_schema(config).schema().get("type", "object") == "object"
s.input_schema.schema().get("type", "object") == "object"
for s in self.steps.values()
):
# This is correct, but pydantic typings/mypy don't think so.
Expand All @@ -1578,16 +1563,15 @@ def get_input_schema(
**{
k: (v.annotation, v.default)
for step in self.steps.values()
for k, v in step.get_input_schema(config).__fields__.items()
for k, v in step.input_schema.__fields__.items()
if k != "__root__"
},
)

return super().get_input_schema(config)
return super().input_schema

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def output_schema(self) -> Type[BaseModel]:
# This is correct, but pydantic typings/mypy don't think so.
return create_model( # type: ignore[call-overload]
"RunnableParallelOutput",
Expand Down Expand Up @@ -2056,9 +2040,8 @@ def InputType(self) -> Any:
except ValueError:
return Any

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
"""The pydantic schema for the input to this runnable."""
func = getattr(self, "func", None) or getattr(self, "afunc")

Expand All @@ -2083,7 +2066,7 @@ def get_input_schema(
**{key: (Any, None) for key in dict_keys}, # type: ignore
)

return super().get_input_schema(config)
return super().input_schema

@property
def OutputType(self) -> Any:
Expand Down Expand Up @@ -2232,13 +2215,12 @@ class Config:
def InputType(self) -> Any:
return List[self.bound.InputType] # type: ignore[name-defined]

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
return create_model(
"RunnableEachInput",
__root__=(
List[self.bound.get_input_schema(config)], # type: ignore
List[self.bound.input_schema], # type: ignore[name-defined]
None,
),
)
Expand All @@ -2247,14 +2229,12 @@ def get_input_schema(
def OutputType(self) -> Type[List[Output]]:
return List[self.bound.OutputType] # type: ignore[name-defined]

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
schema = self.bound.get_output_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return create_model(
"RunnableEachOutput",
__root__=(
List[schema], # type: ignore
List[self.bound.output_schema], # type: ignore[name-defined]
None,
),
)
Expand Down Expand Up @@ -2352,15 +2332,13 @@ def InputType(self) -> Type[Input]:
def OutputType(self) -> Type[Output]:
return self.bound.OutputType

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.bound.get_input_schema(merge_configs(self.config, config))
@property
def input_schema(self) -> Type[BaseModel]:
return self.bound.input_schema

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self.bound.get_output_schema(merge_configs(self.config, config))
@property
def output_schema(self) -> Type[BaseModel]:
return self.bound.output_schema

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
Expand Down
11 changes: 5 additions & 6 deletions libs/langchain/langchain/schema/runnable/branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,19 @@ def get_lc_namespace(cls) -> List[str]:
"""The namespace of a RunnableBranch is the namespace of its default branch."""
return cls.__module__.split(".")[:-1]

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
@property
def input_schema(self) -> Type[BaseModel]:
runnables = (
[self.default]
+ [r for _, r in self.branches]
+ [r for r, _ in self.branches]
)

for runnable in runnables:
if runnable.get_input_schema(config).schema().get("type") is not None:
return runnable.get_input_schema(config)
if runnable.input_schema.schema().get("type") is not None:
return runnable.input_schema

return super().get_input_schema(config)
return super().input_schema

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
Expand Down
14 changes: 6 additions & 8 deletions libs/langchain/langchain/schema/runnable/configurable.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,13 @@ def InputType(self) -> Type[Input]:
def OutputType(self) -> Type[Output]:
return self.default.OutputType

def get_input_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_input_schema(config)
@property
def input_schema(self) -> Type[BaseModel]:
return self.default.input_schema

def get_output_schema(
self, config: Optional[RunnableConfig] = None
) -> Type[BaseModel]:
return self._prepare(config).get_output_schema(config)
@property
def output_schema(self) -> Type[BaseModel]:
return self.default.output_schema

@abstractmethod
def _prepare(
Expand Down
Loading

0 comments on commit 85eaa4c

Please sign in to comment.