Skip to content

Fixing watsonx error: 'model_id' or 'model' cannot be specified in the request body for models in a deployment space #11854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions litellm/llms/watsonx/chat/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def completion(
litellm_params=litellm_params,
)

## UPDATE PAYLOAD (optional params)
## UPDATE PAYLOAD (optional params and special cases for models deployed in spaces)
watsonx_auth_payload = watsonx_chat_transformation._prepare_payload(
model=model,
api_params=api_params,
Expand All @@ -70,7 +70,7 @@ def completion(
)

return super().completion(
model=model,
model=watsonx_auth_payload.get("model_id", None),
messages=messages,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
Expand Down
14 changes: 13 additions & 1 deletion litellm/llms/watsonx/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import List, Optional, Tuple, Union

from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.watsonx import WatsonXAIEndpoint
from litellm.types.llms.watsonx import WatsonXAIEndpoint, WatsonXAPIParams

from ....utils import _remove_additional_properties, _remove_strict_from_schema
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
Expand Down Expand Up @@ -108,3 +108,15 @@ def get_complete_url(
url=url, api_version=optional_params.pop("api_version", None)
)
return url

def _prepare_payload(self, model: str, api_params: WatsonXAPIParams) -> dict:
"""
Prepare payload for deployment models.
Deployment models cannot have 'model_id' or 'model' in the request body.
"""
payload: dict = {}
payload["model_id"] = None if model.startswith("deployment/") else model
payload["project_id"] = (
None if model.startswith("deployment/") else api_params["project_id"]
)
return payload
205 changes: 205 additions & 0 deletions tests/test_litellm/llms/watsonx/test_watsonx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import json
import os
import sys

sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
from unittest.mock import patch, Mock
import pytest
from typing import Optional


@pytest.fixture
def watsonx_chat_completion_call():
def _call(
model="watsonx/my-test-model",
messages=None,
api_key="test_api_key",
space_id: Optional[str] = None,
headers=None,
client=None,
patch_token_call=True,
):
if messages is None:
messages = [{"role": "user", "content": "Hello, how are you?"}]
if client is None:
client = HTTPHandler()

if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock() # No-op to simulate no exception

with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
try:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)

return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
try:
completion(
model=model,
messages=messages,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None

return _call


def test_watsonx_deployment_model_id_not_in_payload(
monkeypatch, watsonx_chat_completion_call
):
"""Test that deployment models do not include 'model_id' in the request payload"""
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
model = "watsonx/deployment/test-deployment-id"
messages = [{"role": "user", "content": "Test message"}]

mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)

assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
# Ensure model_id is not in the payload for deployment models
assert "model_id" not in json_data or json_data["model_id"] is None
# Ensure project_id is also not in the payload for deployment models
assert "project_id" not in json_data or json_data["project_id"] is None


def test_watsonx_regular_model_includes_model_id(
monkeypatch, watsonx_chat_completion_call
):
"""Test that regular models include 'model_id' in the request payload"""
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
model = "watsonx/regular-model"
messages = [{"role": "user", "content": "Test message"}]

mock_post, _ = watsonx_chat_completion_call(model=model, messages=messages)

assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
# Ensure model_id is included in the payload for regular models
assert "model_id" in json_data
assert json_data["model_id"] == "regular-model" # Provider prefix is stripped
# Ensure project_id is also included for regular models
assert "project_id" in json_data


@pytest.fixture
def watsonx_completion_call():
def _call(
model="watsonx_text/my-test-model",
prompt="Hello, how are you?",
api_key="test_api_key",
space_id: Optional[str] = None,
headers=None,
client=None,
patch_token_call=True,
):
if client is None:
client = HTTPHandler()

if patch_token_call:
mock_response = Mock()
mock_response.json.return_value = {
"access_token": "mock_access_token",
"expires_in": 3600,
}
mock_response.raise_for_status = Mock()

with patch.object(client, "post") as mock_post, patch.object(
litellm.module_level_client, "post", return_value=mock_response
) as mock_get:
try:
litellm.text_completion(
model=model,
prompt=prompt,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)

return mock_post, mock_get
else:
with patch.object(client, "post") as mock_post:
try:
litellm.text_completion(
model=model,
prompt=prompt,
api_key=api_key,
headers=headers or {},
client=client,
space_id=space_id,
)
except Exception as e:
print(e)
return mock_post, None

return _call


def test_watsonx_completion_deployment_model_id_not_in_payload(
monkeypatch, watsonx_completion_call
):
"""Test that deployment models do not include 'model_id' in completion request payload"""
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
model = "watsonx_text/deployment/test-deployment-id"
prompt = "Test prompt"

mock_post, _ = watsonx_completion_call(model=model, prompt=prompt)

assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
# Ensure model_id is not in the payload for deployment models
assert "model_id" not in json_data
# Ensure project_id is also not in the payload for deployment models
assert "project_id" not in json_data


def test_watsonx_completion_regular_model_includes_model_id(
monkeypatch, watsonx_completion_call
):
"""Test that regular models include 'model_id' in completion request payload"""
monkeypatch.setenv("WATSONX_PROJECT_ID", "test-project-id")
monkeypatch.setenv("WATSONX_API_BASE", "https://test-api.watsonx.ai")
model = "watsonx_text/regular-model"
prompt = "Test prompt"

mock_post, _ = watsonx_completion_call(model=model, prompt=prompt)

assert mock_post.call_count == 1
json_data = json.loads(mock_post.call_args.kwargs["data"])
# Ensure model_id is included in the payload for regular models
assert "model_id" in json_data
assert json_data["model_id"] == "regular-model" # Provider prefix is stripped
# Ensure project_id is also included for regular models
assert "project_id" in json_data
Loading