From 4ae41f1486666b03682c0e0310714c110caa5a05 Mon Sep 17 00:00:00 2001 From: Raphael MANSUY Date: Fri, 26 Jul 2024 09:01:54 +0800 Subject: [PATCH] fix(code2prompt/utils): Optimize `calculate_prices` function and add more test cases --- code2prompt/utils/price_calculator.py | 72 +++++++++++++++------------ tests/test_price.py | 46 +++++++++++++---- 2 files changed, 77 insertions(+), 41 deletions(-) diff --git a/code2prompt/utils/price_calculator.py b/code2prompt/utils/price_calculator.py index d06c560..8daee85 100644 --- a/code2prompt/utils/price_calculator.py +++ b/code2prompt/utils/price_calculator.py @@ -31,49 +31,57 @@ def calculate_price(token_count, price_per_1000): """ return (token_count / 1_000) * price_per_1000 -def calculate_prices(token_prices, input_token_count, output_token_count, provider=None, model=None): +def calculate_prices(token_prices, input_tokens, output_tokens, provider=None, model=None): """ - Calculate the prices based on the token prices, input and output token counts, provider, and model. + Calculate the prices for a given number of input and output tokens based on token prices. Args: token_prices (dict): A dictionary containing token prices for different providers and models. - input_token_count (int): The number of input tokens. - output_token_count (int): The number of output tokens. - provider (str, optional): The name of the provider. If specified, only prices for this provider will be calculated. Defaults to None. - model (str, optional): The name of the model. If specified, only prices for this model will be calculated. Defaults to None. + input_tokens (int): The number of input tokens. + output_tokens (int): The number of output tokens. + provider (str, optional): The name of the provider. If specified, only prices for the specified provider will be calculated. Defaults to None. + model (str, optional): The name of the model. If specified, only prices for the specified model will be calculated. Defaults to None. Returns: - list: A list of lists containing the calculated prices for each provider and model. Each inner list contains the following information: - - Provider name - - Model name - - Input price - - Input token count - - Total price + list: A list of tuples containing the provider name, model name, price per token, total tokens, and total price for each calculation. + """ - table_data = [] - for provider_data in token_prices["providers"]: - # Convert both strings to lowercase for case-insensitive comparison - if provider and provider_data["name"].lower() != provider.lower(): +def calculate_prices(token_prices, input_tokens, output_tokens, provider=None, model=None): + results = [] + + for p in token_prices["providers"]: + if provider and p["name"] != provider: continue - for model_data in provider_data["models"]: - # Convert both strings to lowercase for case-insensitive comparison - if model and model_data["name"].lower() != model.lower(): + + for m in p["models"]: + if model and m["name"] != model: continue - input_price = model_data.get("input_price", model_data.get("price", 0)) - output_price = model_data.get("output_price", model_data.get("price", 0)) + total_tokens = input_tokens + output_tokens + + if "price" in m: + # Single price for both input and output tokens + price = m["price"] + total_price = (price * total_tokens) / 1000 + price_info = f"${price:.10f}" + elif "input_price" in m and "output_price" in m: + # Separate prices for input and output tokens + input_price = m["input_price"] + output_price = m["output_price"] + total_price = ((input_price * input_tokens) + (output_price * output_tokens)) / 1000 + price_info = f"${input_price:.10f} (input) / ${output_price:.10f} (output)" + else: + # Skip models with unexpected price structure + continue - total_price = ( - calculate_price(input_token_count, input_price) + - calculate_price(output_token_count, output_price) + result = ( + p["name"], # Provider name + m["name"], # Model name + price_info, # Price information + total_tokens, # Total number of tokens + f"${total_price:.10f}" # Total price ) - table_data.append([ - provider_data["name"], - model_data["name"], - f"${input_price:.7f}", - input_token_count, - f"${total_price:.7f}" - ]) + results.append(result) - return table_data \ No newline at end of file + return results \ No newline at end of file diff --git a/tests/test_price.py b/tests/test_price.py index 99d8ef4..de0578b 100644 --- a/tests/test_price.py +++ b/tests/test_price.py @@ -1,8 +1,6 @@ import pytest -from pathlib import Path from unittest.mock import patch, mock_open from code2prompt.utils.price_calculator import load_token_prices, calculate_prices -from code2prompt.main import create_markdown_file # Mock JSON data MOCK_JSON_DATA = ''' @@ -61,27 +59,57 @@ def test_load_token_prices_invalid_json(): with pytest.raises(RuntimeError, match="Error loading token prices"): load_token_prices() -def test_calculate_prices_specific_provider_model(mock_token_prices): +def test_calculate_prices_single_price_model(mock_token_prices): result = calculate_prices(mock_token_prices, 1000, 1000, "provider1", "model1") assert len(result) == 1 - assert result[0][0] == "provider1" - assert result[0][1] == "model1" - assert result[0][2] == "$0.100000" - assert result[0][3] == 1000 - assert result[0][4] == "$0.20" + assert result[0] == ("provider1", "model1", "$0.1000000000", 2000, "$0.2000000000") + +def test_calculate_prices_dual_price_model(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000, 2000, "provider1", "model2") + assert len(result) == 1 + assert result[0] == ("provider1", "model2", "$0.3000000000 (input) / $0.4000000000 (output)", 3000, "$1.1000000000") def test_calculate_prices_all_providers_models(mock_token_prices): result = calculate_prices(mock_token_prices, 1000, 1000) assert len(result) == 4 + assert set(row[0] for row in result) == {"provider1", "provider2"} + assert set(row[1] for row in result) == {"model1", "model2"} def test_calculate_prices_specific_provider(mock_token_prices): result = calculate_prices(mock_token_prices, 1000, 1000, "provider1") assert len(result) == 2 assert all(row[0] == "provider1" for row in result) + assert set(row[1] for row in result) == {"model1", "model2"} def test_calculate_prices_zero_tokens(mock_token_prices): result = calculate_prices(mock_token_prices, 0, 0) - assert all(row[4] == "$0.00" for row in result) + assert len(result) == 4 + assert all(row[4] == "$0.0000000000" for row in result) + +def test_calculate_prices_different_input_output_tokens(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000, 2000, "provider2", "model1") + assert len(result) == 1 + assert result[0] == ("provider2", "model1", "$0.3000000000 (input) / $0.4000000000 (output)", 3000, "$1.1000000000") +def test_calculate_prices_non_existent_provider(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000, 1000, "non_existent_provider") + assert len(result) == 0 +def test_calculate_prices_non_existent_model(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000, 1000, "provider1", "non_existent_model") + assert len(result) == 0 +def test_calculate_prices_large_numbers(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000000, 1000000, "provider1", "model1") + assert len(result) == 1 + assert result[0] == ("provider1", "model1", "$0.1000000000", 2000000, "$200.0000000000") + +def test_calculate_prices_small_numbers(mock_token_prices): + result = calculate_prices(mock_token_prices, 1, 1, "provider1", "model1") + assert len(result) == 1 + assert result[0] == ("provider1", "model1", "$0.1000000000", 2, "$0.0002000000") + +def test_calculate_prices_floating_point_precision(mock_token_prices): + result = calculate_prices(mock_token_prices, 1000, 1000, "provider2", "model1") + assert len(result) == 1 + assert result[0] == ("provider2", "model1", "$0.3000000000 (input) / $0.4000000000 (output)", 2000, "$0.7000000000") \ No newline at end of file