Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion litellm/integrations/custom_guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,28 @@ def get_disable_global_guardrail(self, data: dict) -> Optional[bool]:
return metadata["disable_global_guardrail"]
return False

def _is_valid_response_type(self, result: Any) -> bool:
"""
Check if result is a valid LLMResponseTypes instance.

Safely handles TypedDict types which don't support isinstance checks.
For non-LiteLLM responses (like passthrough httpx.Response), returns True
to allow them through.
"""
if result is None:
return False

try:
# Try isinstance check on valid types that support it
response_types = get_args(LLMResponseTypes)
return isinstance(result, response_types)
except TypeError as e:
# TypedDict types don't support isinstance checks
# In this case, we can't validate the type, so we allow it through
if "TypedDict" in str(e):
return True
raise

def get_guardrail_from_metadata(
self, data: dict
) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
Expand Down Expand Up @@ -342,7 +364,7 @@ async def async_post_call_success_deployment_hook(
response=response,
)

if result is None or not isinstance(result, get_args(LLMResponseTypes)):
if not self._is_valid_response_type(result):
return response

return result
Expand Down
8 changes: 5 additions & 3 deletions litellm/litellm_core_utils/api_route_to_call_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,22 @@
Each route can have both async (prefixed with 'a') and sync call types.
"""

from typing import List, Optional

from litellm.types.utils import API_ROUTE_TO_CALL_TYPES, CallTypes


def get_call_types_for_route(route: str) -> list:
def get_call_types_for_route(route: str) -> Optional[List[CallTypes]]:
"""
Get the list of CallTypes for a given API route.

Args:
route: API route path (e.g., "/chat/completions")

Returns:
List of CallTypes for that route, or empty list if route not found
List of CallTypes for that route, or None if route not found
"""
return API_ROUTE_TO_CALL_TYPES.get(route, [])
return API_ROUTE_TO_CALL_TYPES.get(route, None)


def get_routes_for_call_type(call_type: CallTypes) -> list:
Expand Down
6 changes: 4 additions & 2 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,14 +885,16 @@ async def _handle_llm_api_exception(

@staticmethod
def _get_pre_call_type(
route_type: Literal["acompletion", "aembedding", "aresponses"],
) -> Literal["completion", "embeddings", "responses"]:
route_type: Literal["acompletion", "aembedding", "aresponses", "allm_passthrough_route"],
) -> Literal["completion", "embeddings", "responses", "allm_passthrough_route"]:
if route_type == "acompletion":
return "completion"
elif route_type == "aembedding":
return "embeddings"
elif route_type == "aresponses":
return "responses"
elif route_type == "allm_passthrough_route":
return "allm_passthrough_route"

#########################################################
# Proxy Level Streaming Data Generator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def async_moderation_hook(
endpoint_guardrail_translation_mappings = (
load_guardrail_translation_mappings()
)
if CallTypes(call_type) not in endpoint_guardrail_translation_mappings:
if call_type is not None and CallTypes(call_type) not in endpoint_guardrail_translation_mappings:
return data

endpoint_translation = endpoint_guardrail_translation_mappings[
Expand Down Expand Up @@ -180,10 +180,10 @@ async def async_post_call_success_hook(
call_type: Optional[CallTypesLiteral] = None
if user_api_key_dict.request_route is not None:
call_types = get_call_types_for_route(user_api_key_dict.request_route)
if call_types is not None and len(call_types) > 0:
call_type = call_types[0]
if call_types is not None and len(call_types) > 0: # type: ignore
call_type = call_types[0] # type: ignore
if call_type is None:
call_type = _infer_call_type(call_type=None, completion_response=response)
call_type = _infer_call_type(call_type=None, completion_response=response) # type: ignore

if call_type is None:
return response
Expand Down Expand Up @@ -308,10 +308,10 @@ async def async_post_call_streaming_iterator_hook( # noqa: PLR0915
if call_type is None and user_api_key_dict.request_route is not None:
call_types = get_call_types_for_route(user_api_key_dict.request_route)
if call_types is not None:
call_type = call_types[0]
call_type = call_types[0].value

if call_type is None:
call_type = _infer_call_type(call_type=None, completion_response=item)
call_type = _infer_call_type(call_type=None, completion_response=item) # type: ignore

# If call type not supported, just pass through all chunks
if (
Expand Down
153 changes: 153 additions & 0 deletions tests/test_litellm/integrations/test_custom_guardrail.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,156 @@ def test_appends_to_litellm_metadata(self):
assert isinstance(info, list)
assert len(info) == 2
assert info[1]["guardrail_name"] == "test_guardrail"


class TestCustomGuardrailPassthroughSupport:
"""Tests for passthrough endpoint guardrail support - Issue fixes."""

@pytest.mark.asyncio
async def test_async_post_call_success_deployment_hook_with_httpx_response(self):
"""
Test that async_post_call_success_deployment_hook handles raw httpx.Response objects
from passthrough endpoints without crashing with TypeError.

