Skip to content

Commit fb478b3

Browse files
committed
chore(agents-api): lint
1 parent 7b0fad7 commit fb478b3

File tree

6 files changed

+108
-76
lines changed

6 files changed

+108
-76
lines changed

agents-api/agents_api/clients/litellm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from litellm import get_supported_openai_params
1010
from litellm.utils import CustomStreamWrapper, ModelResponse, get_valid_models
1111

12-
from ..common.utils.usage import track_usage, track_embedding_usage
12+
from ..common.utils.usage import track_embedding_usage, track_usage
1313
from ..env import (
1414
embedding_dimensions,
1515
embedding_model_id,
@@ -79,7 +79,7 @@ async def acompletion(
7979
)
8080

8181
response = patch_litellm_response(model_response)
82-
82+
8383
# Track usage in database if we have a user ID (which should be the developer ID)
8484
user = settings.get("user")
8585
if user and isinstance(response, ModelResponse):
@@ -90,7 +90,7 @@ async def acompletion(
9090
model=model,
9191
messages=messages,
9292
response=response,
93-
custom_api_used= custom_api_key is not None,
93+
custom_api_used=custom_api_key is not None,
9494
metadata={"tags": kwargs.get("tags", [])},
9595
)
9696
except Exception as e:
@@ -134,7 +134,7 @@ async def aembedding(
134134
drop_params=True,
135135
**settings,
136136
)
137-
137+
138138
# Track embedding usage if we have a user ID
139139
user = settings.get("user")
140140
if user:

agents-api/agents_api/common/utils/usage.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,15 @@
22
Utilities for tracking token usage and costs for LLM API calls.
33
"""
44

5-
from typing import Any, Optional
5+
from typing import Any
66
from uuid import UUID
77

88
from beartype import beartype
9-
from litellm import cost_per_token
10-
from litellm.utils import ModelResponse, _select_tokenizer as select_tokenizer
11-
from litellm.utils import token_counter
9+
from litellm.utils import ModelResponse, token_counter
1210

1311
from ...queries.usage.create_usage_record import create_usage_record
1412

13+
1514
@beartype
1615
async def track_usage(
1716
*,
@@ -36,27 +35,28 @@ async def track_usage(
3635
Returns:
3736
None
3837
"""
39-
38+
4039
# Try to get token counts from response.usage
4140
if response.usage:
4241
prompt_tokens = response.usage.prompt_tokens
4342
completion_tokens = response.usage.completion_tokens
4443
else:
4544
# Calculate tokens manually if usage is not available
4645
prompt_tokens = token_counter(model=model, messages=messages)
47-
46+
4847
# Calculate completion tokens from the response
49-
completion_content = []
50-
for choice in response.choices:
51-
if hasattr(choice, "message") and choice.message:
52-
if hasattr(choice.message, "content") and choice.message.content:
53-
completion_content.append({"content": choice.message.content})
54-
55-
completion_tokens = token_counter(
56-
model=model,
57-
messages=completion_content
58-
) if completion_content else 0
59-
48+
completion_content = [
49+
{"content": choice.message.content}
50+
for choice in response.choices
51+
if hasattr(choice, "message")
52+
and choice.message
53+
and hasattr(choice.message, "content")
54+
and choice.message.content
55+
]
56+
57+
completion_tokens = (
58+
token_counter(model=model, messages=completion_content) if completion_content else 0
59+
)
6060

6161
# Map the model name to the actual model name
6262
actual_model = model
@@ -99,13 +99,15 @@ async def track_embedding_usage(
9999
Returns:
100100
None
101101
"""
102-
102+
103103
# Try to get token count from response.usage
104104
if hasattr(response, "usage") and response.usage:
105105
prompt_tokens = response.usage.prompt_tokens
106106
else:
107107
# Calculate tokens manually if usage is not available
108-
prompt_tokens = sum(token_counter(model=model, text=input_text) for input_text in inputs)
108+
prompt_tokens = sum(
109+
token_counter(model=model, text=input_text) for input_text in inputs
110+
)
109111

110112
# Map the model name to the actual model name
111113
actual_model = model
@@ -118,4 +120,4 @@ async def track_embedding_usage(
118120
completion_tokens=0, # Embeddings don't have completion tokens
119121
custom_api_used=custom_api_used,
120122
metadata=metadata,
121-
)
123+
)

agents-api/agents_api/queries/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@
1717
from . import sessions as sessions
1818
from . import tasks as tasks
1919
from . import tools as tools
20+
from . import usage as usage
2021
from . import users as users
21-
from . import usage as usage

agents-api/agents_api/queries/usage/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@
1717

1818
__all__ = [
1919
"create_usage_record",
20-
]
20+
]

agents-api/agents_api/queries/usage/create_usage_record.py

+38-17
Original file line numberDiff line numberDiff line change
@@ -71,30 +71,51 @@ async def create_usage_record(
7171
# Calculate cost based on token usage
7272
# For custom API keys, we still track usage but mark it as such
7373
total_cost = 0.0
74-
74+
7575
if not custom_api_used:
7676
# Calculate cost using litellm's cost_per_token function
7777
try:
78-
prompt_cost, completion_cost = cost_per_token(model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens)
78+
prompt_cost, completion_cost = cost_per_token(
79+
model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
80+
)
7981
total_cost = prompt_cost + completion_cost
80-
except Exception as e:
82+
except Exception:
8183
estimated = True
8284
fallback_pricing = {
83-
# Meta Llama models
84-
'meta-llama/llama-4-scout': {'api_request': 0.08/1000, 'api_response': 0.45/1000},
85-
'meta-llama/llama-4-maverick': {'api_request': 0.19/1000, 'api_response': 0.85/1000},
86-
'meta-llama/llama-4-maverick:free': {'api_request': 0.0/1000, 'api_response': 0.0/1000},
87-
88-
# Qwen model
89-
'qwen/qwen-2.5-72b-instruct': {'api_request': 0.7/1000, 'api_response': 0.7/1000},
90-
91-
# Sao10k model
92-
'sao10k/l3.3-euryale-70b': {'api_request': 0.7/1000, 'api_response': 0.8/1000},
93-
'sao10k/l3.1-euryale-70b': {'api_request': 0.7/1000, 'api_response': 0.8/1000}
85+
# Meta Llama models
86+
"meta-llama/llama-4-scout": {
87+
"api_request": 0.08 / 1000,
88+
"api_response": 0.45 / 1000,
89+
},
90+
"meta-llama/llama-4-maverick": {
91+
"api_request": 0.19 / 1000,
92+
"api_response": 0.85 / 1000,
93+
},
94+
"meta-llama/llama-4-maverick:free": {
95+
"api_request": 0.0 / 1000,
96+
"api_response": 0.0 / 1000,
97+
},
98+
# Qwen model
99+
"qwen/qwen-2.5-72b-instruct": {
100+
"api_request": 0.7 / 1000,
101+
"api_response": 0.7 / 1000,
102+
},
103+
# Sao10k model
104+
"sao10k/l3.3-euryale-70b": {
105+
"api_request": 0.7 / 1000,
106+
"api_response": 0.8 / 1000,
107+
},
108+
"sao10k/l3.1-euryale-70b": {
109+
"api_request": 0.7 / 1000,
110+
"api_response": 0.8 / 1000,
111+
},
94112
}
95-
113+
96114
if model in fallback_pricing:
97-
total_cost = fallback_pricing[model]['api_request'] * prompt_tokens + fallback_pricing[model]['api_response'] * completion_tokens
115+
total_cost = (
116+
fallback_pricing[model]["api_request"] * prompt_tokens
117+
+ fallback_pricing[model]["api_response"] * completion_tokens
118+
)
98119
else:
99120
print(f"No fallback pricing found for model {model}")
100121

@@ -112,4 +133,4 @@ async def create_usage_record(
112133
return (
113134
usage_query,
114135
params,
115-
)
136+
)

agents-api/tests/test_usage_tracking.py

+42-33
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
Tests for usage tracking functionality.
33
"""
44

5-
import io
6-
from contextlib import redirect_stdout
7-
from decimal import Decimal
85
from datetime import datetime
6+
from decimal import Decimal
97
from unittest.mock import patch
108

9+
from agents_api.clients.pg import create_db_pool
10+
from agents_api.common.utils.usage import track_embedding_usage, track_usage
11+
from agents_api.queries.usage.create_usage_record import create_usage_record
12+
from litellm import cost_per_token
13+
from litellm.utils import Message, ModelResponse, Usage, token_counter
1114
from ward import test
1215

13-
from agents_api.common.utils.usage import track_usage, track_embedding_usage
14-
from agents_api.queries.usage.create_usage_record import create_usage_record
15-
from litellm.utils import ModelResponse, Usage, Choices, Message
16-
from agents_api.clients.pg import create_db_pool
1716
from .fixtures import pg_dsn, test_developer_id
18-
from litellm import cost_per_token
19-
from litellm.utils import token_counter
17+
2018

2119
@test("query: create_usage_record creates a record with correct parameters")
2220
async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
@@ -31,10 +29,10 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
3129
assert len(response) == 1
3230
record = response[0]
3331
assert record["developer_id"] == developer_id
34-
assert record["model"] == 'gpt-4o-mini'
32+
assert record["model"] == "gpt-4o-mini"
3533
assert record["prompt_tokens"] == 100
3634
assert record["completion_tokens"] == 100
37-
assert record["cost"] == Decimal('0.000075')
35+
assert record["cost"] == Decimal("0.000075")
3836
assert record["estimated"] is False
3937
assert record["custom_api_used"] is False
4038
assert record["metadata"] == {}
@@ -60,7 +58,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
6058
"meta-llama/llama-4-maverick:free",
6159
"qwen/qwen-2.5-72b-instruct",
6260
"sao10k/l3.3-euryale-70b",
63-
"sao10k/l3.1-euryale-70b"
61+
"sao10k/l3.1-euryale-70b",
6462
]
6563
for model in models:
6664
response = await create_usage_record(
@@ -86,9 +84,11 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
8684
connection_pool=pool,
8785
)
8886

89-
input_cost, completion_cost = cost_per_token("gpt-4o-mini", prompt_tokens=2041, completion_tokens=34198)
87+
input_cost, completion_cost = cost_per_token(
88+
"gpt-4o-mini", prompt_tokens=2041, completion_tokens=34198
89+
)
9090
cost = input_cost + completion_cost
91-
cost = Decimal(str(cost)).quantize(Decimal('0.000001'))
91+
cost = Decimal(str(cost)).quantize(Decimal("0.000001"))
9292

9393
assert len(response) == 1
9494
record = response[0]
@@ -125,13 +125,14 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
125125

126126
assert len(response) == 1
127127
record = response[0]
128-
assert record["cost"] == Decimal('0.000000')
128+
assert record["cost"] == Decimal("0.000000")
129129
assert record["estimated"] is True
130130

131+
131132
@test("query: create_usage_record with fallback pricing with model not in fallback pricing")
132133
async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
133134
pool = await create_db_pool(dsn=dsn)
134-
135+
135136
with patch("builtins.print") as mock_print:
136137
unknown_model = "unknown-model-name"
137138
response = await create_usage_record(
@@ -146,7 +147,7 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id) -> None:
146147

147148
assert len(response) == 1
148149
record = response[0]
149-
assert record["cost"] == Decimal('0.000000')
150+
assert record["cost"] == Decimal("0.000000")
150151
assert record["estimated"] is True
151152
assert expected_call == actual_call
152153

@@ -171,23 +172,28 @@ async def _(developer_id=test_developer_id) -> None:
171172
assert call_args["prompt_tokens"] == 100
172173
assert call_args["completion_tokens"] == 100
173174

174-
175+
175176
@test("utils: track_usage without response.usage")
176177
async def _(developer_id=test_developer_id) -> None:
177178
with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record:
178179
response = ModelResponse(
179180
usage=None,
180-
choices=[{
181-
"finish_reason": "stop",
182-
"index": 0,
183-
"message": Message(content="Hello, world!", role="assistant")
184-
}]
181+
choices=[
182+
{
183+
"finish_reason": "stop",
184+
"index": 0,
185+
"message": Message(content="Hello, world!", role="assistant"),
186+
}
187+
],
185188
)
186189
response.usage = None
187190
messages = [{"role": "user", "content": "Hello, world!"}]
188191

189192
prompt_tokens = token_counter(model="gpt-4o-mini", messages=messages)
190-
completion_tokens = token_counter(model="gpt-4o-mini", messages=[{"content": choice.message.content} for choice in response.choices])
193+
completion_tokens = token_counter(
194+
model="gpt-4o-mini",
195+
messages=[{"content": choice.message.content} for choice in response.choices],
196+
)
191197

192198
await track_usage(
193199
developer_id=developer_id,
@@ -210,16 +216,16 @@ async def _(developer_id=test_developer_id) -> None:
210216
completion_tokens=0,
211217
),
212218
)
213-
219+
214220
inputs = ["This is a test input for embedding"]
215-
221+
216222
await track_embedding_usage(
217223
developer_id=developer_id,
218224
model="text-embedding-3-large",
219225
inputs=inputs,
220226
response=response,
221227
)
222-
228+
223229
call_args = mock_create_usage_record.call_args[1]
224230
assert call_args["prompt_tokens"] == 150
225231
assert call_args["completion_tokens"] == 0
@@ -231,20 +237,23 @@ async def _(developer_id=test_developer_id) -> None:
231237
with patch("agents_api.common.utils.usage.create_usage_record") as mock_create_usage_record:
232238
response = ModelResponse()
233239
response.usage = None
234-
240+
235241
inputs = ["First test input", "Second test input"]
236-
242+
237243
# Calculate expected tokens manually
238-
expected_tokens = sum(token_counter(model="text-embedding-3-large", text=input_text) for input_text in inputs)
239-
244+
expected_tokens = sum(
245+
token_counter(model="text-embedding-3-large", text=input_text)
246+
for input_text in inputs
247+
)
248+
240249
await track_embedding_usage(
241250
developer_id=developer_id,
242251
model="text-embedding-3-large",
243252
inputs=inputs,
244253
response=response,
245254
)
246-
255+
247256
call_args = mock_create_usage_record.call_args[1]
248257
assert call_args["prompt_tokens"] == expected_tokens
249258
assert call_args["completion_tokens"] == 0
250-
assert call_args["model"] == "text-embedding-3-large"
259+
assert call_args["model"] == "text-embedding-3-large"

0 commit comments

Comments
 (0)