diff --git a/litellm/router.py b/litellm/router.py index a7d9667f43e5..a15f6a5bcba4 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -19,6 +19,7 @@ import traceback import uuid from collections import defaultdict +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -4696,11 +4697,19 @@ async def get_model_group_usage( rpm_usage += t return tpm_usage, rpm_usage - async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]: + @lru_cache(maxsize=64) + def _cached_get_model_group_info( + self, model_group: str + ) -> Optional[ModelGroupInfo]: + """ + Cached version of get_model_group_info, uses @lru_cache wrapper - current_tpm, current_rpm = await self.get_model_group_usage(model_group) + This is a speed optimization, since set_response_headers makes a call to get_model_group_info on every request + """ + return self.get_model_group_info(model_group) - model_group_info = self.get_model_group_info(model_group) + async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, int]: + model_group_info = self._cached_get_model_group_info(model_group) if model_group_info is not None and model_group_info.tpm is not None: tpm_limit = model_group_info.tpm @@ -4712,6 +4721,11 @@ async def get_remaining_model_group_usage(self, model_group: str) -> Dict[str, i else: rpm_limit = None + if tpm_limit is None and rpm_limit is None: + return {} + + current_tpm, current_rpm = await self.get_model_group_usage(model_group) + returned_dict = {} if tpm_limit is not None: returned_dict["x-ratelimit-remaining-tokens"] = tpm_limit - ( diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index e3ca281508d7..e02b47ec365e 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -1126,3 +1126,21 @@ async def test_async_callback_filter_deployments(model_list): ) assert len(new_healthy_deployments) == len(healthy_deployments) + + +def test_cached_get_model_group_info(model_list): + """Test if the '_cached_get_model_group_info' function is working correctly with LRU cache""" + router = Router(model_list=model_list) + + # First call - should hit the actual function + result1 = router._cached_get_model_group_info("gpt-3.5-turbo") + + # Second call with same argument - should hit the cache + result2 = router._cached_get_model_group_info("gpt-3.5-turbo") + + # Verify results are the same + assert result1 == result2 + + # Verify the cache info shows hits + cache_info = router._cached_get_model_group_info.cache_info() + assert cache_info.hits > 0 # Should have at least one cache hit