Skip to content

Commit

Permalink
Merge pull request #81 from BatsResearch/add-openllm-support
Browse files Browse the repository at this point in the history
Add OpenLLM support
  • Loading branch information
dotpyu authored Sep 23, 2024
2 parents 8847315 + b9c0a00 commit a513883
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 9 deletions.
8 changes: 7 additions & 1 deletion alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
"google",
"groq",
"torch",
"openllm",
"dummy",
], f"Invalid model type: {self.model_type}"
else:
Expand All @@ -99,7 +100,7 @@ def __init__(
self.run = self.cache.cached_query(self.run)

self.grpcClient = None
if end_point:
if end_point and model_type not in ["dummy", "openllm", ]:
end_point_pieces = end_point.split(":")
self.end_point_ip, self.end_point_port = (
"".join(end_point_pieces[:-1]),
Expand Down Expand Up @@ -180,6 +181,11 @@ def __init__(
from ..fm.openai import OpenAIModel

self.model = OpenAIModel(self.model, **kwargs)
elif self.model_type == "openllm":
from ..fm.openllm import OpenLLMModel

base_url = kwargs.get("base_url", end_point)
self.model = OpenLLMModel(self.model, base_url=base_url, **kwargs)
elif self.model_type == "cohere":
from ..fm.cohere import CohereModel

Expand Down
148 changes: 148 additions & 0 deletions alfred/fm/openllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import json
import logging
from typing import Optional, List, Any, Union, Tuple

import openai
from openai._exceptions import (
AuthenticationError,
APIError,
APITimeoutError,
RateLimitError,
BadRequestError,
APIConnectionError,
APIStatusError,
)

from .model import APIAccessFoundationModel
from .response import CompletionResponse, RankedResponse
from .utils import retry

logger = logging.getLogger(__name__)

class OpenLLMModel(APIAccessFoundationModel):
"""
A wrapper for the OpenLLM Models using OpenAI's Python package
"""

@retry(
num_retries=3,
wait_time=0.1,
exceptions=(
AuthenticationError,
APIConnectionError,
APITimeoutError,
RateLimitError,
APIError,
BadRequestError,
APIStatusError,
),
)
def _api_query(
self,
query: Union[str, List, Tuple],
temperature: float = 0.0,
max_tokens: int = 64,
**kwargs: Any,
) -> str:
"""
Run a single query through the foundation model using OpenAI's Python package
:param query: The prompt to be used for the query
:type query: Union[str, List, Tuple]
:param temperature: The temperature of the model
:type temperature: float
:param max_tokens: The maximum number of tokens to be returned
:type max_tokens: int
:param kwargs: Additional keyword arguments
:type kwargs: Any
:return: The generated completion
:rtype: str
"""
chat = kwargs.get("chat", False)

if chat:
messages = query if isinstance(query, list) else [{"role": "user", "content": query}]
response = self.openai_client.chat.completions.create(
model=self.model_string,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].message.content
else:
prompt = query[0]['content'] if isinstance(query, list) else query
response = self.openai_client.completions.create(
model=self.model_string,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
)
return response.choices[0].text

def __init__(
self, model_string: str = "", api_key: Optional[str] = None, **kwargs: Any
):
"""
Initialize the OpenLLM API wrapper.
:param model_string: The model to be used for generating completions.
:type model_string: str
:param api_key: The API key to be used for the OpenAI API.
:type api_key: Optional[str]
"""
self.model_string = model_string
base_url = kwargs.get("base_url", None)
api_key = api_key or "na"
self.openai_client = openai.OpenAI(base_url=base_url, api_key=api_key)
super().__init__(model_string, {"api_key": api_key, "base_url": base_url})

def _generate_batch(
self,
batch_instance: Union[List[str], Tuple],
**kwargs,
) -> List[CompletionResponse]:
"""
Generate completions for a batch of prompts using the OpenAI API.
:param batch_instance: A list of prompts for which to generate completions.
:type batch_instance: List[str] or List[Tuple]
:param kwargs: Additional keyword arguments to pass to the API.
:type kwargs: Any
:return: A list of `CompletionResponse` objects containing the generated completions.
:rtype: List[CompletionResponse]
"""
output = []
for query in batch_instance:
output.append(
CompletionResponse(prediction=self._api_query(query, **kwargs))
)
return output

def _score_batch(
self,
batch_instance: Union[List[Tuple[str, str]], List[str]],
scoring_instruction: str = "Instruction: Given the query, choose your answer from [[label_space]]:\nQuery:\n",
**kwargs,
) -> List[RankedResponse]:
"""
Score candidates using the OpenAI API.
:param batch_instance: A list of prompts for which to generate candidate preferences.
:type batch_instance: List[str] or List[Tuple]
:param scoring_instruction: The instruction prompt for scoring
:type scoring_instruction: str
"""
output = []
for query in batch_instance:
_scoring_prompt = (
scoring_instruction.replace(
"[[label_space]]", ",".join(query.candidates)
)
+ query.prompt
)
output.append(
RankedResponse(
prediction=self._api_query(_scoring_prompt, **kwargs), scores={}
)
)
return output
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ A full list of `Alfred` project modules.
- [Model](alfred/fm/model.md#model)
- [Onnx](alfred/fm/onnx.md#onnx)
- [Openai](alfred/fm/openai.md#openai)
- [Openllm](alfred/fm/openllm.md#openllm)
- [Query](alfred/fm/query/index.md#query)
- [CompletionQuery](alfred/fm/query/completion_query.md#completionquery)
- [Query](alfred/fm/query/query.md#query)
Expand Down
16 changes: 8 additions & 8 deletions docs/alfred/client/client.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Client:

### Client().__call__

[Show source in client.py:313](../../../alfred/client/client.py#L313)
[Show source in client.py:319](../../../alfred/client/client.py#L319)

__call__() function to run the model on the queries.
Equivalent to run() function.
Expand All @@ -71,7 +71,7 @@ def __call__(

### Client().calibrate

[Show source in client.py:329](../../../alfred/client/client.py#L329)
[Show source in client.py:335](../../../alfred/client/client.py#L335)

calibrate are used to calibrate foundation models contextually given the template.
A voter class may be passed to calibrate the model with a specific voter.
Expand Down Expand Up @@ -115,7 +115,7 @@ def calibrate(

### Client().chat

[Show source in client.py:427](../../../alfred/client/client.py#L427)
[Show source in client.py:433](../../../alfred/client/client.py#L433)

Chat with the model APIs.
Currently, Alfred supports Chat APIs from Anthropic and OpenAI
Expand All @@ -133,7 +133,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any): ...

### Client().encode

[Show source in client.py:401](../../../alfred/client/client.py#L401)
[Show source in client.py:407](../../../alfred/client/client.py#L407)

embed() function to embed the queries.

Expand All @@ -155,7 +155,7 @@ def encode(

### Client().generate

[Show source in client.py:272](../../../alfred/client/client.py#L272)
[Show source in client.py:278](../../../alfred/client/client.py#L278)

Wrapper function to generate the response(s) from the model. (For completion)

Expand Down Expand Up @@ -183,7 +183,7 @@ def generate(

### Client().remote_run

[Show source in client.py:246](../../../alfred/client/client.py#L246)
[Show source in client.py:252](../../../alfred/client/client.py#L252)

Wrapper function for running the model on the queries thru a gRPC Server.

Expand All @@ -209,7 +209,7 @@ def remote_run(

### Client().run

[Show source in client.py:226](../../../alfred/client/client.py#L226)
[Show source in client.py:232](../../../alfred/client/client.py#L232)

Run the model on the queries.

Expand All @@ -235,7 +235,7 @@ def run(

### Client().score

[Show source in client.py:289](../../../alfred/client/client.py#L289)
[Show source in client.py:295](../../../alfred/client/client.py#L295)

Wrapper function to score the response(s) from the model. (For ranking)

Expand Down
1 change: 1 addition & 0 deletions docs/alfred/fm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
- [Model](./model.md)
- [Onnx](./onnx.md)
- [Openai](./openai.md)
- [Openllm](./openllm.md)
- [Query](query/index.md)
- [Remote](remote/index.md)
- [Response](response/index.md)
Expand Down

0 comments on commit a513883

Please sign in to comment.