Skip to content

Commit 03fec71

Browse files
finitearthmo374z
andauthored
Fix/template (#39)
* v1.3.1 (#37) #### Added features * new features for the VLLM Wrapper (accept seeding to ensure reproducibility) * fixes in the "MarkerBasedClassificator" * fixes in prompt creation and task description handling * generalize the Classificator * add verbosity and callback handling in EvoPromptGA * add timestamp to the callback * removed datasets from repo * changed task creation (now by default with a dataset) * add generation prompt to vllm input * allow for parquet as fileoutput callback * added sys_prompts * change usage of csv callbacks * add system prompt to token counts * fix merge issues * drag system prompts from api to task * added release notes --------- Co-authored-by: Moritz Schlager <87517800+mo374z@users.noreply.github.com>
1 parent 39b0779 commit 03fec71

File tree

14 files changed

+88
-114
lines changed

14 files changed

+88
-114
lines changed

docs/release-notes.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Release Notes
22

3+
## Release v1.3.2
4+
### What's changed
5+
#### Added features
6+
* Allow for configuration and evaluation of system prompts in all LLM-Classes
7+
* CSV Callback is now FileOutputCallback and able to write Parquet files
8+
9+
#### Further Changes:
10+
* Fixed LLM-Call templates in VLLM
11+
* refined OPRO-implementation to be closer to the paper
12+
313
## Release v1.3.1
414
### What's changed
515
#### Added features

promptolution/callbacks.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -88,32 +88,37 @@ def on_train_end(self, optimizer, logs=None):
8888
return True
8989

9090

91-
class CSVCallback(Callback):
92-
"""Callback for saving optimization progress to a CSV file.
91+
class FileOutputCallback(Callback):
92+
"""Callback for saving optimization progress to a specified file type.
9393
94-
This callback saves prompts and scores at each step to a CSV file.
94+
This callback saves information about each step to a file.
9595
9696
Attributes:
97-
dir (str): Directory the CSV file is saved to.
97+
dir (str): Directory the file is saved to.
9898
step (int): The current step number.
99+
file_type (str): The type of file to save the output to.
99100
"""
100101

101-
def __init__(self, dir):
102-
"""Initialize the CSVCallback.
102+
def __init__(self, dir, file_type: Literal["parquet", "csv"] = "parquet"):
103+
"""Initialize the FileOutputCallback.
103104
104105
Args:
105106
dir (str): Directory the CSV file is saved to.
107+
file_type (str): The type of file to save the output to.
106108
"""
107109
if not os.path.exists(dir):
108110
os.makedirs(dir)
109111

110-
self.dir = dir
111-
self.dir = dir
112+
self.file_type = file_type
113+
114+
if file_type == "parquet":
115+
self.path = dir + "/step_results.parquet"
116+
elif file_type == "csv":
117+
self.path = dir + "/step_results.csv"
118+
else:
119+
raise ValueError(f"File type {file_type} not supported.")
120+
112121
self.step = 0
113-
self.input_tokens = 0
114-
self.output_tokens = 0
115-
self.start_time = datetime.now()
116-
self.step_time = datetime.now()
117122

118123
def on_step_end(self, optimizer):
119124
"""Save prompts and scores to csv.
@@ -125,47 +130,24 @@ def on_step_end(self, optimizer):
125130
df = pd.DataFrame(
126131
{
127132
"step": [self.step] * len(optimizer.prompts),
128-
"input_tokens": [optimizer.meta_llm.input_token_count - self.input_tokens] * len(optimizer.prompts),
129-
"output_tokens": [optimizer.meta_llm.output_token_count - self.output_tokens] * len(optimizer.prompts),
130-
"time_elapsed": [(datetime.now() - self.step_time).total_seconds()] * len(optimizer.prompts),
133+
"input_tokens": [optimizer.meta_llm.input_token_count] * len(optimizer.prompts),
134+
"output_tokens": [optimizer.meta_llm.output_token_count] * len(optimizer.prompts),
135+
"time": [datetime.now().total_seconds()] * len(optimizer.prompts),
131136
"score": optimizer.scores,
132137
"prompt": optimizer.prompts,
133138
}
134139
)
135-
self.step_time = datetime.now()
136-
self.input_tokens = optimizer.meta_llm.input_token_count
137-
self.output_tokens = optimizer.meta_llm.output_token_count
138140

139-
if not os.path.exists(self.dir + "step_results.csv"):
140-
df.to_csv(self.dir + "step_results.csv", index=False)
141-
else:
142-
df.to_csv(self.dir + "step_results.csv", mode="a", header=False, index=False)
143-
144-
return True
145-
146-
def on_train_end(self, optimizer):
147-
"""Called at the end of training.
148-
149-
Args:
150-
optimizer: The optimizer object that called the callback.
151-
"""
152-
df = pd.DataFrame(
153-
dict(
154-
steps=self.step,
155-
input_tokens=optimizer.meta_llm.input_token_count,
156-
output_tokens=optimizer.meta_llm.output_token_count,
157-
time_elapsed=(datetime.now() - self.start_time).total_seconds(),
158-
time=datetime.now(),
159-
score=np.array(optimizer.scores).mean(),
160-
best_prompts=str(optimizer.prompts),
161-
),
162-
index=[0],
163-
)
164-
165-
if not os.path.exists(self.dir + "train_results.csv"):
166-
df.to_csv(self.dir + "train_results.csv", index=False)
167-
else:
168-
df.to_csv(self.dir + "train_results.csv", mode="a", header=False, index=False)
141+
if self.file_type == "parquet":
142+
if self.step == 1:
143+
df.to_parquet(self.path, index=False)
144+
else:
145+
df.to_parquet(self.path, mode="a", index=False)
146+
elif self.file_type == "csv":
147+
if self.step == 1:
148+
df.to_csv(self.path, index=False)
149+
else:
150+
df.to_csv(self.path, mode="a", header=False, index=False)
169151

170152
return True
171153

promptolution/llms/api_llm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@
1010
import requests
1111
from langchain_anthropic import ChatAnthropic
1212
from langchain_community.chat_models.deepinfra import ChatDeepInfra, ChatDeepInfraException
13-
from langchain_core.messages import HumanMessage
13+
from langchain_core.messages import HumanMessage, SystemMessage
1414
from langchain_openai import ChatOpenAI
1515

1616
from promptolution.llms.base_llm import BaseLLM
1717

1818
logger = Logger(__name__)
1919

2020

21-
async def invoke_model(prompt, model, semaphore):
21+
async def invoke_model(prompt, system_prompt, model, semaphore):
2222
"""Asynchronously invoke a language model with retry logic.
2323
2424
Args:
2525
prompt (str): The input prompt for the model.
26+
system_prompt (str): The system prompt for the model.
2627
model: The language model to invoke.
2728
semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls.
2829
@@ -39,7 +40,7 @@ async def invoke_model(prompt, model, semaphore):
3940

4041
while attempts < max_retries:
4142
try:
42-
response = await model.ainvoke([HumanMessage(content=prompt)])
43+
response = await model.ainvoke([SystemMessage(content=system_prompt), HumanMessage(content=prompt)])
4344
return response.content
4445
except ChatDeepInfraException as e:
4546
print(f"DeepInfra error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds...")
@@ -80,13 +81,14 @@ def __init__(self, model_id: str, token: str = None, **kwargs: Any):
8081
else:
8182
self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token)
8283

83-
def _get_response(self, prompts: List[str]) -> List[str]:
84+
def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
8485
"""Get responses for a list of prompts in a synchronous manner.
8586
8687
This method includes retry logic for handling connection errors and rate limits.
8788
8889
Args:
8990
prompts (list[str]): List of input prompts.
91+
system_prompts (list[str]): List of system prompts. If not provided, uses default system_prompts
9092
9193
Returns:
9294
list[str]: List of model responses.

promptolution/llms/base_llm.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import numpy as np
88

9+
from promptolution.templates import DEFAULT_SYS_PROMPT
10+
911
logger = logging.getLogger(__name__)
1012

1113

@@ -54,7 +56,7 @@ def update_token_count(self, inputs: List[str], outputs: List[str]):
5456
self.input_token_count += input_tokens
5557
self.output_token_count += output_tokens
5658

57-
def get_response(self, prompts: str) -> str:
59+
def get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
5860
"""Generate responses for the given prompts.
5961
6062
This method calls the _get_response method to generate responses
@@ -64,31 +66,37 @@ def get_response(self, prompts: str) -> str:
6466
Args:
6567
prompts (str or List[str]): Input prompt(s). If a single string is provided,
6668
it's converted to a list containing that string.
69+
system_prompts (str or List[str]): System prompt(s) to provide context to the model.
6770
6871
Returns:
6972
List[str]: A list of generated responses, one for each input prompt.
7073
"""
74+
if system_prompts is None:
75+
system_prompts = DEFAULT_SYS_PROMPT
7176
if isinstance(prompts, str):
7277
prompts = [prompts]
73-
responses = self._get_response(prompts)
74-
self.update_token_count(prompts, responses)
78+
if isinstance(system_prompts, str):
79+
system_prompts = [system_prompts] * len(prompts)
80+
responses = self._get_response(prompts, system_prompts)
81+
self.update_token_count(prompts + system_prompts, responses)
7582

7683
return responses
7784

7885
@abstractmethod
79-
def _get_response(self, prompts: List[str]) -> List[str]:
86+
def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
8087
"""Generate responses for the given prompts.
8188
8289
This method should be implemented by subclasses to define how
8390
the LLM generates responses.
8491
8592
Args:
8693
prompts (List[str]): A list of input prompts.
94+
system_prompts (List[str]): A list of system prompts to provide context to the model.
8795
8896
Returns:
8997
List[str]: A list of generated responses corresponding to the input prompts.
9098
"""
91-
pass
99+
raise NotImplementedError
92100

93101

94102
class DummyLLM(BaseLLM):

promptolution/llms/local_llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, model_id: str, batch_size=8):
5050
self.pipeline.tokenizer.pad_token_id = self.pipeline.tokenizer.eos_token_id
5151
self.pipeline.tokenizer.padding_side = "left"
5252

53-
def _get_response(self, prompts: list[str]):
53+
def _get_response(self, prompts: list[str], system_prompts: list[str]) -> list[str]:
5454
"""Generate responses for a list of prompts using the local language model.
5555
5656
Args:
@@ -63,8 +63,12 @@ def _get_response(self, prompts: list[str]):
6363
This method uses torch.no_grad() for inference to reduce memory usage.
6464
It handles both single and batch inputs, ensuring consistent output format.
6565
"""
66+
inputs = []
67+
for prompt, sys_prompt in zip(prompts, system_prompts):
68+
inputs.append([{"role": "system", "prompt": sys_prompt}, {"role": "user", "prompt": prompt}])
69+
6670
with torch.no_grad():
67-
response = self.pipeline(prompts, pad_token_id=self.pipeline.tokenizer.eos_token_id)
71+
response = self.pipeline(inputs, pad_token_id=self.pipeline.tokenizer.eos_token_id)
6872

6973
if len(response) != 1:
7074
response = [r[0] if isinstance(r, list) else r for r in response]

promptolution/llms/vllm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
# Initialize tokenizer separately for potential pre-processing
109109
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
110110

111-
def _get_response(self, inputs: list[str]):
111+
def _get_response(self, prompts: list[str], system_prompts: list[str]) -> list[str]:
112112
"""Generate responses for a list of prompts using the vLLM engine.
113113
114114
Args:
@@ -126,13 +126,14 @@ def _get_response(self, inputs: list[str]):
126126
[
127127
{
128128
"role": "system",
129-
"content": "You are a helpful assistant.",
129+
"content": sys_prompt,
130130
},
131-
{"role": "user", "content": input},
131+
{"role": "user", "content": prompt},
132132
],
133133
tokenize=False,
134+
add_generation_prompt=True,
134135
)
135-
for input in inputs
136+
for prompt, sys_prompt in zip(prompts, system_prompts)
136137
]
137138

