Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/client/content/config/tabs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ def _render_model_selection(model: dict, provider_models: list, action: str) ->

def _render_api_configuration(model: dict, provider_models: list, disable_for_oci: bool) -> dict:
"""Render API configuration UI and return updated model"""
api_base = next(
litellm_api_base = next(
(m.get("api_base", "") for m in provider_models if m.get("key") == model["id"]), model.get("api_base", "")
)

model["api_base"] = st.text_input(
"Provider URL:",
help=help_text.help_dict["model_url"],
key="add_model_url",
value=api_base,
value=model.get("api_base", litellm_api_base),
disabled=disable_for_oci,
)

Expand Down
112 changes: 100 additions & 12 deletions tests/client/content/config/tabs/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,18 +151,12 @@ def test_page_content_verification(self, app_server, app_test):
class TestModelFunctions:
"""Test individual functions from models.py"""

def _setup_function_test(self, app_test):
"""Helper to set up function testing"""
at = app_test(ST_FILE)
at.run() # Initialize session state
return at

def test_clear_client_models_function(self, app_server, app_test):
"""Test clear_client_models function behavior"""
from client.content.config.tabs.models import clear_client_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

# Set up test client settings
at.session_state.client_settings = {
Expand All @@ -189,7 +183,7 @@ def test_clear_client_models_no_matches(self, app_server, app_test):
from client.content.config.tabs.models import clear_client_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

original_settings = {
"ll_model": {"model": "different/model"},
Expand All @@ -212,7 +206,7 @@ def test_get_models_function(self, app_server, app_test):
from client.content.config.tabs.models import get_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

# Clear existing model configs to test refresh
at.session_state.model_configs = None
Expand All @@ -231,7 +225,7 @@ def test_get_models_force_refresh(self, app_server, app_test):
from client.content.config.tabs.models import get_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

# Set some existing data
at.session_state.model_configs = ["old_data"]
Expand All @@ -249,7 +243,7 @@ def test_get_supported_models_function(self, app_server, app_test):
from client.content.config.tabs.models import get_supported_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

# Get providers using API context
with patch("client.utils.api_call.state", at.session_state):
Expand Down Expand Up @@ -373,7 +367,7 @@ def test_render_model_selection_with_custom_model_id(self, app_server, app_test)
from client.content.config.tabs.models import _render_model_selection, get_supported_models

assert app_server is not None
at = self._setup_function_test(app_test)
at = app_test(ST_FILE).run()

# Get actual supported models from API
with patch("client.utils.api_call.state", at.session_state):
Expand Down Expand Up @@ -403,3 +397,97 @@ def test_render_model_selection_with_custom_model_id(self, app_server, app_test)

# The model ID should be preserved
assert result_model["id"] == custom_model_id

def test_render_api_configuration_preserves_saved_api_base(self, app_server, app_test):
"""Test that _render_api_configuration preserves saved api_base over LiteLLM default"""
from client.content.config.tabs.models import _render_api_configuration, get_supported_models

assert app_server is not None
at = app_test(ST_FILE).run()

# Get actual supported models from API
with patch("client.utils.api_call.state", at.session_state):
supported_models = get_supported_models("ll")

# Find OpenAI provider
openai_provider = next((p for p in supported_models if p["provider"] == "openai"), None)
assert openai_provider is not None, "OpenAI provider should be available in supported models"

provider_models = openai_provider["models"]

# Find a model that has an api_base in the provider models
model_with_api_base = next((m for m in provider_models if m.get("api_base")), None)
if model_with_api_base:
litellm_api_base = model_with_api_base["api_base"]
model_id = model_with_api_base["key"]
else:
# Fallback if no model has api_base
litellm_api_base = "https://api.openai.com/v1"
model_id = "gpt-4"

# Create a model with a SAVED custom api_base (different from LiteLLM default)
saved_custom_api_base = "https://my-custom-api.example.com/v1"
model = {
"id": model_id,
"provider": "openai",
"type": "ll",
"api_base": saved_custom_api_base, # This is the saved configuration
}

disable_for_oci = False

with patch("client.content.config.tabs.models.state", at.session_state):
# This should preserve the saved api_base, NOT overwrite with LiteLLM default
result_model = _render_api_configuration(model, provider_models, disable_for_oci)

# The saved api_base should be preserved (regression test for commit 5612888)
assert result_model["api_base"] == saved_custom_api_base
assert result_model["api_base"] != litellm_api_base or saved_custom_api_base == litellm_api_base

def test_render_api_configuration_uses_litellm_default_when_no_saved_value(self, app_server, app_test):
"""Test that _render_api_configuration uses LiteLLM default when no saved api_base exists"""
from client.content.config.tabs.models import _render_api_configuration, get_supported_models

assert app_server is not None
at = app_test(ST_FILE).run()

# Get actual supported models from API
with patch("client.utils.api_call.state", at.session_state):
supported_models = get_supported_models("ll")

# Find OpenAI provider
openai_provider = next((p for p in supported_models if p["provider"] == "openai"), None)
assert openai_provider is not None, "OpenAI provider should be available in supported models"

provider_models = openai_provider["models"]

# Find a model that has an api_base in the provider models
model_with_api_base = next((m for m in provider_models if m.get("api_base")), None)
if model_with_api_base:
litellm_api_base = model_with_api_base["api_base"]
model_id = model_with_api_base["key"]
else:
# Fallback if no model has api_base
litellm_api_base = "https://api.openai.com/v1"
model_id = "gpt-4"

# Create a model WITHOUT a saved api_base (new model being added)
model = {
"id": model_id,
"provider": "openai",
"type": "ll",
# Note: No "api_base" key - simulating a new model being added
}

disable_for_oci = False

with patch("client.content.config.tabs.models.state", at.session_state):
# This should use the LiteLLM default api_base
result_model = _render_api_configuration(model, provider_models, disable_for_oci)

# Should fall back to LiteLLM default when no saved value exists
if model_with_api_base:
assert result_model["api_base"] == litellm_api_base
else:
# If no model has api_base, it should be empty string
assert result_model["api_base"] == ""
Loading