Skip to content

Commit

Permalink
(litellm sdk speedup router) - adds a helper `_cached_get_model_group…
Browse files Browse the repository at this point in the history
…_info` to use when trying to get deployment tpm/rpm limits (#7719)

* fix _cached_get_model_group_info

* fixes get_remaining_model_group_usage

* test_cached_get_model_group_info
  • Loading branch information
ishaan-jaff authored Jan 12, 2025
1 parent baa528a commit 15b5203
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
20 changes: 17 additions & 3 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import traceback
import uuid
from collections import defaultdict
from functools import lru_cache
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -4696,11 +4697,19 @@ async def get_model_group_usage(
rpm_usage += t
return tpm_usage, rpm_usage

async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]:
@lru_cache(maxsize=64)
def _cached_get_model_group_info(
self, model_group: str
) -> Optional[ModelGroupInfo]:
"""
Cached version of get_model_group_info, uses @lru_cache wrapper
current_tpm, current_rpm = await self.get_model_group_usage(model_group)
This is a speed optimization, since set_response_headers makes a call to get_model_group_info on every request
"""
return self.get_model_group_info(model_group)

model_group_info = self.get_model_group_info(model_group)
async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]:
model_group_info = self._cached_get_model_group_info(model_group)

if model_group_info is not None and model_group_info.tpm is not None:
tpm_limit = model_group_info.tpm
Expand All @@ -4712,6 +4721,11 @@ async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, i
else:
rpm_limit = None

if tpm_limit is None and rpm_limit is None:
return {}

current_tpm, current_rpm = await self.get_model_group_usage(model_group)

returned_dict = {}
if tpm_limit is not None:
returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - (
Expand Down
18 changes: 18 additions & 0 deletions tests/router_unit_tests/test_router_helper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,21 @@ async def test_async_callback_filter_deployments(model_list):
)

assert len(new_healthy_deployments) == len(healthy_deployments)


def test_cached_get_model_group_info(model_list):
"""Test if the '_cached_get_model_group_info' function is working correctly with LRU cache"""
router = Router(model_list=model_list)

# First call - should hit the actual function
result1 = router._cached_get_model_group_info("gpt-3.5-turbo")

# Second call with same argument - should hit the cache
result2 = router._cached_get_model_group_info("gpt-3.5-turbo")

# Verify results are the same
assert result1 == result2

# Verify the cache info shows hits
cache_info = router._cached_get_model_group_info.cache_info()
assert cache_info.hits > 0 # Should have at least one cache hit

0 comments on commit 15b5203

Please sign in to comment.