Skip to content

Commit

Permalink
fix(code2prompt/utils): Optimize calculate_prices function and add …
Browse files Browse the repository at this point in the history
…more test cases
  • Loading branch information
raphaelmansuy authored and CTY-git committed Sep 23, 2024
1 parent 44875b6 commit 4ae41f1
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 41 deletions.
72 changes: 40 additions & 32 deletions code2prompt/utils/price_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,49 +31,57 @@ def calculate_price(token_count, price_per_1000):
"""
return (token_count / 1_000) * price_per_1000

def calculate_prices(token_prices, input_token_count, output_token_count, provider=None, model=None):
def calculate_prices(token_prices, input_tokens, output_tokens, provider=None, model=None):
"""
Calculate the prices based on the token prices, input and output token counts, provider, and model.
Calculate the prices for a given number of input and output tokens based on token prices.
Args:
token_prices (dict): A dictionary containing token prices for different providers and models.
input_token_count (int): The number of input tokens.
output_token_count (int): The number of output tokens.
provider (str, optional): The name of the provider. If specified, only prices for this provider will be calculated. Defaults to None.
model (str, optional): The name of the model. If specified, only prices for this model will be calculated. Defaults to None.
input_tokens (int): The number of input tokens.
output_tokens (int): The number of output tokens.
provider (str, optional): The name of the provider. If specified, only prices for the specified provider will be calculated. Defaults to None.
model (str, optional): The name of the model. If specified, only prices for the specified model will be calculated. Defaults to None.
Returns:
list: A list of lists containing the calculated prices for each provider and model. Each inner list contains the following information:
- Provider name
- Model name
- Input price
- Input token count
- Total price
list: A list of tuples containing the provider name, model name, price per token, total tokens, and total price for each calculation.
"""
table_data = []
for provider_data in token_prices["providers"]:
# Convert both strings to lowercase for case-insensitive comparison
if provider and provider_data["name"].lower() != provider.lower():
def calculate_prices(token_prices, input_tokens, output_tokens, provider=None, model=None):
results = []

for p in token_prices["providers"]:
if provider and p["name"] != provider:
continue
for model_data in provider_data["models"]:
# Convert both strings to lowercase for case-insensitive comparison
if model and model_data["name"].lower() != model.lower():

for m in p["models"]:
if model and m["name"] != model:
continue

input_price = model_data.get("input_price", model_data.get("price", 0))
output_price = model_data.get("output_price", model_data.get("price", 0))
total_tokens = input_tokens + output_tokens

if "price" in m:
# Single price for both input and output tokens
price = m["price"]
total_price = (price * total_tokens) / 1000
price_info = f"${price:.10f}"
elif "input_price" in m and "output_price" in m:
# Separate prices for input and output tokens
input_price = m["input_price"]
output_price = m["output_price"]
total_price = ((input_price * input_tokens) + (output_price * output_tokens)) / 1000
price_info = f"${input_price:.10f} (input) / ${output_price:.10f} (output)"
else:
# Skip models with unexpected price structure
continue

total_price = (
calculate_price(input_token_count, input_price) +
calculate_price(output_token_count, output_price)
result = (
p["name"], # Provider name
m["name"], # Model name
price_info, # Price information
total_tokens, # Total number of tokens
f"${total_price:.10f}" # Total price
)

table_data.append([
provider_data["name"],
model_data["name"],
f"${input_price:.7f}",
input_token_count,
f"${total_price:.7f}"
])
results.append(result)

return table_data
return results
46 changes: 37 additions & 9 deletions tests/test_price.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
from pathlib import Path
from unittest.mock import patch, mock_open
from code2prompt.utils.price_calculator import load_token_prices, calculate_prices
from code2prompt.main import create_markdown_file

# Mock JSON data
MOCK_JSON_DATA = '''
Expand Down Expand Up @@ -61,27 +59,57 @@ def test_load_token_prices_invalid_json():
with pytest.raises(RuntimeError, match="Error loading token prices"):
load_token_prices()

def test_calculate_prices_specific_provider_model(mock_token_prices):
def test_calculate_prices_single_price_model(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000, "provider1", "model1")
assert len(result) == 1
assert result[0][0] == "provider1"
assert result[0][1] == "model1"
assert result[0][2] == "$0.100000"
assert result[0][3] == 1000
assert result[0][4] == "$0.20"
assert result[0] == ("provider1", "model1", "$0.1000000000", 2000, "$0.2000000000")

def test_calculate_prices_dual_price_model(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 2000, "provider1", "model2")
assert len(result) == 1
assert result[0] == ("provider1", "model2", "$0.3000000000 (input) / $0.4000000000 (output)", 3000, "$1.1000000000")

def test_calculate_prices_all_providers_models(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000)
assert len(result) == 4
assert set(row[0] for row in result) == {"provider1", "provider2"}
assert set(row[1] for row in result) == {"model1", "model2"}

def test_calculate_prices_specific_provider(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000, "provider1")
assert len(result) == 2
assert all(row[0] == "provider1" for row in result)
assert set(row[1] for row in result) == {"model1", "model2"}

def test_calculate_prices_zero_tokens(mock_token_prices):
result = calculate_prices(mock_token_prices, 0, 0)
assert all(row[4] == "$0.00" for row in result)
assert len(result) == 4
assert all(row[4] == "$0.0000000000" for row in result)

def test_calculate_prices_different_input_output_tokens(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 2000, "provider2", "model1")
assert len(result) == 1
assert result[0] == ("provider2", "model1", "$0.3000000000 (input) / $0.4000000000 (output)", 3000, "$1.1000000000")

def test_calculate_prices_non_existent_provider(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000, "non_existent_provider")
assert len(result) == 0

def test_calculate_prices_non_existent_model(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000, "provider1", "non_existent_model")
assert len(result) == 0

def test_calculate_prices_large_numbers(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000000, 1000000, "provider1", "model1")
assert len(result) == 1
assert result[0] == ("provider1", "model1", "$0.1000000000", 2000000, "$200.0000000000")

def test_calculate_prices_small_numbers(mock_token_prices):
result = calculate_prices(mock_token_prices, 1, 1, "provider1", "model1")
assert len(result) == 1
assert result[0] == ("provider1", "model1", "$0.1000000000", 2, "$0.0002000000")

def test_calculate_prices_floating_point_precision(mock_token_prices):
result = calculate_prices(mock_token_prices, 1000, 1000, "provider2", "model1")
assert len(result) == 1
assert result[0] == ("provider2", "model1", "$0.3000000000 (input) / $0.4000000000 (output)", 2000, "$0.7000000000")

0 comments on commit 4ae41f1

Please sign in to comment.