Skip to content

Commit 377717e

Browse files
authored
Merge pull request #1316 from julep-ai/f/usage-tracking
feat(agents-api): add usage tracking + usage table
2 parents 09cf589 + 5de4fb2 commit 377717e

File tree

8 files changed

+640
-1
lines changed

8 files changed

+640
-1
lines changed

agents-api/agents_api/clients/litellm.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import wraps
22
from typing import Literal
3+
from uuid import UUID
34

45
import aiohttp
56
from beartype import beartype
@@ -8,6 +9,7 @@
89
from litellm import get_supported_openai_params
910
from litellm.utils import CustomStreamWrapper, ModelResponse, get_valid_models
1011

12+
from ..common.utils.usage import track_embedding_usage, track_usage
1113
from ..env import (
1214
embedding_dimensions,
1315
embedding_model_id,
@@ -76,7 +78,26 @@ async def acompletion(
7678
api_key=custom_api_key or litellm_master_key,
7779
)
7880

79-
return patch_litellm_response(model_response)
81+
response = patch_litellm_response(model_response)
82+
83+
# Track usage in database if we have a user ID (which should be the developer ID)
84+
user = settings.get("user")
85+
if user and isinstance(response, ModelResponse):
86+
try:
87+
model = response.model
88+
await track_usage(
89+
developer_id=UUID(user),
90+
model=model,
91+
messages=messages,
92+
response=response,
93+
custom_api_used=custom_api_key is not None,
94+
metadata={"tags": kwargs.get("tags", [])},
95+
)
96+
except Exception as e:
97+
# Log error but don't fail the request if usage tracking fails
98+
print(f"Error tracking usage: {e}")
99+
100+
return response
80101

81102

82103
@wraps(_aembedding)
@@ -114,6 +135,27 @@ async def aembedding(
114135
**settings,
115136
)
116137

138+
# Track embedding usage if we have a user ID
139+
user = settings.get("user")
140+
if user:
141+
try:
142+
model = response.model
143+
await track_embedding_usage(
144+
developer_id=UUID(user),
145+
model=model,
146+
inputs=input,
147+
response=response,
148+
custom_api_used=bool(custom_api_key),
149+
metadata={
150+
"request_id": response.id if hasattr(response, "id") else None,
151+
"embedding_count": len(input),
152+
"tags": settings.get("tags", []),
153+
},
154+
)
155+
except Exception as e:
156+
# Log error but don't fail the request if usage tracking fails
157+
print(f"Error tracking embedding usage: {e}")
158+
117159
embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data
118160

119161
# Truncate the embedding to the specified dimensions
+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Utilities for tracking token usage and costs for LLM API calls.
3+
"""
4+
5+
from typing import Any
6+
from uuid import UUID
7+
8+
from beartype import beartype
9+
from litellm.utils import ModelResponse, token_counter
10+
11+
from ...queries.usage.create_usage_record import create_usage_record
12+
13+
14+
@beartype
15+
async def track_usage(
16+
*,
17+
developer_id: UUID,
18+
model: str,
19+
messages: list[dict],
20+
response: ModelResponse,
21+
custom_api_used: bool = False,
22+
metadata: dict[str, Any] = {},
23+
) -> None:
24+
"""
25+
Tracks token usage and costs for an LLM API call.
26+
27+
Parameters:
28+
developer_id (UUID): The unique identifier for the developer.
29+
model (str): The model used for the API call.
30+
messages (list[dict]): The messages sent to the model.
31+
response (ModelResponse): The response from the LLM API call.
32+
custom_api_used (bool): Whether a custom API key was used.
33+
metadata (dict): Additional metadata about the usage.
34+
35+
Returns:
36+
None
37+
"""
38+
39+
# Try to get token counts from response.usage
40+
if response.usage:
41+
prompt_tokens = response.usage.prompt_tokens
42+
completion_tokens = response.usage.completion_tokens
43+
else:
44+
# Calculate tokens manually if usage is not available
45+
prompt_tokens = token_counter(model=model, messages=messages)
46+
47+
# Calculate completion tokens from the response
48+
completion_content = [
49+
{"content": choice.message.content}
50+
for choice in response.choices
51+
if hasattr(choice, "message")
52+
and choice.message
53+
and hasattr(choice.message, "content")
54+
and choice.message.content
55+
]
56+
57+
completion_tokens = (
58+
token_counter(model=model, messages=completion_content) if completion_content else 0
59+
)
60+
61+
# Map the model name to the actual model name
62+
actual_model = model
63+
64+
# Create usage record
65+
await create_usage_record(
66+
developer_id=developer_id,
67+
model=actual_model,
68+
prompt_tokens=prompt_tokens,
69+
completion_tokens=completion_tokens,
70+
custom_api_used=custom_api_used,
71+
metadata={
72+
"request_id": response.id if hasattr(response, "id") else None,
73+
**metadata,
74+
},
75+
)
76+
77+
78+
@beartype
79+
async def track_embedding_usage(
80+
*,
81+
developer_id: UUID,
82+
model: str,
83+
inputs: list[str],
84+
response: Any,
85+
custom_api_used: bool = False,
86+
metadata: dict[str, Any] = {},
87+
) -> None:
88+
"""
89+
Tracks token usage and costs for an embedding API call.
90+
91+
Parameters:
92+
developer_id (UUID): The unique identifier for the developer.
93+
model (str): The model used for the embedding.
94+
inputs (list[str]): The inputs sent for embedding.
95+
response (Any): The response from the embedding API call.
96+
custom_api_used (bool): Whether a custom API key was used.
97+
metadata (dict): Additional metadata about the usage.
98+
99+
Returns:
100+
None
101+
"""
102+
103+
# Try to get token count from response.usage
104+
if hasattr(response, "usage") and response.usage:
105+
prompt_tokens = response.usage.prompt_tokens
106+
else:
107+
# Calculate tokens manually if usage is not available
108+
prompt_tokens = sum(
109+
token_counter(model=model, text=input_text) for input_text in inputs
110+
)
111+
112+
# Map the model name to the actual model name
113+
actual_model = model
114+
115+
# Create usage record for embeddings (no completion tokens)
116+
await create_usage_record(
117+
developer_id=developer_id,
118+
model=actual_model,
119+
prompt_tokens=prompt_tokens,
120+
completion_tokens=0, # Embeddings don't have completion tokens
121+
custom_api_used=custom_api_used,
122+
metadata=metadata,
123+
)

agents-api/agents_api/queries/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
from . import sessions as sessions
1818
from . import tasks as tasks
1919
from . import tools as tools
20+
from . import usage as usage
2021
from . import users as users
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
"""
2+
The `usage` module within the `queries` package provides functionality for tracking token usage
3+
and costs associated with LLM API calls. This includes:
4+
5+
- Recording prompt and completion tokens
6+
- Calculating costs based on model pricing
7+
- Storing usage data with developer attribution
8+
- Supporting both standard and custom API usage
9+
10+
Each function in this module constructs and executes SQL queries for database operations
11+
related to usage tracking and reporting.
12+
"""
13+
14+
# ruff: noqa: F401, F403, F405
15+
16+
from .create_usage_record import create_usage_record
17+
18+
__all__ = [
19+
"create_usage_record",
20+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""
2+
This module contains functionality for creating usage records in the PostgreSQL database.
3+
It tracks token usage and costs for LLM API calls.
4+
"""
5+
6+
from typing import Any
7+
from uuid import UUID
8+
9+
from beartype import beartype
10+
from litellm import cost_per_token
11+
12+
from ...common.utils.db_exceptions import common_db_exceptions
13+
from ...metrics.counters import query_metrics
14+
from ..utils import pg_query, rewrap_exceptions
15+
16+
FALLBACK_PRICING = {
17+
# Meta Llama models
18+
"meta-llama/llama-4-scout": {
19+
"api_request": 0.08 / 1000,
20+
"api_response": 0.45 / 1000,
21+
},
22+
"meta-llama/llama-4-maverick": {
23+
"api_request": 0.19 / 1000,
24+
"api_response": 0.85 / 1000,
25+
},
26+
"meta-llama/llama-4-maverick:free": {
27+
"api_request": 0.0 / 1000,
28+
"api_response": 0.0 / 1000,
29+
},
30+
# Qwen model
31+
"qwen/qwen-2.5-72b-instruct": {
32+
"api_request": 0.7 / 1000,
33+
"api_response": 0.7 / 1000,
34+
},
35+
# Sao10k model
36+
"sao10k/l3.3-euryale-70b": {
37+
"api_request": 0.7 / 1000,
38+
"api_response": 0.8 / 1000,
39+
},
40+
"sao10k/l3.1-euryale-70b": {
41+
"api_request": 0.7 / 1000,
42+
"api_response": 0.8 / 1000,
43+
},
44+
}
45+
46+
# Define the raw SQL query
47+
usage_query = """
48+
INSERT INTO usage (
49+
developer_id,
50+
model,
51+
prompt_tokens,
52+
completion_tokens,
53+
cost,
54+
estimated,
55+
custom_api_used,
56+
metadata
57+
)
58+
VALUES (
59+
$1, -- developer_id
60+
$2, -- model
61+
$3, -- prompt_tokens
62+
$4, -- completion_tokens
63+
$5, -- cost
64+
$6, -- estimated
65+
$7, -- custom_api_used
66+
$8 -- metadata
67+
)
68+
RETURNING *;
69+
"""
70+
71+
72+
@rewrap_exceptions(common_db_exceptions("usage", ["create"]))
73+
@query_metrics("create_usage_record")
74+
@pg_query
75+
@beartype
76+
async def create_usage_record(
77+
*,
78+
developer_id: UUID,
79+
model: str,
80+
prompt_tokens: int,
81+
completion_tokens: int,
82+
custom_api_used: bool = False,
83+
estimated: bool = False,
84+
metadata: dict[str, Any] | None = None,
85+
) -> tuple[str, list]:
86+
"""
87+
Creates a usage record to track token usage and costs.
88+
89+
Parameters:
90+
developer_id (UUID): The unique identifier for the developer.
91+
model (str): The model used for the API call.
92+
prompt_tokens (int): Number of tokens in the prompt.
93+
completion_tokens (int): Number of tokens in the completion.
94+
custom_api_used (bool): Whether a custom API key was used.
95+
estimated (bool): Whether the token count is estimated.
96+
metadata (dict | None): Additional metadata about the usage.
97+
98+
Returns:
99+
tuple[str, list]: SQL query and parameters for creating the usage record.
100+
"""
101+
# Calculate cost based on token usage
102+
# For custom API keys, we still track usage but mark it as such
103+
total_cost = 0.0
104+
105+
if not custom_api_used:
106+
# Calculate cost using litellm's cost_per_token function
107+
try:
108+
prompt_cost, completion_cost = cost_per_token(
109+
model, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
110+
)
111+
total_cost = prompt_cost + completion_cost
112+
except Exception:
113+
estimated = True
114+
115+
if model in FALLBACK_PRICING:
116+
total_cost = (
117+
FALLBACK_PRICING[model]["api_request"] * prompt_tokens
118+
+ FALLBACK_PRICING[model]["api_response"] * completion_tokens
119+
)
120+
else:
121+
print(f"No fallback pricing found for model {model}")
122+
123+
params = [
124+
developer_id,
125+
model,
126+
prompt_tokens,
127+
completion_tokens,
128+
total_cost,
129+
estimated,
130+
custom_api_used,
131+
metadata or {},
132+
]
133+
134+
return (
135+
usage_query,
136+
params,
137+
)

0 commit comments

Comments
 (0)