Skip to content

Commit 5fc0126

Browse files
authored
Merge pull request #5 from hwchase17/harrison/chains
add initial chains
2 parents 4cc39aa + 434234e commit 5fc0126

File tree

7 files changed

+190
-0
lines changed

7 files changed

+190
-0
lines changed

langchain/chains/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Chains are easily reusable components which can be linked together."""

langchain/chains/base.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Base interface that all chains should implement."""
2+
from abc import ABC, abstractmethod
3+
from typing import Any, Dict, List
4+
5+
6+
class Chain(ABC):
7+
"""Base interface that all chains should implement."""
8+
9+
@property
10+
@abstractmethod
11+
def input_keys(self) -> List[str]:
12+
"""Input keys this chain expects."""
13+
14+
@property
15+
@abstractmethod
16+
def output_keys(self) -> List[str]:
17+
"""Output keys this chain expects."""
18+
19+
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
20+
"""Check that all inputs are present."""
21+
missing_keys = set(self.input_keys).difference(inputs)
22+
if missing_keys:
23+
raise ValueError(f"Missing some input keys: {missing_keys}")
24+
25+
def _validate_outputs(self, outputs: Dict[str, str]) -> None:
26+
if set(outputs) != set(self.output_keys):
27+
raise ValueError(
28+
f"Did not get output keys that were expected. "
29+
f"Got: {set(outputs)}. Expected: {set(self.output_keys)}."
30+
)
31+
32+
@abstractmethod
33+
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
34+
"""Run the logic of this chain and return the output."""
35+
36+
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
37+
"""Run the logic of this chain and add to output."""
38+
self._validate_inputs(inputs)
39+
outputs = self._run(inputs)
40+
self._validate_outputs(outputs)
41+
return {**inputs, **outputs}

langchain/chains/llm.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Chain that just formats a prompt and calls an LLM."""
2+
from typing import Any, Dict, List
3+
4+
from pydantic import BaseModel, Extra
5+
6+
from langchain.chains.base import Chain
7+
from langchain.llms.base import LLM
8+
from langchain.prompt import Prompt
9+
10+
11+
class LLMChain(Chain, BaseModel):
12+
"""Chain to run queries against LLMs."""
13+
14+
prompt: Prompt
15+
llm: LLM
16+
return_key: str = "text"
17+
18+
class Config:
19+
"""Configuration for this pydantic object."""
20+
21+
extra = Extra.forbid
22+
arbitrary_types_allowed = True
23+
24+
@property
25+
def input_keys(self) -> List[str]:
26+
"""Will be whatever keys the prompt expects."""
27+
return self.prompt.input_variables
28+
29+
@property
30+
def output_keys(self) -> List[str]:
31+
"""Will always return text key."""
32+
return [self.return_key]
33+
34+
def _run(self, inputs: Dict[str, Any]) -> Dict[str, str]:
35+
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
36+
prompt = self.prompt.template.format(**selected_inputs)
37+
38+
kwargs = {}
39+
if "stop" in inputs:
40+
kwargs["stop"] = inputs["stop"]
41+
response = self.llm(prompt, **kwargs)
42+
return {self.return_key: response}
43+
44+
def predict(self, **kwargs: Any) -> str:
45+
"""More user-friendly interface for interacting with LLMs."""
46+
return self(kwargs)[self.return_key]
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Tests for correct functioning of chains."""
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""Test logic on base chain class."""
2+
from typing import Dict, List
3+
4+
import pytest
5+
from pydantic import BaseModel
6+
7+
from langchain.chains.base import Chain
8+
9+
10+
class FakeChain(Chain, BaseModel):
11+
"""Fake chain class for testing purposes."""
12+
13+
be_correct: bool = True
14+
15+
@property
16+
def input_keys(self) -> List[str]:
17+
"""Input key of foo."""
18+
return ["foo"]
19+
20+
@property
21+
def output_keys(self) -> List[str]:
22+
"""Output key of bar."""
23+
return ["bar"]
24+
25+
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
26+
if self.be_correct:
27+
return {"bar": "baz"}
28+
else:
29+
return {"baz": "bar"}
30+
31+
32+
def test_bad_inputs() -> None:
33+
"""Test errors are raised if input keys are not found."""
34+
chain = FakeChain()
35+
with pytest.raises(ValueError):
36+
chain({"foobar": "baz"})
37+
38+
39+
def test_bad_outputs() -> None:
40+
"""Test errors are raised if outputs keys are not found."""
41+
chain = FakeChain(be_correct=False)
42+
with pytest.raises(ValueError):
43+
chain({"foo": "baz"})
44+
45+
46+
def test_correct_call() -> None:
47+
"""Test correct call of fake chain."""
48+
chain = FakeChain()
49+
output = chain({"foo": "bar"})
50+
assert output == {"foo": "bar", "bar": "baz"}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Test LLM chain."""
2+
import pytest
3+
4+
from langchain.chains.llm import LLMChain
5+
from langchain.prompt import Prompt
6+
from tests.unit_tests.llms.fake_llm import FakeLLM
7+
8+
9+
@pytest.fixture
10+
def fake_llm_chain() -> LLMChain:
11+
"""Fake LLM chain for testing purposes."""
12+
prompt = Prompt(input_variables=["bar"], template="This is a {bar}:")
13+
return LLMChain(prompt=prompt, llm=FakeLLM(), return_key="text1")
14+
15+
16+
def test_missing_inputs(fake_llm_chain: LLMChain) -> None:
17+
"""Test error is raised if inputs are missing."""
18+
with pytest.raises(ValueError):
19+
fake_llm_chain({"foo": "bar"})
20+
21+
22+
def test_valid_call(fake_llm_chain: LLMChain) -> None:
23+
"""Test valid call of LLM chain."""
24+
output = fake_llm_chain({"bar": "baz"})
25+
assert output == {"bar": "baz", "text1": "foo"}
26+
27+
# Test with stop words.
28+
output = fake_llm_chain({"bar": "baz", "stop": ["foo"]})
29+
# Response should be `bar` now.
30+
assert output == {"bar": "baz", "stop": ["foo"], "text1": "bar"}
31+
32+
33+
def test_predict_method(fake_llm_chain: LLMChain) -> None:
34+
"""Test predict method works."""
35+
output = fake_llm_chain.predict(bar="baz")
36+
assert output == "foo"

tests/unit_tests/llms/fake_llm.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Fake LLM wrapper for testing purposes."""
2+
from typing import List, Optional
3+
4+
from langchain.llms.base import LLM
5+
6+
7+
class FakeLLM(LLM):
8+
"""Fake LLM wrapper for testing purposes."""
9+
10+
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
11+
"""Return `foo` if no stop words, otherwise `bar`."""
12+
if stop is None:
13+
return "foo"
14+
else:
15+
return "bar"

0 commit comments

Comments
 (0)