Skip to content

Commit

Permalink
Add convenience function to render prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Apr 11, 2023
1 parent 2a2109f commit e00a769
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 54 deletions.
54 changes: 26 additions & 28 deletions examples/meta_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"""
import argparse

import outlines
import outlines.text as text


Expand All @@ -22,9 +21,9 @@ def solve(question):
Let's solve this problem by splitting it into steps.
"""

answer, prompt = solve(question)
_, completed = solve(question)

return prompt, answer
return completed


def fill_in_the_blanks(question, model_name: str):
Expand All @@ -39,10 +38,10 @@ def determine_goal(question):
def solve(memory):
"""${memory}. Let's begin."""

_, memory = determine_goal(question)
answer, full_interaction = solve(memory)
_, completed = determine_goal(question)
_, completed = solve(completed)

return full_interaction, answer
return completed


def ask_an_expert(question, model_name: str):
Expand Down Expand Up @@ -73,10 +72,10 @@ def get_answer(question, expert, memory):
${question}
"""

expert, memory = find_expert(question)
answer, full_interaction = get_answer(question, expert, memory)
expert, completed = find_expert(question)
_, completed = get_answer(question, expert, completed)

return full_interaction, answer
return completed


def ask_an_expert_simple(question, model_name: str):
Expand All @@ -95,18 +94,16 @@ def get_answer(expert, memory):
For instance,${expert} would answer
"""

expert, memory = find_expert(question)
answer, full_interaction = get_answer(expert, memory)
expert, completed = find_expert(question)
answer, completed = get_answer(expert, completed)

return full_interaction, answer
return completed


def run_example(model_fn, question, model):
print("\n-----------------------------------------\n")
question_s = outlines.text.string()
fn = outlines.chain([question_s], model_fn(question_s, model))
prompt, answer = fn(question)
print(f"{prompt}")
def run_example(model_fn, question, model_name):
completed = model_fn(question, model_name)
print(f"\n-----------------------")
print(f"{completed}")


if __name__ == "__main__":
Expand All @@ -121,16 +118,17 @@ def run_example(model_fn, question, model):

math_q = "f(x) = x*x. What is f(f(3))?"
sat_q = """
Directions: In the following question, a related pair of words or phrases \
is followed by five pairs of words or phrases. Choose the pair that best \
expresses a relationship similar to that in the original pair. \
BRAGGART :: MODESTY
A) FLEDGLING : EXPERIENCE
B) EMBEZZLER : GREED
C) WALLFLOWER : TIMIDITY
D) INVALID : MALADY
E) CANDIDATE : AMBITION
Directions: In the following question, a related pair of words or phrases
is followed by five pairs of words or phrases. Choose the pair that best
expresses a relationship similar to that in the original pair.
BRAGGART :: MODESTY
A) FLEDGLING : EXPERIENCE
B) EMBEZZLER : GREED
C) WALLFLOWER : TIMIDITY
D) INVALID : MALADY
E) CANDIDATE : AMBITION
"""
alignment_q = "What should humankind do to ensure that artificial general intelligence is aligned?"
Expand Down
3 changes: 2 additions & 1 deletion outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@
from outlines.function import fn
from outlines.image import as_image
from outlines.program import chain, program
from outlines.text import as_string, completion, render
from outlines.text import as_string, completion, prompt, render

__all__ = [
"chain",
"as_image",
"as_string",
"fn",
"program",
"prompt",
"render",
]
4 changes: 2 additions & 2 deletions outlines/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .basic import *
from .completion import completion
from .render import render
from .prompt import prompt, render
from .var import as_string, string

__all__ = ["as_string", "completion", "string", "render"]
__all__ = ["as_string", "completion", "prompt", "string", "render"]
25 changes: 3 additions & 22 deletions outlines/text/completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import inspect

from outlines.text.render import render
from outlines.text.prompt import prompt


def completion(name: str, stops_at=None):
Expand Down Expand Up @@ -62,20 +60,7 @@ def completion(name: str, stops_at=None):
raise NameError(f"The model provider {provider_name} is not available.")

def decorator(fn):
# Get the names of the parameters to the function, which must correspond
# to the variables defined in the template.
var_names = []
kwargs_data = {}
sig = inspect.signature(fn)
for parameter in sig.parameters.values():
if parameter.default == inspect._empty:
var_names.append(parameter.name)
else:
kwargs_data[parameter.name] = parameter.default