This tests Fix #3: TypeError: TypedDict does not support instance and class checks
"""
import httpx

custom_guardrail = CustomGuardrail()

# Mock the async_post_call_success_hook to return None (guardrail didn't modify response)
custom_guardrail.async_post_call_success_hook = AsyncMock(return_value=None)

# Create a mock httpx.Response object (typical passthrough response)
mock_response = AsyncMock(spec=httpx.Response)
mock_response.status_code = 200
mock_response.text = "Mock response"

request_data = {
"guardrails": ["test_guardrail"],
"user_api_key_user_id": "test_user",
"user_api_key_team_id": "test_team",
"user_api_key_end_user_id": "test_end_user",
"user_api_key_hash": "test_hash",
"user_api_key_request_route": "passthrough_route",
}

# This should not raise TypeError: TypedDict does not support instance and class checks
result = await custom_guardrail.async_post_call_success_deployment_hook(
request_data=request_data,
response=mock_response,
call_type=CallTypes.allm_passthrough_route,
)

# When result is None, should return the original response
assert result == mock_response

@pytest.mark.asyncio
async def test_async_post_call_success_deployment_hook_with_none_call_type(self):
"""
Test that async_post_call_success_deployment_hook handles None call_type gracefully.

This ensures that even if call_type is None (before fix #1), the guardrail doesn't crash.
"""
custom_guardrail = CustomGuardrail()

# Mock the async_post_call_success_hook to return None
custom_guardrail.async_post_call_success_hook = AsyncMock(return_value=None)

mock_response = AsyncMock()

request_data = {
"guardrails": ["test_guardrail"],
"user_api_key_user_id": "test_user",
}

# Call with None call_type - should not crash
result = await custom_guardrail.async_post_call_success_deployment_hook(
request_data=request_data,
response=mock_response,
call_type=None,
)

# Should return the original response when result is None
assert result == mock_response

def test_is_valid_response_type_with_none(self):
"""
Test _is_valid_response_type helper method correctly identifies None as invalid.

This is part of Fix #3: Safely handling TypedDict types that don't support isinstance checks.
"""
custom_guardrail = CustomGuardrail()

# None should be invalid
assert custom_guardrail._is_valid_response_type(None) is False

def test_is_valid_response_type_with_typeddict_error(self):
"""
Test _is_valid_response_type gracefully handles TypeError from TypedDict.

This tests Fix #3: When isinstance() is called with TypedDict types, it raises TypeError.
The method should catch this and allow the response through.
"""
from litellm.types.utils import ModelResponse

custom_guardrail = CustomGuardrail()

# Create a valid LiteLLM response object
response = ModelResponse(
id="test-id",
choices=[],
created=0,
model="test-model",
object="chat.completion",
)

# This should return True (it's a valid response type or TypeError is caught)
result = custom_guardrail._is_valid_response_type(response)
assert result is True


class TestPassthroughCallTypeHandling:
"""Tests for passthrough call type handling in common_request_processing."""

def test_get_pre_call_type_with_allm_passthrough_route(self):
"""
Test that _get_pre_call_type correctly maps allm_passthrough_route.

This tests Fix #1: allm_passthrough_route was not being handled, causing call_type to be None.
"""
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)

# Test the mapping
result = ProxyBaseLLMRequestProcessing._get_pre_call_type(
route_type="allm_passthrough_route"
)

# Should return allm_passthrough_route, not None
assert result == "allm_passthrough_route"

def test_get_pre_call_type_preserves_standard_mappings(self):
"""
Test that _get_pre_call_type still correctly maps standard route types.

Ensures Fix #1 didn't break existing functionality.
"""
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
)

# Test standard mappings are preserved
assert (
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="acompletion")
== "completion"
)
assert (
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aembedding")
== "embeddings"
)
assert (
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aresponses")
== "responses"
)
Loading