Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions docs/examples/chains/llm_checker.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# LLMCheckerChain\n",
"This notebook showcases how to use LLMCheckerChain."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMCheckerChain\n",
"from langchain.llms import OpenAI\n",
"\n",
"llm = OpenAI(temperature=0.7)\n",
"\n",
"text = \"What type of mammal lays the biggest eggs?\"\n",
"\n",
"checker_chain = LLMCheckerChain(llm=llm, verbose=True)\n",
"\n",
"checker_chain.run(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
2 changes: 2 additions & 0 deletions langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
ConversationChain,
LLMBashChain,
LLMChain,
LLMCheckerChain,
LLMMathChain,
PALChain,
QAWithSourcesChain,
Expand All @@ -27,6 +28,7 @@
__all__ = [
"LLMChain",
"LLMBashChain",
"LLMCheckerChain",
"LLMMathChain",
"SelfAskWithSearchChain",
"SerpAPIWrapper",
Expand Down
2 changes: 2 additions & 0 deletions langchain/chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.base import LLMBashChain
from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.llm_requests import LLMRequestsChain
from langchain.chains.pal.base import PALChain
Expand All @@ -18,6 +19,7 @@
"ConversationChain",
"LLMChain",
"LLMBashChain",
"LLMCheckerChain",
"LLMMathChain",
"PALChain",
"QAWithSourcesChain",
Expand Down
4 changes: 4 additions & 0 deletions langchain/chains/llm_checker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Chain that tries to verify assumptions before answering a question.

Heavily borrowed from https://github.com/jagilley/fact-checker
"""
98 changes: 98 additions & 0 deletions langchain/chains/llm_checker/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Chain for question-answering with self-verification."""


from typing import Dict, List

from pydantic import BaseModel, Extra

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_checker.prompt import (
CHECK_ASSERTIONS_PROMPT,
CREATE_DRAFT_ANSWER_PROMPT,
LIST_ASSERTIONS_PROMPT,
REVISED_ANSWER_PROMPT,
)
from langchain.chains.sequential import SequentialChain
from langchain.llms.base import LLM
from langchain.prompts import PromptTemplate


class LLMCheckerChain(Chain, BaseModel):
"""Chain for question-answering with self-verification.

Example:
.. code-block:: python
from langchain import OpenAI, LLMCheckerChain
llm = OpenAI(temperature=0.7)
checker_chain = LLMCheckerChain(llm=llm)
"""

llm: LLM
"""LLM wrapper to use."""
create_draft_answer_prompt: PromptTemplate = CREATE_DRAFT_ANSWER_PROMPT
list_assertions_prompt: PromptTemplate = LIST_ASSERTIONS_PROMPT
check_assertions_prompt: PromptTemplate = CHECK_ASSERTIONS_PROMPT
revised_answer_prompt: PromptTemplate = REVISED_ANSWER_PROMPT
"""Prompt to use when questioning the documents."""
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
arbitrary_types_allowed = True

@property
def input_keys(self) -> List[str]:
"""Return the singular input key.

:meta private:
"""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return the singular output key.

:meta private:
"""
return [self.output_key]

def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
question = inputs[self.input_key]

create_draft_answer_chain = LLMChain(
llm=self.llm, prompt=self.create_draft_answer_prompt, output_key="statement"
)
list_assertions_chain = LLMChain(
llm=self.llm, prompt=self.list_assertions_prompt, output_key="assertions"
)
check_assertions_chain = LLMChain(
llm=self.llm,
prompt=self.check_assertions_prompt,
output_key="checked_assertions",
)

revised_answer_chain = LLMChain(
llm=self.llm,
prompt=self.revised_answer_prompt,
output_key="revised_statement",
)

chains = [
create_draft_answer_chain,
list_assertions_chain,
check_assertions_chain,
revised_answer_chain,
]

question_to_checked_assertions_chain = SequentialChain(
chains=chains,
input_variables=["question"],
output_variables=["revised_statement"],
verbose=True,
)
output = question_to_checked_assertions_chain({"question": question})
return {self.output_key: output["revised_statement"]}
31 changes: 31 additions & 0 deletions langchain/chains/llm_checker/prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate

_CREATE_DRAFT_ANSWER_TEMPLATE = """{question}\n\n"""
CREATE_DRAFT_ANSWER_PROMPT = PromptTemplate(
input_variables=["question"], template=_CREATE_DRAFT_ANSWER_TEMPLATE
)

_LIST_ASSERTIONS_TEMPLATE = """Here is a statement:
{statement}
Make a bullet point list of the assumptions you made when producing the above statement.\n\n"""
LIST_ASSERTIONS_PROMPT = PromptTemplate(
input_variables=["statement"], template=_LIST_ASSERTIONS_TEMPLATE
)

_CHECK_ASSERTIONS_TEMPLATE = """Here is a bullet point list of assertions:
{assertions}
For each assertion, determine whether it is true or false. If it is false, explain why.\n\n"""
CHECK_ASSERTIONS_PROMPT = PromptTemplate(
input_variables=["assertions"], template=_CHECK_ASSERTIONS_TEMPLATE
)

_REVISED_ANSWER_TEMPLATE = """{checked_assertions}

Question: In light of the above assertions and checks, how would you answer the question '{question}'?

Answer:"""
REVISED_ANSWER_PROMPT = PromptTemplate(
input_variables=["checked_assertions", "question"],
template=_REVISED_ANSWER_TEMPLATE,
)
43 changes: 43 additions & 0 deletions tests/unit_tests/chains/test_llm_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# flake8: noqa E501

"""Test LLMCheckerChain functionality."""

import pytest

from langchain.chains.llm_checker.base import LLMCheckerChain
from langchain.chains.llm_checker.prompt import (
_CHECK_ASSERTIONS_TEMPLATE,
_CREATE_DRAFT_ANSWER_TEMPLATE,
_LIST_ASSERTIONS_TEMPLATE,
_REVISED_ANSWER_TEMPLATE,
)
from tests.unit_tests.llms.fake_llm import FakeLLM


@pytest.fixture
def fake_llm_checker_chain() -> LLMCheckerChain:
"""Fake LLMCheckerChain for testing."""
queries = {
_CREATE_DRAFT_ANSWER_TEMPLATE.format(
question="Which mammal lays the biggest eggs?"
): "I don't know which mammal layers the biggest eggs.",
_LIST_ASSERTIONS_TEMPLATE.format(
statement="I don't know which mammal layers the biggest eggs.",
): "1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
_CHECK_ASSERTIONS_TEMPLATE.format(
assertions="1) I know that mammals lay eggs.\n2) I know that birds lay eggs.\n3) I know that birds are mammals.",
): "1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
_REVISED_ANSWER_TEMPLATE.format(
checked_assertions="1) I know that mammals lay eggs. TRUE\n2) I know that birds lay eggs. TRUE\n3) I know that birds are mammals. TRUE",
question="Which mammal lays the biggest eggs?",
): "I still don't know.",
}
fake_llm = FakeLLM(queries=queries)
return LLMCheckerChain(llm=fake_llm, input_key="q", output_key="a")


def test_simple_question(fake_llm_checker_chain: LLMCheckerChain) -> None:
"""Test simple question that should not need python."""
question = "Which mammal lays the biggest eggs?"
output = fake_llm_checker_chain.run(question)
assert output == "I still don't know."