Skip to content

Commit 5d33d97

Browse files
finitearthmo374z
andauthored
Refactor/generic api llm (#41)
* v1.3.2 (#40) #### Added features * Allow for configuration and evaluation of system prompts in all LLM-Classes * CSV Callback is now FileOutputCallback and able to write Parquet files * Fixed LLM-Call templates in VLLM * refined OPRO-implementation to be closer to the paper * implement api calls * removed default for system messages * roll back renaming --------- Co-authored-by: mo374z <schlager.mo@t-online.de>
1 parent 3e86324 commit 5d33d97

File tree

6 files changed

+157
-135
lines changed

6 files changed

+157
-135
lines changed

promptolution/llms/api_llm.py

Lines changed: 56 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,86 @@
11
"""Module to interface with various language models through their respective APIs."""
22

3-
import asyncio
4-
import time
5-
from logging import Logger
6-
from typing import Any, List
73

8-
import nest_asyncio
9-
import openai
10-
import requests
11-
from langchain_anthropic import ChatAnthropic
12-
from langchain_community.chat_models.deepinfra import ChatDeepInfra, ChatDeepInfraException
13-
from langchain_core.messages import HumanMessage, SystemMessage
14-
from langchain_openai import ChatOpenAI
4+
try:
5+
import asyncio
156

16-
from promptolution.llms.base_llm import BaseLLM
7+
from openai import AsyncOpenAI
178

18-
logger = Logger(__name__)
9+
import_successful = True
10+
except ImportError:
11+
import_successful = False
1912

13+
from logging import Logger
14+
from typing import Any, List
2015

21-
async def invoke_model(prompt, system_prompt, model, semaphore):
22-
"""Asynchronously invoke a language model with retry logic.
16+
from promptolution.llms.base_llm import BaseLLM
2317

24-
Args:
25-
prompt (str): The input prompt for the model.
26-
system_prompt (str): The system prompt for the model.
27-
model: The language model to invoke.
28-
semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls.
18+
logger = Logger(__name__)
2919

30-
Returns:
31-
str: The model's response content.
3220

33-
Raises:
34-
ChatDeepInfraException: If all retry attempts fail.
35-
"""
21+
async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, semaphore):
3622
async with semaphore:
37-
max_retries = 100
38-
delay = 3
39-
attempts = 0
40-
41-
while attempts < max_retries:
42-
try:
43-
response = await model.ainvoke([SystemMessage(content=system_prompt), HumanMessage(content=prompt)])
44-
return response.content
45-
except ChatDeepInfraException as e:
46-
print(f"DeepInfra error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds...")
47-
attempts += 1
48-
await asyncio.sleep(delay)
23+
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
24+
response = await client.chat.completions.create(
25+
model=model_id,
26+
messages=messages,
27+
max_tokens=max_tokens,
28+
)
29+
return response
4930

5031

5132
class APILLM(BaseLLM):
52-
"""A class to interface with various language models through their respective APIs.
33+
"""A class to interface with language models through their respective APIs.
5334
54-
This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models.
55-
It handles API key management, model initialization, and provides methods for
56-
both synchronous and asynchronous inference.
35+
This class provides a unified interface for making API calls to language models
36+
using the OpenAI client library. It handles rate limiting through semaphores
37+
and supports both synchronous and asynchronous operations.
5738
5839
Attributes:
59-
model: The initialized language model instance.
60-
61-
Methods:
62-
get_response: Synchronously get responses for a list of prompts.
63-
get_response_async: Asynchronously get responses for a list of prompts.
40+
model_id (str): Identifier for the model to use.
41+
client (AsyncOpenAI): The initialized API client.
42+
max_tokens (int): Maximum number of tokens in model responses.
43+
semaphore (asyncio.Semaphore): Semaphore to limit concurrent API calls.
6444
"""
6545

66-
def __init__(self, model_id: str, token: str = None, **kwargs: Any):
67-
"""Initialize the APILLM with a specific model.
46+
def __init__(
47+
self, api_url: str, model_id: str, token: str = None, max_concurrent_calls=50, max_tokens=512, **kwargs: Any
48+
):
49+
"""Initialize the APILLM with a specific model and API configuration.
6850
6951
Args:
52+
api_url (str): The base URL for the API endpoint.
7053
model_id (str): Identifier for the model to use.
71-
token (str): API key for the model.
54+
token (str, optional): API key for authentication. Defaults to None.
55+
max_concurrent_calls (int, optional): Maximum number of concurrent API calls. Defaults to 50.
56+
max_tokens (int, optional): Maximum number of tokens in model responses. Defaults to 512.
57+
**kwargs (Any): Additional parameters to pass to the API client.
7258
7359
Raises:
74-
ValueError: If an unknown model identifier is provided.
60+
ImportError: If required libraries are not installed.
7561
"""
62+
if not import_successful:
63+
raise ImportError(
64+
"Could not import at least one of the required libraries: openai, asyncio. "
65+
"Please ensure they are installed in your environment."
66+
)
7667
super().__init__()
77-
if "claude" in model_id:
78-
self.model = ChatAnthropic(model=model_id, api_key=token)
79-
elif "gpt" in model_id:
80-
self.model = ChatOpenAI(model=model_id, api_key=token)
81-
else:
82-
self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token)
83-
84-
def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
85-
"""Get responses for a list of prompts in a synchronous manner.
68+
self.model_id = model_id
69+
self.client = AsyncOpenAI(base_url=api_url, api_key=token, **kwargs)
70+
self.max_tokens = max_tokens
8671

87-
This method includes retry logic for handling connection errors and rate limits.
72+
self.semaphore = asyncio.Semaphore(max_concurrent_calls)
8873

89-
Args:
90-
prompts (list[str]): List of input prompts.
91-
system_prompts (list[str]): List of system prompts. If not provided, uses default system_prompts
92-
93-
Returns:
94-
list[str]: List of model responses.
95-
96-
Raises:
97-
requests.exceptions.ConnectionError: If max retries are exceeded.
98-
"""
99-
max_retries = 100
100-
delay = 3
101-
attempts = 0
102-
103-
nest_asyncio.apply()
104-
105-
while attempts < max_retries:
106-
try:
107-
responses = asyncio.run(self.get_response_async(prompts))
108-
return responses
109-
except requests.exceptions.ConnectionError as e:
110-
attempts += 1
111-
logger.critical(
112-
f"Connection error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..."
113-
)
114-
time.sleep(delay)
115-
except openai.RateLimitError as e:
116-
attempts += 1
117-
logger.critical(
118-
f"Rate limit error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..."
119-
)
120-
time.sleep(delay)
121-
122-
# If the loop exits, it means max retries were reached
123-
raise requests.exceptions.ConnectionError("Max retries exceeded. Connection could not be established.")
124-
125-
async def get_response_async(self, prompts: list[str], max_concurrent_calls=200) -> list[str]:
126-
"""Asynchronously get responses for a list of prompts.
127-
128-
This method uses a semaphore to limit the number of concurrent API calls.
129-
130-
Args:
131-
prompts (list[str]): List of input prompts.
132-
max_concurrent_calls (int): Maximum number of concurrent API calls allowed.
133-
134-
Returns:
135-
list[str]: List of model responses.
136-
"""
137-
semaphore = asyncio.Semaphore(max_concurrent_calls)
138-
tasks = []
139-
140-
for prompt in prompts:
141-
tasks.append(invoke_model(prompt, self.model, semaphore))
74+
def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
75+
# Setup for async execution in sync context
76+
loop = asyncio.get_event_loop()
77+
responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts))
78+
return responses
14279