# The docstring contains the template that will be rendered to be used
# as a prompt to the language model.
template = inspect.cleandoc(fn.__doc__)
prompt_fn = prompt(fn)

def wrapper(*args, **kwargs):
"""Call the LLM with the rendered template.
Expand All @@ -91,11 +76,7 @@ def wrapper(*args, **kwargs):
call.
"""
args_data = {name: arg for name, arg in zip(var_names, args)}
kwargs_data.update(kwargs)
data = {**args_data, **kwargs_data}

prompt = render(template, **data)
prompt = prompt_fn(*args, **kwargs)
result = llm(prompt)
return result, prompt + result

Expand Down
58 changes: 57 additions & 1 deletion outlines/text/render.py → outlines/text/prompt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import collections
import inspect
from typing import Dict, Union
from typing import Callable, Dict, Union

from mako.runtime import Context
from mako.template import Template
Expand Down Expand Up @@ -94,3 +94,59 @@ def render(
mako_template.render_context(ctx)

return buf.get_value()


def prompt(fn: Callable):
"""Decorator around a function that contains a prompt template.
This allows to define prompts in the docstring of a function and ease their
manipulation by providing some degree of encapsulation.
>>> import outlines
>>>
>>> @outlines.prompt
>>> def answer_tpl(question):
... "I have a ${question}"
...
>>> prompt = answer_tpl("How are you?")
This is syntactic sugar and uses the `render` function internally.
Therefore, the wrapped functions return `str` when called with `str`
arguments only, and a `StringVariable` when at least one argument is a
`StringVariable`.
"""

# Get the names of the parameters to the function, which must correspond
# to the variables defined in the template.
var_names = []
kwargs_data = {}
sig = inspect.signature(fn)
for parameter in sig.parameters.values():
if parameter.default == inspect._empty:
var_names.append(parameter.name)
else:
kwargs_data[parameter.name] = parameter.default

# The docstring contains the template that will be rendered to be used
# as a prompt to the language model.
docstring = fn.__doc__
if docstring is None:
raise TypeError("Could not find a template in the function's docstring.")
else:
template = inspect.cleandoc(docstring)

def wrapper(*args, **kwargs):
"""Render and return the template.
Returns
-------
A Python `str` when all arguments are Python `str`, a `StringVariable`
otherwise.
"""
bound_arguments = sig.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return render(template, **bound_arguments.arguments)

return wrapper
74 changes: 74 additions & 0 deletions tests/text/test_compose.py → tests/text/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

import outlines.text as text
from outlines.text import render, string
from outlines.text.basic import Add
from outlines.text.var import StringConstant, StringVariable
Expand Down Expand Up @@ -59,3 +60,76 @@ def test_template_few_shots():
examples=examples,
)
assert isinstance(prompt, StringVariable)


def test_prompt_basic():
@text.prompt
def test_tpl(variable):
"""${variable} test"""

with pytest.raises(TypeError):
test_tpl(v="test")

p = test_tpl("test")
assert p == "test test"

p = test_tpl(variable="test")
assert p == "test test"

@text.prompt
def test_single_quote_tpl(variable):
"${variable} test"

p = test_tpl("test")
assert p == "test test"


def test_prompt_kwargs():
@text.prompt
def test_kwarg_tpl(var, other_var="other"):
"""${var} and ${other_var}"""

p = test_kwarg_tpl("test")
assert p == "test and other"

p = test_kwarg_tpl("test", other_var="kwarg")
assert p == "test and kwarg"

p = test_kwarg_tpl("test", "test")
assert p == "test and test"


def test_not_prompt():
with pytest.raises(TypeError, match="template"):

@text.prompt
def test_empty(variable):
pass

with pytest.raises(TypeError, match="template"):

@text.prompt
def test_only_code(variable):
return variable


def test_prompt_few_shots():
@text.prompt
def few_shots_tpl(w, examples):
"""This is a test
${w}
% for s, t in examples:
Search: ${s}
Trap: ${t}
% endfor
"""

prompt = few_shots_tpl("Test", [["a", "b"], ["c", "d"]])
assert (
prompt == "This is a test\n\nTest\n\nSearch: a\nTrap: b\nSearch: c\nTrap: d\n"
)

prompt = few_shots_tpl(string(), [["a", "b"], ["c", "d"]])
assert isinstance(prompt, StringVariable)

0 comments on commit e00a769

Please sign in to comment.