Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DynamicPrompt class creation #49

Merged
merged 8 commits into from
Nov 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion langchain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from langchain.docstore import Wikipedia
from langchain.faiss import FAISS
from langchain.llms import Cohere, HuggingFaceHub, OpenAI
from langchain.prompt import BasePrompt, Prompt
from langchain.prompt import BasePrompt, DynamicPrompt, Prompt
from langchain.sql_database import SQLDatabase

__all__ = [
Expand All @@ -29,6 +29,7 @@
"Cohere",
"OpenAI",
"BasePrompt",
"DynamicPrompt",
"Prompt",
"ReActChain",
"Wikipedia",
Expand Down
108 changes: 107 additions & 1 deletion langchain/prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Prompt schema definition."""
import re
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Callable, Dict, List

from pydantic import BaseModel, Extra, root_validator

Expand Down Expand Up @@ -126,3 +127,108 @@ def from_examples(
example_str = example_separator.join(examples)
template = prefix + example_str + suffix
return cls(input_variables=input_variables, template=template)


class DynamicPrompt(BaseModel, BasePrompt):
r"""Schema to represent a dynamic prompt for an LLM.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the r for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i forgot the linter told me to do this
./langchain/prompt.py:133:1: D301 Use r""" if any backslashes in a docstring


Example:
.. code-block:: python

from langchain import DynamicPrompt
dynamic_prompt = DynamicPrompt(
examples=["Say hi. Hi", "Say ho. Ho"],
example_separator="\n\n",
prefix="",
suffix="\n\nSay {foo}"
input_variables=["foo"],
max_length=200,
get_text_length=word_count
)
"""

examples: List[str]
"""A list of the examples that the prompt template expects."""

example_separator: str = "\n\n"
"""Example separator, e.g. \n\n, for the dynamic prompt creation."""

input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""

prefix: str
"""Prefix for the prompt."""

suffix: str
"""Suffix for the prompt."""

template_format: str = "f-string"
"""The format of the prompt template. Options are: 'f-string'."""

get_text_length: Callable[[str], int] = lambda x: len(re.split("\n| ", x))
"""Function to measure prompt length. Defaults to word count."""

max_length: int = 2048
"""Max length for the prompt, beyond which examples are cut."""

class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid

def template(self, example_list: List[str], **kwargs: Any) -> str:
"""Return template given example list."""
template = self.example_separator.join(
[self.prefix, *example_list, self.suffix]
)
return _FORMATTER_MAPPING[self.template_format](template, **kwargs)

def format(self, **kwargs: Any) -> str:
"""Dynamically format the prompt with the inputs.

Args:
kwargs: Any arguments to be passed to the prompt template.

Returns:
A formatted string.

Example:

.. code-block:: python

prompt.format(variable1="foo")
"""
curr_examples = self.examples
template = self.template(curr_examples, **kwargs)
while self.get_text_length(template) > self.max_length and curr_examples:
curr_examples = curr_examples[:-1]
template = self.template(curr_examples, **kwargs)
return template

@root_validator()
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that prefix, suffix and input variables are consistent."""
input_variables = values["input_variables"]
suffix = values["suffix"]
template_format = values["template_format"]
if template_format not in _FORMATTER_MAPPING:
valid_formats = list(_FORMATTER_MAPPING)
raise ValueError(
f"Invalid template format. Got `{template_format}`;"
f" should be one of {valid_formats}"
)
try:
result = values["get_text_length"]("foo")
assert isinstance(result, int)
except AssertionError:
raise ValueError(
"Invalid text length callable, must take string & return int;"
)
dummy_inputs = {input_variable: "foo" for input_variable in input_variables}
# TODO variables could be in prefix or suffix
try:
formatter_func = _FORMATTER_MAPPING[template_format]
formatter_func(suffix, **dummy_inputs)
sjwhitmore marked this conversation as resolved.
Show resolved Hide resolved
except KeyError:
raise ValueError("Invalid prompt schema.")
return values
110 changes: 110 additions & 0 deletions tests/unit_tests/test_dynamic_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""Test functionality related to dynamic prompts."""
from langchain.prompt import DynamicPrompt, Prompt

# FULL TEMPLATES
LONGER_TEMPLATE = """Test Prompt:

Question: who are you?
Answer: foo

Question: what are you?
Answer: bar

Question: {question}
Answer:"""
SHORTER_TEMPLATE = """Test Prompt:

Question: who are you?
Answer: foo

Question: {question}
Answer:"""
SHORTEST_TEMPLATE = """Test Prompt:

Question: {question}
Answer:"""

# DYNAMIC PROMPT COMPONENTS
PREFIX = """Test Prompt:"""
SUFFIX = """Question: {question}\nAnswer:"""
EXAMPLES = [
"""Question: who are you?\nAnswer: foo""",
"""Question: what are you?\nAnswer: bar""",
]

# INPUTS
TEST_LONG_QUESTION = """I am writing a really long question,
this probably is going to affect the example right?"""
TEST_LONGEST_QUESTION = """This question is super super super,
super super super super super super super super super super super,
super super super super long, this will affect the example right?"""
TEST_SHORT_QUESTION = "Short question?"


def test_dynamic_prompt_valid() -> None:
"""Test dynamic prompt can be successfully constructed from examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
)
prompt_cls = Prompt(input_variables=input_variables, template=LONGER_TEMPLATE)
dynamic_prompt_template = dynamic_prompt_cls.format(question="foo?")
prompt_template = prompt_cls.format(question="foo?")
assert dynamic_prompt_template == prompt_template
assert dynamic_prompt_cls.input_variables == prompt_cls.input_variables


def test_dynamic_prompt_trims_one_example() -> None:
"""Test dynamic prompt can trim one example."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONG_QUESTION)
shorter_prompt = SHORTER_TEMPLATE.format(question=TEST_LONG_QUESTION)
assert dynamic_prompt == shorter_prompt


def test_dynamic_prompt_trims_no_examples() -> None:
"""Test dynamic prompt can trim no examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_SHORT_QUESTION)
full_prompt = LONGER_TEMPLATE.format(question=TEST_SHORT_QUESTION)
assert dynamic_prompt == full_prompt


def test_dynamic_prompt_trims_all_examples() -> None:
"""Test dynamic prompt can trim all examples."""
input_variables = ["question"]
example_separator = "\n\n"
dynamic_prompt_cls = DynamicPrompt(
examples=EXAMPLES,
suffix=SUFFIX,
input_variables=input_variables,
example_separator=example_separator,
prefix=PREFIX,
max_length=30,
)
dynamic_prompt = dynamic_prompt_cls.format(question=TEST_LONGEST_QUESTION)
full_prompt = SHORTEST_TEMPLATE.format(question=TEST_LONGEST_QUESTION)
assert dynamic_prompt == full_prompt