Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 13 additions & 37 deletions litellm/proxy/auth/model_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,18 @@ def _check_wildcard_routing(model: str) -> bool:
- openai/*
- *
"""
if "*" in model:
return True
return False
return "*" in model


def get_provider_models(
provider: str, litellm_params: Optional[LiteLLM_Params] = None
) -> Optional[List[str]]:
def get_provider_models(provider: str, litellm_params: Optional[LiteLLM_Params] = None) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models(litellm_params=litellm_params)

if provider in litellm.models_by_provider:
provider_models = get_valid_models(
custom_llm_provider=provider, litellm_params=litellm_params
)
provider_models = get_valid_models(custom_llm_provider=provider, litellm_params=litellm_params)
return provider_models
return None

Expand All @@ -51,9 +45,7 @@ def _get_models_from_access_groups(
new_models = []
for idx, model in enumerate(all_models):
if model in model_access_groups:
if (
not include_model_access_groups
): # remove access group, unless requested - e.g. when creating a key
if not include_model_access_groups: # remove access group, unless requested - e.g. when creating a key
idx_to_remove.append(idx)
new_models.extend(model_access_groups[model])

Expand All @@ -80,7 +72,6 @@ async def get_mcp_server_ids(

# Make a direct SQL query to get just the mcp_servers
try:

result = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": user_api_key_dict.object_permission_id},
)
Expand Down Expand Up @@ -114,9 +105,7 @@ def get_key_models(
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list

all_models = _get_models_from_access_groups(
model_access_groups=model_access_groups, all_models=all_models
)
all_models = _get_models_from_access_groups(model_access_groups=model_access_groups, all_models=all_models)

verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
Expand Down Expand Up @@ -176,6 +165,7 @@ def get_complete_model_list(
"""

unique_models = []

def append_unique(models):
for model in models:
if model not in unique_models:
Expand All @@ -188,7 +178,7 @@ def append_unique(models):
else:
append_unique(proxy_model_list)
if include_model_access_groups:
append_unique(list(model_access_groups.keys())) # TODO: keys order
append_unique(list(model_access_groups.keys())) # TODO: keys order

if user_model:
append_unique([user_model])
Expand All @@ -215,9 +205,7 @@ def append_unique(models):
return complete_model_list


def get_known_models_from_wildcard(
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
) -> List[str]:
def get_known_models_from_wildcard(wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None) -> List[str]:
try:
wildcard_provider_prefix, wildcard_suffix = wildcard_model.split("/", 1)
except ValueError: # safely fail
Expand All @@ -233,25 +221,17 @@ def get_known_models_from_wildcard(

# get all known provider models

wildcard_models = get_provider_models(
provider=provider, litellm_params=litellm_params
)
wildcard_models = get_provider_models(provider=provider, litellm_params=litellm_params)

if wildcard_models is None:
return []
if wildcard_suffix != "*":
## CHECK IF PARTIAL FILTER e.g. `gemini-*`
model_prefix = wildcard_suffix.replace("*", "")

is_partial_filter = any(
wc_model.startswith(model_prefix) for wc_model in wildcard_models
)
is_partial_filter = any(wc_model.startswith(model_prefix) for wc_model in wildcard_models)
if is_partial_filter:
filtered_wildcard_models = [
wc_model
for wc_model in wildcard_models
if wc_model.startswith(model_prefix)
]
filtered_wildcard_models = [wc_model for wc_model in wildcard_models if wc_model.startswith(model_prefix)]
wildcard_models = filtered_wildcard_models
else:
# add model prefix to wildcard models
Expand All @@ -274,9 +254,7 @@ def _get_wildcard_models(
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
if (
return_wildcard_routes
): # will add the wildcard route to the list eg: anthropic/*.
if return_wildcard_routes: # will add the wildcard route to the list eg: anthropic/*.
all_wildcard_models.append(model)

## get litellm params from model
Expand Down Expand Up @@ -341,9 +319,7 @@ def get_all_fallbacks(

try:
# Use existing function to get fallback model group
fallback_model_group, _ = get_fallback_model_group(
fallbacks=fallbacks_config, model_group=model
)
fallback_model_group, _ = get_fallback_model_group(fallbacks=fallbacks_config, model_group=model)

if fallback_model_group is None:
return []
Expand Down