Skip to content

Commit 04e334a

Browse files
committed
Clamp recent usage limit requests
1 parent 7357848 commit 04e334a

File tree

6 files changed

+190
-19
lines changed

6 files changed

+190
-19
lines changed

src/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
1+
"""Project-wide constants used across multiple modules."""
2+
13
DEFAULT_COMMAND_PREFIX: str = "!/"
4+
"""Default command prefix for interactive commands."""
5+
6+
MAX_RECENT_USAGE_RECORDS: int = 1000
7+
"""Maximum number of recent usage records that can be requested at once."""

src/core/app/controllers/usage_controller.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from src.core.di.services import get_or_build_service_provider
1313
from src.core.domain.usage_data import UsageData
1414
from src.core.interfaces.usage_tracking_interface import IUsageTrackingService
15+
from src.constants import MAX_RECENT_USAGE_RECORDS
16+
from src.core.common.usage_limits import normalize_recent_usage_limit
1517

1618
logger = logging.getLogger(__name__)
1719

@@ -62,8 +64,31 @@ async def get_recent_usage(
6264
if not self.usage_service:
6365
return []
6466

67+
try:
68+
requested_limit = int(limit)
69+
except (TypeError, ValueError):
70+
requested_limit = 0
71+
72+
normalized_limit = normalize_recent_usage_limit(requested_limit)
73+
74+
if normalized_limit == 0:
75+
if logger.isEnabledFor(logging.DEBUG):
76+
logger.debug(
77+
"Recent usage requested with limit=%s; returning empty result", limit
78+
)
79+
return []
80+
81+
if normalized_limit < requested_limit:
82+
if logger.isEnabledFor(logging.INFO):
83+
logger.info(
84+
"Recent usage limit clamped from %s to %s (max=%s)",
85+
limit,
86+
normalized_limit,
87+
MAX_RECENT_USAGE_RECORDS,
88+
)
89+
6590
result = await self.usage_service.get_recent_usage(
66-
session_id=session_id, limit=limit
91+
session_id=session_id, limit=normalized_limit
6792
)
6893
return result # type: ignore[no-any-return]
6994

@@ -92,7 +117,12 @@ async def get_usage_stats(
92117
@router.get("/recent", response_model=list[UsageData])
93118
async def get_recent_usage(
94119
session_id: str | None = Query(None, description="Filter by session ID"),
95-
limit: int = Query(100, description="Maximum number of records to return"),
120+
limit: int = Query(
121+
100,
122+
description="Maximum number of records to return",
123+
ge=0,
124+
le=MAX_RECENT_USAGE_RECORDS,
125+
),
96126
service_provider: Any = Depends(get_or_build_service_provider),
97127
) -> list[UsageData]:
98128
"""Get recent usage data.
@@ -106,5 +136,30 @@ async def get_recent_usage(
106136
List of usage data entities
107137
"""
108138
usage_service = service_provider.get_required_service(IUsageTrackingService)
109-
result = await usage_service.get_recent_usage(session_id=session_id, limit=limit)
139+
try:
140+
requested_limit = int(limit)
141+
except (TypeError, ValueError):
142+
requested_limit = 0
143+
144+
normalized_limit = normalize_recent_usage_limit(requested_limit)
145+
146+
if normalized_limit == 0:
147+
if logger.isEnabledFor(logging.DEBUG):
148+
logger.debug(
149+
"API recent usage requested with limit=%s; returning empty result", limit
150+
)
151+
return []
152+
153+
if normalized_limit < requested_limit:
154+
if logger.isEnabledFor(logging.INFO):
155+
logger.info(
156+
"API recent usage limit clamped from %s to %s (max=%s)",
157+
limit,
158+
normalized_limit,
159+
MAX_RECENT_USAGE_RECORDS,
160+
)
161+
162+
result = await usage_service.get_recent_usage(
163+
session_id=session_id, limit=normalized_limit
164+
)
110165
return result # type: ignore[no-any-return]

src/core/common/usage_limits.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
"""Utilities for normalizing usage-related request parameters."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any
6+
7+
from src.constants import MAX_RECENT_USAGE_RECORDS
8+
9+
10+
def normalize_recent_usage_limit(limit: Any) -> int:
11+
"""Normalize the recent usage limit value to a safe, bounded integer.
12+
13+
Args:
14+
limit: The requested limit value that may come from untrusted sources.
15+
16+
Returns:
17+
A non-negative integer that does not exceed :data:`MAX_RECENT_USAGE_RECORDS`.
18+
Invalid or non-positive values yield ``0`` so callers can short-circuit expensive
19+
repository lookups.
20+
"""
21+
22+
try:
23+
numeric_limit = int(limit)
24+
except (TypeError, ValueError):
25+
return 0
26+
27+
if numeric_limit <= 0:
28+
return 0
29+
30+
return min(numeric_limit, MAX_RECENT_USAGE_RECORDS)
31+

src/core/services/usage_tracking_service.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def headers(self) -> dict[str, str]: ...
3030
def media_type(self) -> str: ...
3131

3232

33+
from src.constants import MAX_RECENT_USAGE_RECORDS
34+
from src.core.common.usage_limits import normalize_recent_usage_limit
3335
from src.core.domain.usage_data import UsageData
3436
from src.core.domain.usage_stats import ModelUsageStats, UsageStatsResponse
3537
from src.core.interfaces.repositories_interface import IUsageRepository
@@ -341,11 +343,37 @@ async def get_recent_usage(
341343
Returns:
342344
List of usage data entities
343345
"""
346+
try:
347+
requested_limit = int(limit)
348+
except (TypeError, ValueError):
349+
requested_limit = 0
350+
351+
normalized_limit = normalize_recent_usage_limit(requested_limit)
352+
353+
if normalized_limit == 0:
354+
if logger.isEnabledFor(logging.DEBUG):
355+
logger.debug(
356+
"Recent usage requested with limit=%s; returning empty result", limit
357+
)
358+
return []
359+
360+
if normalized_limit < requested_limit:
361+
if logger.isEnabledFor(logging.INFO):
362+
logger.info(
363+
"Recent usage limit clamped from %s to %s (max=%s)",
364+
limit,
365+
normalized_limit,
366+
MAX_RECENT_USAGE_RECORDS,
367+
)
368+
344369
if session_id:
345370
data = await self._repository.get_by_session_id(session_id)
346371
else:
347372
data = await self._repository.get_all()
348373

349374
# Sort by timestamp (newest first) and limit
375+
if not data:
376+
return []
377+
350378
sorted_data = sorted(data, key=lambda x: x.timestamp, reverse=True)
351-
return sorted_data[:limit]
379+
return sorted_data[:normalized_limit]

tests/unit/core/app/controllers/test_usage_controller_comprehensive.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from unittest.mock import AsyncMock
99

1010
import pytest
11+
from src.constants import MAX_RECENT_USAGE_RECORDS
1112
from src.core.app.controllers.usage_controller import UsageController
1213
from src.core.domain.usage_data import UsageData
1314
from src.core.interfaces.usage_tracking_interface import IUsageTrackingService
@@ -228,23 +229,29 @@ async def test_get_recent_usage_large_limit(
228229

229230
assert result == mock_usage_data
230231
mock_usage_service.get_recent_usage.assert_called_once_with(
231-
session_id=None, limit=10000
232+
session_id=None, limit=MAX_RECENT_USAGE_RECORDS
232233
)
233234

234235
@pytest.mark.asyncio
235236
async def test_get_recent_usage_zero_limit(
236237
self, controller: UsageController, mock_usage_service: IUsageTrackingService
237238
) -> None:
238239
"""Test get_recent_usage with zero limit."""
239-
mock_usage_data = []
240-
mock_usage_service.get_recent_usage.return_value = mock_usage_data
241-
242240
result = await controller.get_recent_usage(limit=0)
243241

244-
assert result == mock_usage_data
245-
mock_usage_service.get_recent_usage.assert_called_once_with(
246-
session_id=None, limit=0
247-
)
242+
assert result == []
243+
mock_usage_service.get_recent_usage.assert_not_called()
244+
245+
@pytest.mark.asyncio
246+
async def test_get_recent_usage_negative_limit(
247+
self, controller: UsageController, mock_usage_service: IUsageTrackingService
248+
) -> None:
249+
"""Test get_recent_usage with negative limit."""
250+
251+
result = await controller.get_recent_usage(limit=-5)
252+
253+
assert result == []
254+
mock_usage_service.get_recent_usage.assert_not_called()
248255

249256
@pytest.mark.asyncio
250257
async def test_service_error_handling_stats(
@@ -384,15 +391,10 @@ async def test_get_recent_usage_negative_limit(
384391
self, controller: UsageController, mock_usage_service: IUsageTrackingService
385392
) -> None:
386393
"""Test get_recent_usage with negative limit value."""
387-
mock_usage_data = []
388-
mock_usage_service.get_recent_usage.return_value = mock_usage_data
389-
390394
result = await controller.get_recent_usage(limit=-10)
391395

392-
assert result == mock_usage_data
393-
mock_usage_service.get_recent_usage.assert_called_once_with(
394-
session_id=None, limit=-10
395-
)
396+
assert result == []
397+
mock_usage_service.get_recent_usage.assert_not_called()
396398

397399
@pytest.mark.asyncio
398400
async def test_get_usage_stats_zero_days(

tests/unit/core/services/test_usage_tracking_service_comprehensive.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from unittest.mock import AsyncMock, patch
1010

1111
import pytest
12+
from src.constants import MAX_RECENT_USAGE_RECORDS
1213
from src.core.domain.usage_data import UsageData
1314
from src.core.interfaces.repositories_interface import IUsageRepository
1415
from src.core.services.usage_tracking_service import UsageTrackingService
@@ -346,6 +347,28 @@ async def test_get_recent_usage(
346347
assert result == mock_usage_data
347348
mock_repository.get_by_session_id.assert_called_once_with("session1")
348349

350+
@pytest.mark.asyncio
351+
async def test_get_recent_usage_zero_limit(
352+
self, service: UsageTrackingService, mock_repository: IUsageRepository
353+
) -> None:
354+
"""Recent usage should return empty results for non-positive limits."""
355+
356+
result = await service.get_recent_usage(limit=0)
357+
358+
assert result == []
359+
mock_repository.get_all.assert_not_called()
360+
361+
@pytest.mark.asyncio
362+
async def test_get_recent_usage_negative_limit(
363+
self, service: UsageTrackingService, mock_repository: IUsageRepository
364+
) -> None:
365+
"""Negative limits should be treated as zero to avoid large responses."""
366+
367+
result = await service.get_recent_usage(limit=-10)
368+
369+
assert result == []
370+
mock_repository.get_all.assert_not_called()
371+
349372
@pytest.mark.asyncio
350373
async def test_get_recent_usage_defaults(
351374
self, service: UsageTrackingService, mock_repository: IUsageRepository
@@ -359,6 +382,32 @@ async def test_get_recent_usage_defaults(
359382
assert result == mock_usage_data
360383
mock_repository.get_all.assert_called_once()
361384

385+
@pytest.mark.asyncio
386+
async def test_get_recent_usage_large_limit_is_clamped(
387+
self, service: UsageTrackingService, mock_repository: IUsageRepository
388+
) -> None:
389+
"""Large limits should be clamped to protect against excessive workloads."""
390+
391+
mock_usage_data = [
392+
UsageData(
393+
id=str(index),
394+
session_id="session",
395+
model="model",
396+
prompt_tokens=0,
397+
completion_tokens=0,
398+
total_tokens=0,
399+
cost=0.0,
400+
timestamp=datetime.now(timezone.utc) + timedelta(seconds=index),
401+
)
402+
for index in range(MAX_RECENT_USAGE_RECORDS + 50)
403+
]
404+
mock_repository.get_all.return_value = mock_usage_data
405+
406+
result = await service.get_recent_usage(limit=10_000)
407+
408+
assert len(result) == MAX_RECENT_USAGE_RECORDS
409+
mock_repository.get_all.assert_called_once()
410+
362411
@pytest.mark.asyncio
363412
async def test_get_recent_usage_with_session_id(
364413
self, service: UsageTrackingService, mock_repository: IUsageRepository

0 commit comments

Comments
 (0)