138139
# generate responses for self.batch_size prompts at the same time

promptolution/predictors/base_predictor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def __init__(self, llm: BaseLLM):
3131
"""
3232
self.llm = llm
3333

34-
def predict(self, prompts: List[str], xs: np.ndarray, return_seq: bool = False) -> np.ndarray:
34+
def predict(
35+
self, prompts: List[str], xs: np.ndarray, system_prompts: List[str] = None, return_seq: bool = False
36+
) -> np.ndarray:
3537
"""Abstract method to make predictions based on prompts and input data.
3638
3739
Args:
@@ -48,7 +50,9 @@ def predict(self, prompts: List[str], xs: np.ndarray, return_seq: bool = False)
4850
if isinstance(prompts, str):
4951
prompts = [prompts]
5052

51-
outputs = self.llm.get_response([prompt + "\n" + x for prompt in prompts for x in xs])
53+
outputs = self.llm.get_response(
54+
[prompt + "\n" + x for prompt in prompts for x in xs], system_prompts=system_prompts
55+
)
5256
preds = self._extract_preds(outputs)
5357

5458
shape = (len(prompts), len(xs))

promptolution/tasks/base_task.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ def __init__(self, *args, **kwargs):
2121
pass
2222

2323
@abstractmethod
24-
def evaluate(self, prompts: List[str], predictor) -> np.ndarray:
24+
def evaluate(self, prompts: List[str], predictor, system_promtps: List[str] = None) -> np.ndarray:
2525
"""Abstract method to evaluate prompts using a given predictor.
2626
2727
Args:
2828
prompts (List[str]): List of prompts to evaluate.
2929
predictor: The predictor to use for evaluation.
30+
system_promtps (List[str]): List of system prompts to evaluate.
3031
3132
Returns:
3233
np.ndarray: Array of evaluation scores for each prompt.
@@ -58,7 +59,7 @@ def __init__(self):
5859
self.ys = np.array(["positive", "negative", "positive"])
5960
self.classes = ["negative", "positive"]
6061

61-
def evaluate(self, prompts: List[str], predictor) -> np.ndarray:
62+
def evaluate(self, prompts: List[str], predictor, system_prompts=None) -> np.ndarray:
6263
"""Generate random evaluation scores for the given prompts.
6364
6465
Args:

