Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion litellm/proxy/vector_store_endpoints/management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import copy
import json
from typing import List, Optional
from typing import Any, Dict, List, Optional

from fastapi import APIRouter, Depends, HTTPException

Expand All @@ -23,6 +23,8 @@
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
from litellm.secret_managers.main import get_secret
from litellm.types.vector_stores import (
LiteLLM_ManagedVectorStore,
LiteLLM_ManagedVectorStoreListResponse,
Expand All @@ -35,6 +37,102 @@
router = APIRouter()


async def _resolve_embedding_config_from_db(
embedding_model: str, prisma_client
) -> Optional[Dict[str, Any]]:
"""
Resolve embedding config from database model configuration.

If litellm_embedding_model is provided but litellm_embedding_config is not,
this function looks up the model in the database and extracts api_key, api_base,
and api_version from the model's litellm_params to build the embedding config.

Args:
embedding_model: The embedding model string (e.g., "text-embedding-ada-002" or "azure/text-embedding-3-large")
prisma_client: The Prisma client instance

Returns:
Dictionary with api_key, api_base, and api_version if model found, None otherwise
"""
if not embedding_model:
return None

# Extract model name - could be "text-embedding-ada-002" or "azure/text-embedding-3-large"
# Try to find model by exact match first, then try without provider prefix
model_name_candidates = [embedding_model]
if "/" in embedding_model:
# If it has a provider prefix, also try without it
_, model_name = embedding_model.split("/", 1)
model_name_candidates.append(model_name)

# Try to find model in database
for model_name in model_name_candidates:
try:
db_model = await prisma_client.db.litellm_proxymodeltable.find_first(
where={"model_name": model_name}
)

if db_model and db_model.litellm_params:
# Extract litellm_params (could be dict or JSON string)
model_params = db_model.litellm_params
if isinstance(model_params, str):
model_params = json.loads(model_params)

# Decrypt values from database (similar to how proxy_server.py does it)
# Values stored in DB are encrypted, so we need to decrypt them first
decrypted_params = {}
if isinstance(model_params, dict):
for k, v in model_params.items():
if isinstance(v, str):
# Decrypt value - returns original value if decryption fails or no key is set
decrypted_value = decrypt_value_helper(
value=v, key=k, return_original_value=True
)
decrypted_params[k] = decrypted_value
else:
decrypted_params[k] = v
else:
decrypted_params = model_params

# Build embedding config from model params
embedding_config = {}

# Extract api_key
api_key = decrypted_params.get("api_key")
if api_key:
# Handle os.environ/ prefix (after decryption, values may be os.environ/ prefixed)
if isinstance(api_key, str) and api_key.startswith("os.environ/"):
api_key = get_secret(api_key)
embedding_config["api_key"] = api_key

# Extract api_base
api_base = decrypted_params.get("api_base")
if api_base:
# Handle os.environ/ prefix (after decryption, values may be os.environ/ prefixed)
if isinstance(api_base, str) and api_base.startswith("os.environ/"):
api_base = get_secret(api_base)
embedding_config["api_base"] = api_base

# Extract api_version
api_version = decrypted_params.get("api_version")
if api_version:
embedding_config["api_version"] = api_version

# Only return config if we have at least api_key or api_base
if embedding_config:
verbose_proxy_logger.debug(
f"Resolved embedding config from database model {model_name}: {list(embedding_config.keys())}"
)
return embedding_config
except Exception as e:
verbose_proxy_logger.debug(
f"Error resolving embedding config for model {model_name}: {str(e)}"
)
continue

return None


########################################################
# Management Endpoints
########################################################
Expand Down Expand Up @@ -85,6 +183,19 @@ async def new_vector_store(
litellm_params_json: Optional[str] = None
_input_litellm_params: dict = vector_store.get("litellm_params", {}) or {}
if _input_litellm_params is not None:
# Auto-resolve embedding config if embedding model is provided but config is not
embedding_model = _input_litellm_params.get("litellm_embedding_model")
if embedding_model and not _input_litellm_params.get("litellm_embedding_config"):
resolved_config = await _resolve_embedding_config_from_db(
embedding_model=embedding_model,
prisma_client=prisma_client
)
if resolved_config:
_input_litellm_params["litellm_embedding_config"] = resolved_config
verbose_proxy_logger.info(
f"Auto-resolved embedding config for model {embedding_model}"
)

litellm_params_dict = GenericLiteLLMParams(
**_input_litellm_params
).model_dump(exclude_none=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from litellm.proxy.vector_store_endpoints.endpoints import (
_update_request_data_with_litellm_managed_vector_store_registry,
)
from litellm.proxy.vector_store_endpoints.management_endpoints import (
_resolve_embedding_config_from_db,
new_vector_store,
)
from litellm.proxy.vector_store_endpoints.utils import (
check_vector_store_permission,
is_allowed_to_call_vector_store_endpoint,
Expand Down Expand Up @@ -1045,3 +1049,138 @@ async def mock_delete(where):
assert len(vector_stores_to_run) == 0, (
"Deleted vector store should not be returned when trying to use it"
)


@pytest.mark.asyncio
async def test_resolve_embedding_config_from_db():
"""Test that _resolve_embedding_config_from_db correctly resolves embedding config from database."""
mock_prisma_client = MagicMock()

# Mock database model with litellm_params
mock_db_model = MagicMock()
mock_db_model.litellm_params = {
"api_key": "test-api-key",
"api_base": "https://api.openai.com",
"api_version": "2024-01-01"
}

mock_prisma_client.db.litellm_proxymodeltable.find_first = AsyncMock(
return_value=mock_db_model
)

with patch(
"litellm.proxy.vector_store_endpoints.management_endpoints.decrypt_value_helper",
side_effect=lambda value, key, return_original_value: value
):
result = await _resolve_embedding_config_from_db(
embedding_model="text-embedding-ada-002",
prisma_client=mock_prisma_client
)

assert result is not None
assert result["api_key"] == "test-api-key"
assert result["api_base"] == "https://api.openai.com"
assert result["api_version"] == "2024-01-01"
mock_prisma_client.db.litellm_proxymodeltable.find_first.assert_called_once_with(
where={"model_name": "text-embedding-ada-002"}
)

# Test with empty embedding_model
result_empty = await _resolve_embedding_config_from_db(
embedding_model="",
prisma_client=mock_prisma_client
)
assert result_empty is None

# Test with model not found
mock_prisma_client.db.litellm_proxymodeltable.find_first = AsyncMock(
return_value=None
)
result_not_found = await _resolve_embedding_config_from_db(
embedding_model="non-existent-model",
prisma_client=mock_prisma_client
)
assert result_not_found is None


@pytest.mark.asyncio
async def test_new_vector_store_auto_resolves_embedding_config():
"""Test that new_vector_store auto-resolves embedding config when embedding_model is provided but config is not."""
import json
from litellm.types.vector_stores import LiteLLM_ManagedVectorStore

mock_prisma_client = MagicMock()

# Mock vector store request with embedding_model but no embedding_config
vector_store_data: LiteLLM_ManagedVectorStore = {
"vector_store_id": "test-store-001",
"custom_llm_provider": "openai",
"litellm_params": {
"litellm_embedding_model": "text-embedding-ada-002",
# Note: litellm_embedding_config is not provided
}
}

# Mock database model lookup for embedding config resolution
mock_db_model = MagicMock()
mock_db_model.litellm_params = {
"api_key": "resolved-api-key",
"api_base": "https://api.openai.com",
"api_version": "2024-01-01"
}

# Mock user API key
mock_user_api_key = MagicMock(spec=UserAPIKeyAuth)
mock_user_api_key.user_role = None

# Mock database operations
mock_prisma_client.db.litellm_managedvectorstorestable.find_unique = AsyncMock(
return_value=None # Vector store doesn't exist yet
)
mock_prisma_client.db.litellm_proxymodeltable.find_first = AsyncMock(
return_value=mock_db_model
)

# Track what was passed to create
captured_create_data = {}

async def mock_create(*args, **kwargs):
captured_create_data.update(kwargs.get("data", {}))
mock_created_vector_store = MagicMock()
mock_created_vector_store.model_dump.return_value = {
"vector_store_id": "test-store-001",
"custom_llm_provider": "openai",
"litellm_params": kwargs.get("data", {}).get("litellm_params")
}
return mock_created_vector_store

mock_prisma_client.db.litellm_managedvectorstorestable.create = AsyncMock(
side_effect=mock_create
)

mock_registry = MagicMock()
mock_registry.add_vector_store_to_registry = MagicMock()

with patch(
"litellm.proxy.proxy_server.prisma_client",
mock_prisma_client
), patch(
"litellm.proxy.vector_store_endpoints.management_endpoints.decrypt_value_helper",
side_effect=lambda value, key, return_original_value: value
), patch.object(
litellm, "vector_store_registry", mock_registry
):
result = await new_vector_store(
vector_store=vector_store_data,
user_api_key_dict=mock_user_api_key
)

assert result["status"] == "success"
# Verify that embedding config was resolved and included in the create call
litellm_params_json = captured_create_data.get("litellm_params")
assert litellm_params_json is not None
litellm_params_dict = json.loads(litellm_params_json)
assert "litellm_embedding_config" in litellm_params_dict
assert litellm_params_dict["litellm_embedding_config"]["api_key"] == "resolved-api-key"
assert litellm_params_dict["litellm_embedding_config"]["api_base"] == "https://api.openai.com"
assert litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-01-01"
Loading