diff --git a/examples/meta_prompting.py b/examples/meta_prompting.py index fd0056f3c..206308d73 100644 --- a/examples/meta_prompting.py +++ b/examples/meta_prompting.py @@ -11,7 +11,6 @@ """ import argparse -import outlines import outlines.text as text @@ -22,9 +21,9 @@ def solve(question): Let's solve this problem by splitting it into steps. """ - answer, prompt = solve(question) + _, completed = solve(question) - return prompt, answer + return completed def fill_in_the_blanks(question, model_name: str): @@ -39,10 +38,10 @@ def determine_goal(question): def solve(memory): """${memory}. Let's begin.""" - _, memory = determine_goal(question) - answer, full_interaction = solve(memory) + _, completed = determine_goal(question) + _, completed = solve(completed) - return full_interaction, answer + return completed def ask_an_expert(question, model_name: str): @@ -73,10 +72,10 @@ def get_answer(question, expert, memory): ${question} """ - expert, memory = find_expert(question) - answer, full_interaction = get_answer(question, expert, memory) + expert, completed = find_expert(question) + _, completed = get_answer(question, expert, completed) - return full_interaction, answer + return completed def ask_an_expert_simple(question, model_name: str): @@ -95,18 +94,16 @@ def get_answer(expert, memory): For instance,${expert} would answer """ - expert, memory = find_expert(question) - answer, full_interaction = get_answer(expert, memory) + expert, completed = find_expert(question) + answer, completed = get_answer(expert, completed) - return full_interaction, answer + return completed -def run_example(model_fn, question, model): - print("\n-----------------------------------------\n") - question_s = outlines.text.string() - fn = outlines.chain([question_s], model_fn(question_s, model)) - prompt, answer = fn(question) - print(f"{prompt}") +def run_example(model_fn, question, model_name): + completed = model_fn(question, model_name) + print(f"\n-----------------------") + print(f"{completed}") if __name__ == "__main__": @@ -121,16 +118,17 @@ def run_example(model_fn, question, model): math_q = "f(x) = x*x. What is f(f(3))?" sat_q = """ - Directions: In the following question, a related pair of words or phrases \ - is followed by five pairs of words or phrases. Choose the pair that best \ - expresses a relationship similar to that in the original pair. \ - - BRAGGART :: MODESTY - A) FLEDGLING : EXPERIENCE - B) EMBEZZLER : GREED - C) WALLFLOWER : TIMIDITY - D) INVALID : MALADY - E) CANDIDATE : AMBITION + +Directions: In the following question, a related pair of words or phrases +is followed by five pairs of words or phrases. Choose the pair that best +expresses a relationship similar to that in the original pair. + +BRAGGART :: MODESTY +A) FLEDGLING : EXPERIENCE +B) EMBEZZLER : GREED +C) WALLFLOWER : TIMIDITY +D) INVALID : MALADY +E) CANDIDATE : AMBITION """ alignment_q = "What should humankind do to ensure that artificial general intelligence is aligned?" diff --git a/outlines/__init__.py b/outlines/__init__.py index 970abe25b..fbf4c8382 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -31,7 +31,7 @@ from outlines.function import fn from outlines.image import as_image from outlines.program import chain, program -from outlines.text import as_string, completion, render +from outlines.text import as_string, completion, prompt, render __all__ = [ "chain", @@ -39,5 +39,6 @@ "as_string", "fn", "program", + "prompt", "render", ] diff --git a/outlines/text/__init__.py b/outlines/text/__init__.py index 6396d4b92..8651d0a8f 100644 --- a/outlines/text/__init__.py +++ b/outlines/text/__init__.py @@ -1,6 +1,6 @@ from .basic import * from .completion import completion -from .render import render +from .prompt import prompt, render from .var import as_string, string -__all__ = ["as_string", "completion", "string", "render"] +__all__ = ["as_string", "completion", "prompt", "string", "render"] diff --git a/outlines/text/completion.py b/outlines/text/completion.py index a28187dbb..df41ab8ab 100644 --- a/outlines/text/completion.py +++ b/outlines/text/completion.py @@ -1,6 +1,4 @@ -import inspect - -from outlines.text.render import render +from outlines.text.prompt import prompt def completion(name: str, stops_at=None): @@ -62,20 +60,7 @@ def completion(name: str, stops_at=None): raise NameError(f"The model provider {provider_name} is not available.") def decorator(fn): - # Get the names of the parameters to the function, which must correspond - # to the variables defined in the template. - var_names = [] - kwargs_data = {} - sig = inspect.signature(fn) - for parameter in sig.parameters.values(): - if parameter.default == inspect._empty: - var_names.append(parameter.name) - else: - kwargs_data[parameter.name] = parameter.default - - # The docstring contains the template that will be rendered to be used - # as a prompt to the language model. - template = inspect.cleandoc(fn.__doc__) + prompt_fn = prompt(fn) def wrapper(*args, **kwargs): """Call the LLM with the rendered template. @@ -91,11 +76,7 @@ def wrapper(*args, **kwargs): call. """ - args_data = {name: arg for name, arg in zip(var_names, args)} - kwargs_data.update(kwargs) - data = {**args_data, **kwargs_data} - - prompt = render(template, **data) + prompt = prompt_fn(*args, **kwargs) result = llm(prompt) return result, prompt + result diff --git a/outlines/text/render.py b/outlines/text/prompt.py similarity index 59% rename from outlines/text/render.py rename to outlines/text/prompt.py index 32e9d8259..1f6fa0a64 100644 --- a/outlines/text/render.py +++ b/outlines/text/prompt.py @@ -1,6 +1,6 @@ import collections import inspect -from typing import Dict, Union +from typing import Callable, Dict, Union from mako.runtime import Context from mako.template import Template @@ -94,3 +94,59 @@ def render( mako_template.render_context(ctx) return buf.get_value() + + +def prompt(fn: Callable): + """Decorator around a function that contains a prompt template. + + This allows to define prompts in the docstring of a function and ease their + manipulation by providing some degree of encapsulation. + + >>> import outlines + >>> + >>> @outlines.prompt + >>> def answer_tpl(question): + ... "I have a ${question}" + ... + >>> prompt = answer_tpl("How are you?") + + This is syntactic sugar and uses the `render` function internally. + Therefore, the wrapped functions return `str` when called with `str` + arguments only, and a `StringVariable` when at least one argument is a + `StringVariable`. + + """ + + # Get the names of the parameters to the function, which must correspond + # to the variables defined in the template. + var_names = [] + kwargs_data = {} + sig = inspect.signature(fn) + for parameter in sig.parameters.values(): + if parameter.default == inspect._empty: + var_names.append(parameter.name) + else: + kwargs_data[parameter.name] = parameter.default + + # The docstring contains the template that will be rendered to be used + # as a prompt to the language model. + docstring = fn.__doc__ + if docstring is None: + raise TypeError("Could not find a template in the function's docstring.") + else: + template = inspect.cleandoc(docstring) + + def wrapper(*args, **kwargs): + """Render and return the template. + + Returns + ------- + A Python `str` when all arguments are Python `str`, a `StringVariable` + otherwise. + + """ + bound_arguments = sig.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return render(template, **bound_arguments.arguments) + + return wrapper diff --git a/tests/text/test_compose.py b/tests/text/test_prompt.py similarity index 50% rename from tests/text/test_compose.py rename to tests/text/test_prompt.py index f7f161b3f..659b0f814 100644 --- a/tests/text/test_compose.py +++ b/tests/text/test_prompt.py @@ -1,5 +1,6 @@ import pytest +import outlines.text as text from outlines.text import render, string from outlines.text.basic import Add from outlines.text.var import StringConstant, StringVariable @@ -59,3 +60,76 @@ def test_template_few_shots(): examples=examples, ) assert isinstance(prompt, StringVariable) + + +def test_prompt_basic(): + @text.prompt + def test_tpl(variable): + """${variable} test""" + + with pytest.raises(TypeError): + test_tpl(v="test") + + p = test_tpl("test") + assert p == "test test" + + p = test_tpl(variable="test") + assert p == "test test" + + @text.prompt + def test_single_quote_tpl(variable): + "${variable} test" + + p = test_tpl("test") + assert p == "test test" + + +def test_prompt_kwargs(): + @text.prompt + def test_kwarg_tpl(var, other_var="other"): + """${var} and ${other_var}""" + + p = test_kwarg_tpl("test") + assert p == "test and other" + + p = test_kwarg_tpl("test", other_var="kwarg") + assert p == "test and kwarg" + + p = test_kwarg_tpl("test", "test") + assert p == "test and test" + + +def test_not_prompt(): + with pytest.raises(TypeError, match="template"): + + @text.prompt + def test_empty(variable): + pass + + with pytest.raises(TypeError, match="template"): + + @text.prompt + def test_only_code(variable): + return variable + + +def test_prompt_few_shots(): + @text.prompt + def few_shots_tpl(w, examples): + """This is a test + + ${w} + + % for s, t in examples: + Search: ${s} + Trap: ${t} + % endfor + """ + + prompt = few_shots_tpl("Test", [["a", "b"], ["c", "d"]]) + assert ( + prompt == "This is a test\n\nTest\n\nSearch: a\nTrap: b\nSearch: c\nTrap: d\n" + ) + + prompt = few_shots_tpl(string(), [["a", "b"], ["c", "d"]]) + assert isinstance(prompt, StringVariable)