Skip to content

Commit

Permalink
Add cost summary to client.py (microsoft#812)
Browse files Browse the repository at this point in the history
* init commit

* add doc, notebook and test

* fix test

* update

* update

* update

* update
  • Loading branch information
yiranwu0 authored Dec 3, 2023
1 parent 5c92fb3 commit 7a4ba1a
Show file tree
Hide file tree
Showing 6 changed files with 485 additions and 17 deletions.
123 changes: 106 additions & 17 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class OpenAIWrapper:
cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Dict = None
actual_usage_summary: Dict = None

def __init__(self, *, config_list: List[Dict] = None, **base_config):
"""
Expand Down Expand Up @@ -233,14 +235,15 @@ def yes_or_no_filter(context, response):
# Try to get the response from cache
key = get_key(params)
response = cache.get(key, None)
if response is not None:
self._update_usage_summary(response, use_cache=True)
if response is not None:
# check the filter
pass_filter = filter_func is None or filter_func(context=context, response=response)
if pass_filter or i == last:
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config
try:
Expand All @@ -250,6 +253,9 @@ def yes_or_no_filter(context, response):
if i == last:
raise
else:
# add cost calculation before caching not matter filter is passed or not
response.cost = self.cost(response)
self._update_usage_summary(response, use_cache=False)
if cache_seed is not None:
# Cache the response
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
Expand All @@ -261,25 +267,9 @@ def yes_or_no_filter(context, response):
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
response.cost = self.cost(response)
return response
continue # filter is not passed; try the next config

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

def _completions_create(self, client, params):
completions = client.chat.completions if "messages" in params else client.completions
# If streaming is enabled, has messages, and does not have functions, then
Expand Down Expand Up @@ -342,6 +332,105 @@ def _completions_create(self, client, params):
response = completions.create(**params)
return response

def _update_usage_summary(self, response: ChatCompletion | Completion, use_cache: bool) -> None:
"""Update the usage summary.
Usage is calculated no mattter filter is passed or not.
"""

def update_usage(usage_summary):
if usage_summary is None:
usage_summary = {"total_cost": response.cost}
else:
usage_summary["total_cost"] += response.cost

usage_summary[response.model] = {
"cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost,
"prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0)
+ response.usage.prompt_tokens,
"completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0)
+ response.usage.completion_tokens,
"total_tokens": usage_summary.get(response.model, {}).get("total_tokens", 0)
+ response.usage.total_tokens,
}
return usage_summary

self.total_usage_summary = update_usage(self.total_usage_summary)
if not use_cache:
self.actual_usage_summary = update_usage(self.actual_usage_summary)

def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None:
"""Print the usage summary."""

def print_usage(usage_summary, usage_type="total"):
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary is None:
print("No actual cost incurred (all completions are using cache).", flush=True)
return

print(f"Usage summary {word_from_type} cached usage: ", flush=True)
print(f"Total cost: {round(usage_summary['total_cost'], 5)}", flush=True)
for model, counts in usage_summary.items():
if model == "total_cost":
continue #
print(
f"* Model '{model}': cost: {round(counts['cost'], 5)}, prompt_tokens: {counts['prompt_tokens']}, completion_tokens: {counts['completion_tokens']}, total_tokens: {counts['total_tokens']}",
flush=True,
)

if self.total_usage_summary is None:
print('No usage summary. Please call "create" first.', flush=True)
return

if isinstance(mode, list):
if len(mode) == 0 or len(mode) > 2:
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
if "actual" in mode and "total" in mode:
mode = "both"
elif "actual" in mode:
mode = "actual"
elif "total" in mode:
mode = "total"

print("-" * 100, flush=True)
if mode == "both":
print_usage(self.actual_usage_summary, "actual")
print()
if self.total_usage_summary != self.actual_usage_summary:
print_usage(self.total_usage_summary, "total")
else:
print(
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
flush=True,
)
elif mode == "total":
print_usage(self.total_usage_summary, "total")
elif mode == "actual":
print_usage(self.actual_usage_summary, "actual")
else:
raise ValueError(f'Invalid mode: {mode}, choose from "actual", "total", ["actual", "total"]')
print("-" * 100, flush=True)

def clear_usage_summary(self) -> None:
"""Clear the usage summary."""
self.total_usage_summary = None
self.actual_usage_summary = None

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
if model not in oai_price1k:
# TODO: add logging to warn that the model is not found
return 0

n_input_tokens = response.usage.prompt_tokens
n_output_tokens = response.usage.completion_tokens
tmp_price1K = oai_price1k[model]
# First value is input token rate, second value is output token rate
if isinstance(tmp_price1K, tuple):
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
Expand Down
Loading

0 comments on commit 7a4ba1a

Please sign in to comment.