80+
async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
81+
tasks = [
82+
_invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore)
83+
for prompt, system_prompt in zip(prompts, system_prompts)
84+
]
14385
responses = await asyncio.gather(*tasks)
144-
return responses
86+
return [response.choices[0].message.content for response in responses]

promptolution/llms/base_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def set_generation_seed(self, seed: int):
9191
pass
9292

9393
@abstractmethod
94-
def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
94+
def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
9595
"""Generate responses for the given prompts.
9696
9797
This method should be implemented by subclasses to define how

promptolution/llms/local_llm.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
try:
33
import torch
44
import transformers
5-
except ImportError as e:
6-
import logging
75

8-
logger = logging.getLogger(__name__)
9-
logger.warning(f"Could not import torch or transformers in local_llm.py: {e}")
6+
imports_successful = True
7+
except ImportError:
8+
imports_successful = False
109

1110
from promptolution.llms.base_llm import BaseLLM
1211

@@ -35,6 +34,11 @@ def __init__(self, model_id: str, batch_size=8):
3534
This method sets up a text generation pipeline with bfloat16 precision,
3635
automatic device mapping, and specific generation parameters.
3736
"""
37+
if not imports_successful:
38+
raise ImportError(
39+
"Could not import at least one of the required libraries: torch, transformers. "
40+
"Please ensure they are installed in your environment."
41+
)
3842
super().__init__()
3943

