Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Client support #831

Closed
wants to merge 14 commits into from
Prev Previous commit
Next Next commit
rl client which uses transformers
  • Loading branch information
olgavrou committed Nov 30, 2023
commit 6ce9a6e7232831ab690ba40019adb3e86f042696
271 changes: 174 additions & 97 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import inspect
from flaml.automl.logger import logger_formatter
from types import SimpleNamespace

from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token
Expand All @@ -22,23 +23,32 @@
except ImportError:
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
OpenAI = object

try:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
ERROR = None
except ImportError:
ERROR = ImportError("Please install transformers and diskcache to use autogen.RLClientWrapper.")

logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler(stream=sys.stdout)
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)


def template_formatter(
template: str | Callable | None,
context: Optional[Dict] = None,
allow_format_str_template: Optional[bool] = False,
):
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)
template: str | Callable | None,
context: Optional[Dict] = None,
allow_format_str_template: Optional[bool] = False,
):
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)


class ResponseCreator:
cache_path_root: str = ".cache"
Expand Down Expand Up @@ -75,7 +85,7 @@ def construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Di
]
return params

def create(self, client, completions_create, is_last, create_config: Dict, extra_kwargs: Dict):
def create(self, client, client_id, is_last, create_config: Dict, extra_kwargs: Dict):
# construct the create params
params = self.construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
Expand All @@ -96,7 +106,9 @@ def create(self, client, completions_create, is_last, create_config: Dict, extra
# TODO: add response.cost
return response

response = completions_create(client, params)
response = client.create(params)
if response is None:
return None

if cache_seed is not None:
# Cache the response
Expand All @@ -108,10 +120,152 @@ def create(self, client, completions_create, is_last, create_config: Dict, extra
if pass_filter or is_last:
# Return the response if it passes the filter or it is the last client
response.pass_filter = pass_filter
response.config_id = client_id
response.cost = client.cost(response)
return response
return None


class RLClient:
def __init__(self, config: Dict):
import torch

self.device = (
("cuda" if torch.cuda.is_available() else "cpu") if config.get("device", None) is None else config["device"]
)
self.tokenizer = AutoTokenizer.from_pretrained(config["local_model"], load_in_8bit=True, use_fast=False)
self.model = AutoModelForCausalLM.from_pretrained(config["local_model"]).to(self.device)
# get max_length from config or set to 1000
self.max_length = config.get("max_length", 1000)
self.gen_config_params = config.get("params", {})
# correct max_length in self.params
self.gen_config_params["max_length"] = self.max_length
self.gen_config_params["eos_token_id"] = self.tokenizer.eos_token_id
self.gen_config_params["pad_token_id"] = self.tokenizer.eos_token_id
print(f"Loaded model {config['local_model']} to {self.device}")

def create(self, params):
if params.get("stream", False) and "messages" in params and "functions" not in params:
raise NotImplementedError("Local models do not support streaming or functions")
else:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

response = SimpleNamespace()
inputs = self.tokenizer.apply_chat_template(
params["messages"], return_tensors="pt", add_generation_prompt=True
).to(self.device)

inputs_length = inputs.shape[-1]
# copy gen config params
gen_config_params = self.gen_config_params.copy()
# add inputs_length to max_length
gen_config_params["max_length"] += inputs_length
generation_config = GenerationConfig(**gen_config_params)


response.choices = []

for _ in range(len(response_contents)):
outputs = self.model.generate(inputs, generation_config=generation_config)
# Decode only the newly generated text, excluding the prompt
text = self.tokenizer.decode(outputs[0, inputs_length:], skip_special_tokens=True)
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)

return response

def cost(self, response) -> float:
"""Calculate the cost of the response."""
return 0


class OpenAIClient:
def __init__(self, config: Dict):
self.client = OpenAI(**config)

def create(self, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is modified in #786

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like just a change in _completions_create, which can be moved to OpenAIClient.create

completions = self.client.chat.completions if "messages" in params else self.client.completions
# If streaming is enabled, has messages, and does not have functions, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

# Set the terminal text color to green
print("\033[32m", end="")

# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
finish_reasons[choice.index] = choice.finish_reason
# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
response_contents[choice.index] += content
completion_tokens += 1
else:
print()

# Reset the terminal text color
print("\033[0m\n")

# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
prompt_tokens = count_token(params["messages"], model)
response = ChatCompletion(
id=chunk.id,
model=chunk.model,
created=chunk.created,
object="chat.completion",
choices=[],
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
for i in range(len(response_contents)):
response.choices.append(
Choice(
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
),
)
)
else:
# If streaming is not enabled or using functions, send a regular chat completion request
# Functions are not supported, so ensure streaming is disabled
params = params.copy()
params["stream"] = False
response = completions.create(**params)
return response

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000


class OpenAIWrapper:
"""A wrapper class for openai client."""

Expand Down Expand Up @@ -218,8 +372,10 @@ def _client(self, config, openai_config):
"""
openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
self._process_for_azure(openai_config, config)
client = OpenAI(**openai_config)
return client
if "local_model" in config:
return RLClient(config)
else:
return OpenAIClient(openai_config)

@classmethod
def instantiate(
Expand Down Expand Up @@ -271,100 +427,21 @@ def yes_or_no_filter(context, response):
try:
response = self.response_creator.create(
client=client,
client_id=i,
is_last=(i == last),
completions_create=self._completions_create,
create_config=create_config,
extra_kwargs=extra_kwargs,
)
if response is None:
continue # filter is not passed; try the next config
response.config_id = i
response.cost = self.cost(response)
return response

if response is not None:
return response
except APIError:
logger.debug(f"config {i} failed", exc_info=1)
if i == last:
raise

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
# iterate over the chunks of the response
if params.get("stream", False) and "messages" in params and "functions" not in params:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0

# Set the terminal text color to green
print("\033[32m", end="")

# Send the chat completion request to OpenAI's API and process the response in chunks
for chunk in completions.create(**params):
if chunk.choices:
for choice in chunk.choices:
content = choice.delta.content
finish_reasons[choice.index] = choice.finish_reason
# If content is present, print it to the terminal and update response variables
if content is not None:
print(content, end="", flush=True)
response_contents[choice.index] += content
completion_tokens += 1
else:
print()

# Reset the terminal text color
print("\033[0m\n")

# Prepare the final ChatCompletion object based on the accumulated data
model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API
prompt_tokens = count_token(params["messages"], model)
response = ChatCompletion(
id=chunk.id,
model=chunk.model,
created=chunk.created,
object="chat.completion",
choices=[],
usage=CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
for i in range(len(response_contents)):
response.choices.append(
Choice(
index=i,
finish_reason=finish_reasons[i],
message=ChatCompletionMessage(
role="assistant", content=response_contents[i], function_call=None
),
)
)
else:
# If streaming is not enabled or using functions, send a regular chat completion request
# Functions are not supported, so ensure streaming is disabled
params = params.copy()
params["stream"] = False
response = completions.create(**params)
return response

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
def extract_text_or_function_call(cls, response) -> List[str]:
"""Extract the text or function calls from a completion or chat response.

Args:
Expand Down