promptolution/tasks/classification_tasks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def evaluate(
6363
self,
6464
prompts: List[str],
6565
predictor: BasePredictor,
66+
system_prompts: List[str] = None,
6667
n_samples: int = 20,
6768
subsample: bool = False,
6869
return_seq: bool = False,
@@ -72,6 +73,7 @@ def evaluate(
7273
Args:
7374
prompts (List[str]): List of prompts to evaluate.
7475
predictor (BasePredictor): Predictor to use for evaluation.
76+
system_prompts (List[str], optional): List of system prompts to evaluate. Defaults to None.
7577
n_samples (int, optional): Number of samples to use if subsampling. Defaults to 20.
7678
subsample (bool, optional): Whether to use subsampling.
7779
If set to true, samples a different subset per call. Defaults to False.
@@ -95,7 +97,7 @@ def evaluate(
9597
ys_subsample = self.ys[indices]
9698

9799
# Make predictions on the subsample
98-
preds = predictor.predict(prompts, xs_subsample, return_seq=return_seq)
100+
preds = predictor.predict(prompts, xs_subsample, system_prompts=system_prompts, return_seq=return_seq)
99101

100102
if return_seq:
101103
preds, seqs = preds

promptolution/templates.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
DEFAULT_SYS_PROMPT = "You are a helpful assistant."
12
EVOPROMPT_DE_TEMPLATE = """Please follow the instruction step-by-step to generate a better prompt.
23
Identifying the different parts between Prompt 1 and Prompt 2:
34
Prompt 1: Your task is to classify the comment as one of the following categories: terrible, bad, okay, good, great.

0 commit comments

Comments
 (0)