Skip to content

Commit 5e2c0a9

Browse files
authored
Feature/prompt generation (#12)
* added prompt_creation.py * change version
1 parent 37591b7 commit 5e2c0a9

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Utility functions for prompt creation."""
2+
3+
from typing import List, Union
4+
5+
import numpy as np
6+
7+
from promptolution.llms.base_llm import BaseLLM
8+
from promptolution.tasks.base_task import BaseTask
9+
from promptolution.tasks.classification_tasks import ClassificationTask
10+
11+
12+
def create_prompt_variation(prompt: Union[List[str], str], llm: BaseLLM, meta_prompt: str = None) -> List[str]:
13+
"""Generate a variation of the given prompt(s) while keeping the semantic meaning.
14+
15+
Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910
16+
17+
Args:
18+
prompt (Union[List[str], str]): The prompt(s) to generate variations of.
19+
llm (BaseLLM): The language model to use for generating the variations.
20+
meta_prompt (str): The meta prompt to use for generating the variations.
21+
If None, a default meta prompt is used. Should contain <prev_prompt> tag.
22+
23+
Returns:
24+
List[str]: A list of generated variations of the input prompt(s).
25+
"""
26+
if meta_prompt is None:
27+
meta_prompt = """Generate a single variation of the following instruction while keeping the semantic meaning.
28+
Generate the variation starting with <prompt> and ending with </prompt> tags.
29+
30+
Input: <prev_prompt>
31+
32+
Output:"""
33+
34+
if isinstance(prompt, str):
35+
prompt = [prompt]
36+
varied_prompts = llm.get_response([meta_prompt.replace("<prev_prompt>", p) for p in prompt])
37+
38+
varied_prompts = [p.split("</prompt>")[0].split("<prompt>")[-1] for p in varied_prompts]
39+
40+
return varied_prompts
41+
42+
43+
def create_prompts_from_samples(task: BaseTask, llm: BaseLLM, meta_prompt: str = None, n_samples: int = 3) -> List[str]:
44+
"""Generate a set of prompts from dataset examples sampled from a given task.
45+
46+
Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910
47+
Samples are selected, such that
48+
(1) all possible classes are represented
49+
(2) the samples are as representative as possible
50+
51+
Args:
52+
task (BaseTask): The task to generate prompts for.
53+
Xs and Ys from this object are used to generate the prompts.
54+
llm (BaseLLM): The language model to use for generating the prompts.
55+
meta_prompt (str): The meta prompt to use for generating the prompts.
56+
If None, a default meta prompt is used.
57+
n_samples (int): The number of samples to use for generating prompts.
58+
59+
Returns:
60+
List[str]: A list of generated prompts.
61+
"""
62+
if isinstance(task, ClassificationTask):
63+
# if classification task sample such that all classes are represented
64+
unique_classes, counts = np.unique(task.ys, return_counts=True)
65+
proportions = counts / len(task.ys)
66+
samples_per_class = np.round(proportions * n_samples).astype(int)
67+
samples_per_class = np.maximum(samples_per_class, 1)
68+
69+
# sample
70+
xs = []
71+
ys = []
72+
for cls, n_samples in zip(unique_classes, samples_per_class):
73+
indices = np.where(task.ys == cls)[0]
74+
indices = np.random.choice(indices, n_samples, replace=False)
75+
xs.extend(task.xs[indices])
76+
ys.extend(task.ys[indices])
77+
78+
else:
79+
# if not classification task, sample randomly
80+
indices = np.random.choice(len(task.xs), n_samples, replace=False)
81+
xs = task.xs[indices].tolist()
82+
ys = task.ys[indices].tolist()
83+
84+
if meta_prompt is None:
85+
meta_prompt = (
86+
"You are asked to give the corresponding prompt that gives the following outputs given these inputs."
87+
+ "Return it starting with <prompt> and ending with </prompt> tags."
88+
+ "Include the name of the output classes in the prompt."
89+
)
90+
91+
for x, y in zip(xs, ys):
92+
meta_prompt += f"\n\nInput: {x}\nOutput: {y}"
93+
94+
meta_prompt += "\nThe instruction was"
95+
96+
prompt = llm.get_response([meta_prompt])[0]
97+
prompt = prompt.split("</prompt>")[0].split("<prompt>")[-1]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "promptolution"
3-
version = "0.1.1"
3+
version = "0.2.0"
44
description = ""
55
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
66
readme = "README.md"

scripts/prompt_creation_run.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""Script to run prompt creation and evaluation."""
2+
3+
from configparser import ConfigParser
4+
from logging import Logger
5+
6+
from promptolution.llms import get_llm
7+
from promptolution.predictors import get_predictor
8+
from promptolution.tasks import get_tasks
9+
from promptolution.utils.prompt_creation import create_prompt_variation, create_prompts_from_samples
10+
11+
logger = Logger(__name__)
12+
13+
14+
def main():
15+
"""Main function to run the experiment."""
16+
config = ConfigParser()
17+
config.task_name = "subj"
18+
config.ds_path = "data_sets/cls/subj"
19+
config.random_seed = 42
20+
21+
llm = get_llm("meta-llama/Meta-Llama-3-8B-Instruct")
22+
task = get_tasks(config)[0]
23+
predictor = get_predictor("meta-llama/Meta-Llama-3-8B-Instruct", classes=task.classes)
24+
25+
init_prompts = create_prompts_from_samples(task, llm)
26+
logger.critical(f"Initial prompts: {init_prompts}")
27+
28+
# evaluate on task
29+
scores = task.evaluate(init_prompts, predictor)
30+
logger.critical(f"Initial scores {scores.mean()} +/- {scores.std()}")
31+
32+
varied_prompts = create_prompt_variation(init_prompts, llm)[0]
33+
34+
logger.critical(f"Varied prompts: {varied_prompts}")
35+
36+
# evaluate on task
37+
scores = task.evaluate(varied_prompts, predictor)
38+
logger.critical(f"Varied scores {scores.mean()} +/- {scores.std()}")
39+
40+
41+
if __name__ == "__main__":
42+
main()

0 commit comments

Comments
 (0)