Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/client/content/config/tabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,19 @@ def _render_model_specific_config(model: dict, model_type: str, provider_models:
value=max_tokens,
)
else:
output_vector_size = next(
(m.get("output_vector_size", 8191) for m in provider_models if m.get("key") == model["id"]),
model.get("output_vector_size", 8191),
)
# First try to get max_chunk_size from the model, then fall back to output_vector_size from provider
max_chunk_size = model.get("max_chunk_size")
if max_chunk_size is None:
max_chunk_size = next(
(m.get("max_chunk_size", 8192) for m in provider_models if m.get("key") == model["id"]),
8192,
)
model["max_chunk_size"] = st.number_input(
"Max Chunk Size:",
help=help_text.help_dict["chunk_size"],
min_value=0,
key="add_model_max_chunk_size",
value=output_vector_size,
value=max_chunk_size,
)

return model
Expand Down
58 changes: 0 additions & 58 deletions src/server/api/core/databases.py

This file was deleted.

54 changes: 52 additions & 2 deletions src/server/api/utils/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import oracledb
from langchain_community.vectorstores import oraclevs as LangchainVS

import server.api.core.databases as core_databases
import server.api.core.settings as core_settings
from server.bootstrap.bootstrap import DATABASE_OBJECTS

from common.schema import (
Database,
Expand Down Expand Up @@ -38,6 +38,56 @@ def __init__(self, status_code: int, detail: str):
super().__init__(detail)


class ExistsDatabaseError(ValueError):
"""Raised when the database already exist."""


class UnknownDatabaseError(ValueError):
"""Raised when the database doesn't exist."""


#####################################################
# CRUD Functions
#####################################################
def create(database: Database) -> Database:
"""Create a new Database definition"""

try:
_ = get(name=database.name)
raise ExistsDatabaseError(f"Database: {database.name} already exists")
except UnknownDatabaseError:
pass

if any(not getattr(database, key) for key in ("user", "password", "dsn")):
raise ValueError("'user', 'password', and 'dsn' are required")

DATABASE_OBJECTS.append(database)
return get(name=database.name)


def get(name: Optional[DatabaseNameType] = None) -> Union[list[Database], None]:
"""
Return all Database objects if `name` is not provided,
or the single Database if `name` is provided.
If a `name` is provided and not found, raise exception
"""
database_objects = DATABASE_OBJECTS

logger.debug("%i databases are defined", len(database_objects))
database_filtered = [db for db in database_objects if (name is None or db.name == name)]
logger.debug("%i databases after filtering", len(database_filtered))

if name and not database_filtered:
raise UnknownDatabaseError(f"{name} not found")

return database_filtered


def delete(name: DatabaseNameType) -> None:
"""Remove database from database objects"""
DATABASE_OBJECTS[:] = [d for d in DATABASE_OBJECTS if d.name != name]


#####################################################
# Protected Functions
#####################################################
Expand Down Expand Up @@ -231,7 +281,7 @@ def get_databases(
db_name: Optional[DatabaseNameType] = None, validate: bool = False
) -> Union[list[Database], Database, None]:
"""Return list of Database Objects"""
databases = core_databases.get_database(db_name)
databases = get(db_name)
if validate:
for db in databases:
try:
Expand Down
20 changes: 10 additions & 10 deletions src/server/api/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,18 @@ def get(
def update(payload: schema.Model) -> schema.Model:
"""Update an existing Model definition"""

(model_update,) = get(model_provider=payload.provider, model_id=payload.id)
if payload.enabled and model_update.api_base and not is_url_accessible(model_update.api_base)[0]:
model_update.enabled = False
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")
# Get the existing model from MODEL_OBJECTS (this is a reference to the object in the list)
(model_existing,) = get(model_provider=payload.provider, model_id=payload.id)

for key, value in payload:
if hasattr(model_update, key):
setattr(model_update, key, value)
else:
raise InvalidModelError(f"Model: Invalid setting - {key}.")
# Check URL accessibility if enabling the model
if payload.enabled and payload.api_base and not is_url_accessible(payload.api_base)[0]:
model_existing.enabled = False
raise URLUnreachableError("Model: Unable to update. API URL is inaccessible.")

return model_update
# Update all fields from payload in place
for key, value in payload.model_dump().items():
setattr(model_existing, key, value)
return model_existing


def delete(model_provider: schema.ModelProviderType, model_id: schema.ModelIdType) -> None:
Expand Down
3 changes: 2 additions & 1 deletion src/server/api/v1/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from fastapi import APIRouter, HTTPException
import oracledb

import server.api.utils.databases as utils_databases

Expand All @@ -15,7 +16,7 @@
# Validate the DEFAULT Databases
try:
_ = utils_databases.get_databases(db_name="DEFAULT", validate=True)
except Exception:
except (ValueError, PermissionError, ConnectionError, LookupError, oracledb.DatabaseError):
pass

auth = APIRouter()
Expand Down
2 changes: 1 addition & 1 deletion src/server/bootstrap/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _get_base_models_list() -> list[dict]:
"provider": "ollama",
"api_base": os.environ.get("ON_PREM_OLLAMA_URL", default="http://127.0.0.1:11434"),
"api_key": "",
"max_chunk_size": 8192,
"max_chunk_size": 512,
},
]

Expand Down
46 changes: 46 additions & 0 deletions tests/server/integration/test_endpoints_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,52 @@ def test_models_update_edge_cases(self, client, auth_headers):
)
assert response.status_code == 404

def test_models_update_max_chunk_size(self, client, auth_headers):
"""Test updating max_chunk_size for embedding models (regression test)"""
# Create an embedding model with default max_chunk_size
payload = {
"id": "test-embed-chunk-size",
"enabled": False,
"type": "embed",
"provider": "test_provider",
"api_base": "http://127.0.0.1:11434",
"max_chunk_size": 8192,
}

# Create the model
response = client.post("/v1/models", headers=auth_headers["valid_auth"], json=payload)
assert response.status_code == 201
assert response.json()["max_chunk_size"] == 8192

# Update the max_chunk_size to 512
payload["max_chunk_size"] = 512
response = client.patch(
f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload
)
assert response.status_code == 200
assert response.json()["max_chunk_size"] == 512

# Verify the update persists by fetching the model again
response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])
assert response.status_code == 200
assert response.json()["max_chunk_size"] == 512

# Update to a different value to ensure it's not cached
payload["max_chunk_size"] = 1024
response = client.patch(
f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"], json=payload
)
assert response.status_code == 200
assert response.json()["max_chunk_size"] == 1024

# Verify again
response = client.get(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])
assert response.status_code == 200
assert response.json()["max_chunk_size"] == 1024

# Clean up
client.delete(f"/v1/models/{payload['provider']}/{payload['id']}", headers=auth_headers["valid_auth"])

def test_models_response_schema_validation(self, client, auth_headers):
"""Test response schema validation for all endpoints"""
# Test /v1/models response schema
Expand Down
Loading
Loading