4044
self.pipeline = transformers.pipeline(
@@ -78,8 +82,5 @@ def _get_response(self, prompts: list[str], system_prompts: list[str]) -> list[s
7882

7983
def __del__(self):
8084
"""Cleanup method to delete the pipeline and free up GPU memory."""
81-
try:
82-
del self.pipeline
83-
torch.cuda.empty_cache()
84-
except Exception as e:
85-
logger.warning(f"Error during LocalLLM cleanup: {e}")
85+
del self.pipeline
86+
torch.cuda.empty_cache()

promptolution/llms/vllm.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import torch
1313
from transformers import AutoTokenizer
1414
from vllm import LLM, SamplingParams
15-
except ImportError as e:
16-
logger.warning(f"Could not import vllm, torch or transformers in vllm.py: {e}")
15+
16+
imports_successful = True
17+
except ImportError:
18+
imports_successful = False
1719

1820

1921
class VLLM(BaseLLM):
@@ -68,6 +70,11 @@ def __init__(
6870
Note:
6971
This method sets up a vLLM engine with specified parameters for efficient inference.
7072
"""
73+
if not imports_successful:
74+
raise ImportError(
75+
"Could not import at least one of the required libraries: torch, transformers, vllm. "
76+
"Please ensure they are installed in your environment."
77+
)
7178
super().__init__()
7279

7380
self.dtype = dtype

pyproject.toml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,36 @@
11
[tool.poetry]
22
name = "promptolution"
33
version = "1.3.2"
4-
description = ""
4+
description = "A framework for prompt optimization and a zoo of prompt optimization algorithms."
55
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
66
readme = "README.md"
77

88
[tool.poetry.dependencies]
99
python = "^3.9"
1010
numpy = "^1.26.0"
11-
langchain-anthropic = "^0.1.22"
12-
langchain-openai = "^0.1.21"
13-
langchain-core = "^0.2.29"
14-
langchain-community = "^0.2.12"
1511
pandas = "^2.2.2"
1612
tqdm = "^4.66.5"
1713
scikit-learn = "^1.5.2"
14+
15+
[tool.poetry.group.requests.dependencies]
16+
openai = "^1.0.0"
17+
requests = "^2.31.0"
18+
19+
[tool.poetry.group.vllm.dependencies]
1820
vllm = "^0.7.3"
19-
datasets = "^3.3.2"
21+
22+
[tool.poetry.group.transformers.dependencies]
23+
transformers = "^4.48.0"
2024

2125
[tool.poetry.group.dev.dependencies]
2226
matplotlib = "^3.9.2"
2327
seaborn = "^0.13.2"
24-
transformers = "^4.48.0"
2528
black = "^24.4.2"
2629
flake8 = "^7.1.0"
2730
isort = "^5.13.2"
2831
pre-commit = "^3.7.1"
2932
ipykernel = "^6.29.5"
3033

31-
3234
[tool.poetry.group.docs.dependencies]
3335
mkdocs = "^1.6.1"
3436
mkdocs-material = "^9.5.39"
@@ -46,4 +48,4 @@ line_length = 120
4648
profile = "black"
4749

4850
[tool.pydocstyle]
49-
convention = "google"
51+
convention = "google"

scripts/api_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Test run for the Opro optimizer."""
2+
3+
import argparse
4+
import random
5+
from logging import Logger
6+
7+
from promptolution.callbacks import LoggerCallback
8+
from promptolution.templates import EVOPROMPT_GA_TEMPLATE
9+
from promptolution.tasks import ClassificationTask
10+
from promptolution.predictors import MarkerBasedClassificator
11+
from promptolution.optimizers import EvoPromptGA
12+
from datasets import load_dataset
13+
14+
from promptolution.llms.api_llm import APILLM
15+
16+
logger = Logger(__name__)
17+
18+
"""Run a test run for any of the implemented optimizers."""
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("--base-url", default="https://api.openai.com/v1")
21+
parser.add_argument("--model", default="gpt-4o-2024-08-06")
22+
# parser.add_argument("--base-url", default="https://api.deepinfra.com/v1/openai")
23+
# parser.add_argument("--model", default="meta-llama/Meta-Llama-3-8B-Instruct")
24+
# parser.add_argument("--base-url", default="https://api.anthropic.com/v1/")
25+
# parser.add_argument("--model", default="claude-3-haiku-20240307")
26+
parser.add_argument("--n-steps", type=int, default=2)
27+
parser.add_argument("--token", default=None)
28+
args = parser.parse_args()
29+
30+
df = load_dataset("SetFit/ag_news", split="train", revision="main").to_pandas().sample(300)
31+
32+
df["input"] = df["text"]
33+
df["target"] = df["label_text"]
34+
35+
task = ClassificationTask(
36+
df,
37+
description="The dataset contains news articles categorized into four classes: World, Sports, Business, and Tech. The task is to classify each news article into one of the four categories.",
38+
x_column="input",
39+
y_column="target",
40+
)
41+
42+
initial_prompts = [
43+
"Classify this news article as World, Sports, Business, or Tech. Provide your answer between <final_answer> and </final_answer> tags.",
44+
"Read the following news article and determine which category it belongs to: World, Sports, Business, or Tech. Your classification must be placed between <final_answer> </final_answer> markers.",
45+
"Your task is to identify whether this news article belongs to World, Sports, Business, or Tech news. Provide your classification between the markers <final_answer> </final_answer>.",
46+
"Conduct a thorough analysis of the provided news article and classify it as belonging to one of these four categories: World, Sports, Business, or Tech. Your answer should be presented within <final_answer> </final_answer> markers.",
47+
]
48+
49+
llm = APILLM(api_url=args.base_url, model_id=args.model, token=args.token)
50+
downstream_llm = llm
51+
meta_llm = llm
52+
53+
predictor = MarkerBasedClassificator(downstream_llm, classes=task.classes)
54+
55+
callbacks = [LoggerCallback(logger)]
56+
57+
optimizer = EvoPromptGA(
58+
task=task,
59+
prompt_template=EVOPROMPT_GA_TEMPLATE,
60+
predictor=predictor,
61+
meta_llm=meta_llm,
62+
initial_prompts=initial_prompts,
63+
callbacks=callbacks,
64+
n_eval_samples=20,
65+
verbosity=2, # for debugging
66+
)
67+
68+
best_prompts = optimizer.optimize(n_steps=args.n_steps)
69+
70+
logger.info(f"Optimized prompts: {best_prompts}")

0 commit comments

Comments
 (0)