-
Notifications
You must be signed in to change notification settings - Fork 3
Refactor/generic api llm #41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,144 +1,86 @@ | ||
| """Module to interface with various language models through their respective APIs.""" | ||
|
|
||
| import asyncio | ||
| import time | ||
| from logging import Logger | ||
| from typing import Any, List | ||
|
|
||
| import nest_asyncio | ||
| import openai | ||
| import requests | ||
| from langchain_anthropic import ChatAnthropic | ||
| from langchain_community.chat_models.deepinfra import ChatDeepInfra, ChatDeepInfraException | ||
| from langchain_core.messages import HumanMessage, SystemMessage | ||
| from langchain_openai import ChatOpenAI | ||
| try: | ||
| import asyncio | ||
|
|
||
| from promptolution.llms.base_llm import BaseLLM | ||
| from openai import AsyncOpenAI | ||
|
|
||
| logger = Logger(__name__) | ||
| import_successful = True | ||
| except ImportError: | ||
| import_successful = False | ||
|
|
||
| from logging import Logger | ||
| from typing import Any, List | ||
|
|
||
| async def invoke_model(prompt, system_prompt, model, semaphore): | ||
| """Asynchronously invoke a language model with retry logic. | ||
| from promptolution.llms.base_llm import BaseLLM | ||
|
|
||
| Args: | ||
| prompt (str): The input prompt for the model. | ||
| system_prompt (str): The system prompt for the model. | ||
| model: The language model to invoke. | ||
| semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls. | ||
| logger = Logger(__name__) | ||
|
|
||
| Returns: | ||
| str: The model's response content. | ||
|
|
||
| Raises: | ||
| ChatDeepInfraException: If all retry attempts fail. | ||
| """ | ||
| async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, semaphore): | ||
| async with semaphore: | ||
| max_retries = 100 | ||
| delay = 3 | ||
| attempts = 0 | ||
|
|
||
| while attempts < max_retries: | ||
| try: | ||
| response = await model.ainvoke([SystemMessage(content=system_prompt), HumanMessage(content=prompt)]) | ||
| return response.content | ||
| except ChatDeepInfraException as e: | ||
| print(f"DeepInfra error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds...") | ||
| attempts += 1 | ||
| await asyncio.sleep(delay) | ||
| messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] | ||
| response = await client.chat.completions.create( | ||
| model=model_id, | ||
| messages=messages, | ||
| max_tokens=max_tokens, | ||
| ) | ||
| return response | ||
|
|
||
|
|
||
| class APILLM(BaseLLM): | ||
| """A class to interface with various language models through their respective APIs. | ||
| """A class to interface with language models through their respective APIs. | ||
| This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models. | ||
| It handles API key management, model initialization, and provides methods for | ||
| both synchronous and asynchronous inference. | ||
| This class provides a unified interface for making API calls to language models | ||
| using the OpenAI client library. It handles rate limiting through semaphores | ||
| and supports both synchronous and asynchronous operations. | ||
| Attributes: | ||
| model: The initialized language model instance. | ||
| Methods: | ||
| get_response: Synchronously get responses for a list of prompts. | ||
| get_response_async: Asynchronously get responses for a list of prompts. | ||
| model_id (str): Identifier for the model to use. | ||
| client (AsyncOpenAI): The initialized API client. | ||
| max_tokens (int): Maximum number of tokens in model responses. | ||
| semaphore (asyncio.Semaphore): Semaphore to limit concurrent API calls. | ||
| """ | ||
|
|
||
| def __init__(self, model_id: str, token: str = None, **kwargs: Any): | ||
| """Initialize the APILLM with a specific model. | ||
| def __init__( | ||
| self, api_url: str, model_id: str, token: str = None, max_concurrent_calls=50, max_tokens=512, **kwargs: Any | ||
| ): | ||
| """Initialize the APILLM with a specific model and API configuration. | ||
| Args: | ||
| api_url (str): The base URL for the API endpoint. | ||
| model_id (str): Identifier for the model to use. | ||
| token (str): API key for the model. | ||
| token (str, optional): API key for authentication. Defaults to None. | ||
| max_concurrent_calls (int, optional): Maximum number of concurrent API calls. Defaults to 50. | ||
| max_tokens (int, optional): Maximum number of tokens in model responses. Defaults to 512. | ||
| **kwargs (Any): Additional parameters to pass to the API client. | ||
| Raises: | ||
| ValueError: If an unknown model identifier is provided. | ||
| ImportError: If required libraries are not installed. | ||
| """ | ||
| if not import_successful: | ||
| raise ImportError( | ||
| "Could not import at least one of the required libraries: openai, asyncio. " | ||
| "Please ensure they are installed in your environment." | ||
| ) | ||
| super().__init__() | ||
| if "claude" in model_id: | ||
| self.model = ChatAnthropic(model=model_id, api_key=token) | ||
| elif "gpt" in model_id: | ||
| self.model = ChatOpenAI(model=model_id, api_key=token) | ||
| else: | ||
| self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token) | ||
|
|
||
| def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]: | ||
| """Get responses for a list of prompts in a synchronous manner. | ||
| self.model_id = model_id | ||
| self.client = AsyncOpenAI(base_url=api_url, api_key=token, **kwargs) | ||
| self.max_tokens = max_tokens | ||
|
|
||
| This method includes retry logic for handling connection errors and rate limits. | ||
| self.semaphore = asyncio.Semaphore(max_concurrent_calls) | ||
|
|
||
| Args: | ||
| prompts (list[str]): List of input prompts. | ||
| system_prompts (list[str]): List of system prompts. If not provided, uses default system_prompts | ||
| Returns: | ||
| list[str]: List of model responses. | ||
| Raises: | ||
| requests.exceptions.ConnectionError: If max retries are exceeded. | ||
| """ | ||
| max_retries = 100 | ||
| delay = 3 | ||
| attempts = 0 | ||
|
|
||
| nest_asyncio.apply() | ||
|
|
||
| while attempts < max_retries: | ||
| try: | ||
| responses = asyncio.run(self.get_response_async(prompts)) | ||
| return responses | ||
| except requests.exceptions.ConnectionError as e: | ||
| attempts += 1 | ||
| logger.critical( | ||
| f"Connection error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..." | ||
| ) | ||
| time.sleep(delay) | ||
| except openai.RateLimitError as e: | ||
| attempts += 1 | ||
| logger.critical( | ||
| f"Rate limit error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..." | ||
| ) | ||
| time.sleep(delay) | ||
|
|
||
| # If the loop exits, it means max retries were reached | ||
| raise requests.exceptions.ConnectionError("Max retries exceeded. Connection could not be established.") | ||
|
|
||
| async def get_response_async(self, prompts: list[str], max_concurrent_calls=200) -> list[str]: | ||
| """Asynchronously get responses for a list of prompts. | ||
| This method uses a semaphore to limit the number of concurrent API calls. | ||
| Args: | ||
| prompts (list[str]): List of input prompts. | ||
| max_concurrent_calls (int): Maximum number of concurrent API calls allowed. | ||
| Returns: | ||
| list[str]: List of model responses. | ||
| """ | ||
| semaphore = asyncio.Semaphore(max_concurrent_calls) | ||
| tasks = [] | ||
|
|
||
| for prompt in prompts: | ||
| tasks.append(invoke_model(prompt, self.model, semaphore)) | ||
| def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]: | ||
| # Setup for async execution in sync context | ||
| loop = asyncio.get_event_loop() | ||
| responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts)) | ||
| return responses | ||
|
|
||
| async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]: | ||
| tasks = [ | ||
| _invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore) | ||
| for prompt, system_prompt in zip(prompts, system_prompts) | ||
| ] | ||
| responses = await asyncio.gather(*tasks) | ||
| return responses | ||
| return [response.choices[0].message.content for response in responses] | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| """Test run for the Opro optimizer.""" | ||
|
|
||
| import argparse | ||
| import random | ||
| from logging import Logger | ||
|
|
||
| from promptolution.callbacks import LoggerCallback | ||
| from promptolution.templates import EVOPROMPT_GA_TEMPLATE | ||
| from promptolution.tasks import ClassificationTask | ||
| from promptolution.predictors import MarkerBasedClassificator | ||
| from promptolution.optimizers import EvoPromptGA | ||
| from datasets import load_dataset | ||
|
|
||
| from promptolution.llms.api_llm import APILLM | ||
|
|
||
| logger = Logger(__name__) | ||
|
|
||
| """Run a test run for any of the implemented optimizers.""" | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--base-url", default="https://api.openai.com/v1") | ||
| parser.add_argument("--model", default="gpt-4o-2024-08-06") | ||
| # parser.add_argument("--base-url", default="https://api.deepinfra.com/v1/openai") | ||
| # parser.add_argument("--model", default="meta-llama/Meta-Llama-3-8B-Instruct") | ||
| # parser.add_argument("--base-url", default="https://api.anthropic.com/v1/") | ||
| # parser.add_argument("--model", default="claude-3-haiku-20240307") | ||
| parser.add_argument("--n-steps", type=int, default=2) | ||
| parser.add_argument("--token", default=None) | ||
| args = parser.parse_args() | ||
|
|
||
| df = load_dataset("SetFit/ag_news", split="train", revision="main").to_pandas().sample(300) | ||
|
|
||
| df["input"] = df["text"] | ||
| df["target"] = df["label_text"] | ||
|
|
||
| task = ClassificationTask( | ||
| df, | ||
| description="The dataset contains news articles categorized into four classes: World, Sports, Business, and Tech. The task is to classify each news article into one of the four categories.", | ||
| x_column="input", | ||
| y_column="target", | ||
| ) | ||
|
|
||
| initial_prompts = [ | ||
| "Classify this news article as World, Sports, Business, or Tech. Provide your answer between <final_answer> and </final_answer> tags.", | ||
| "Read the following news article and determine which category it belongs to: World, Sports, Business, or Tech. Your classification must be placed between <final_answer> </final_answer> markers.", | ||
| "Your task is to identify whether this news article belongs to World, Sports, Business, or Tech news. Provide your classification between the markers <final_answer> </final_answer>.", | ||
| "Conduct a thorough analysis of the provided news article and classify it as belonging to one of these four categories: World, Sports, Business, or Tech. Your answer should be presented within <final_answer> </final_answer> markers.", | ||
| ] | ||
|
|
||
| llm = APILLM(api_url=args.base_url, model_id=args.model, token=args.token) | ||
| downstream_llm = llm | ||
| meta_llm = llm | ||
|
|
||
| predictor = MarkerBasedClassificator(downstream_llm, classes=task.classes) | ||
|
|
||
| callbacks = [LoggerCallback(logger)] | ||
|
|
||
| optimizer = EvoPromptGA( | ||
| task=task, | ||
| prompt_template=EVOPROMPT_GA_TEMPLATE, | ||
| predictor=predictor, | ||
| meta_llm=meta_llm, | ||
| initial_prompts=initial_prompts, | ||
| callbacks=callbacks, | ||
| n_eval_samples=20, | ||
| verbosity=2, # for debugging | ||
| ) | ||
|
|
||
| best_prompts = optimizer.optimize(n_steps=args.n_steps) | ||
|
|
||
| logger.info(f"Optimized prompts: {best_prompts}") |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also make these imports in the api llm class (if our linter is fine with that), would be more intuitive maybe and we would not need this variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Problem is, that asnycio and co. are not part of the global name space anymore, meaning if we use them outside of the init we get name "asyncio" is not defined.