Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions litellm/llms/anthropic/chat/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,20 +253,36 @@ async def process_output_response(
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (content_index, None) for each text

response_content = response.get("content", [])
# Handle both dict and object responses
if hasattr(response, "get"):
response_content = response.get("content", [])
elif hasattr(response, "content"):
response_content = response.content or []
else:
response_content = []

if not response_content:
return response

# Step 1: Extract all text content and tool calls from response
for content_idx, content_block in enumerate(response_content):
# Check if this is a text or tool_use block by checking the 'type' field
if isinstance(content_block, dict) and content_block.get("type") in [
"text",
"tool_use",
]:
# Cast to dict to handle the union type properly
# Handle both dict and Pydantic object content blocks
if isinstance(content_block, dict):
block_type = content_block.get("type")
block_dict = content_block
elif hasattr(content_block, "type"):
block_type = getattr(content_block, "type", None)
# Convert Pydantic object to dict for processing
if hasattr(content_block, "model_dump"):
block_dict = content_block.model_dump()
else:
block_dict = {"type": block_type, "text": getattr(content_block, "text", None)}
else:
continue

if block_type in ["text", "tool_use"]:
self._extract_output_text_and_images(
content_block=cast(Dict[str, Any], content_block),
content_block=cast(Dict[str, Any], block_dict),
content_idx=content_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
Expand Down Expand Up @@ -590,7 +606,14 @@ async def _apply_guardrail_responses_to_output(
mapping = task_mappings[task_idx]
content_idx = cast(int, mapping[0])

response_content = response.get("content", [])
# Handle both dict and object responses
if hasattr(response, "get"):
response_content = response.get("content", [])
elif hasattr(response, "content"):
response_content = response.content or []
else:
continue

if not response_content:
continue

Expand All @@ -601,7 +624,11 @@ async def _apply_guardrail_responses_to_output(
content_block = response_content[content_idx]

# Verify it's a text block and update the text field
if isinstance(content_block, dict) and content_block.get("type") == "text":
# Cast to dict to handle the union type properly for assignment
content_block = cast("AnthropicResponseTextBlock", content_block)
content_block["text"] = guardrail_response
# Handle both dict and Pydantic object content blocks
if isinstance(content_block, dict):
if content_block.get("type") == "text":
content_block["text"] = guardrail_response
elif hasattr(content_block, "type") and getattr(content_block, "type", None) == "text":
# Update Pydantic object's text attribute
if hasattr(content_block, "text"):
content_block.text = guardrail_response
74 changes: 63 additions & 11 deletions litellm/llms/openai/responses/guardrail_translation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

from openai.types.responses import ResponseFunctionToolCall
from pydantic import BaseModel

from litellm._logging import verbose_proxy_logger
Expand Down Expand Up @@ -299,8 +298,25 @@ async def process_output_response(
task_mappings: List[Tuple[int, int]] = []
# Track (output_item_index, content_index) for each text

# Handle both dict and Pydantic object responses
if isinstance(response, dict):
response_output = response.get("output", [])
elif hasattr(response, "output"):
response_output = response.output or []
else:
verbose_proxy_logger.debug(
"OpenAI Responses API: No output found in response"
)
return response

if not response_output:
verbose_proxy_logger.debug(
"OpenAI Responses API: Empty output in response"
)
return response

# Step 1: Extract all text content and tool calls from response output
for output_idx, output_item in enumerate(response.output):
for output_idx, output_item in enumerate(response_output):
self._extract_output_text_and_images(
output_item=output_item,
output_idx=output_idx,
Expand Down Expand Up @@ -538,13 +554,18 @@ def _extract_output_text_and_images(
content: Optional[Union[List[OutputText], List[dict]]] = None
if isinstance(output_item, BaseModel):
try:
output_item_dump = output_item.model_dump()
generic_response_output_item = GenericResponseOutputItem.model_validate(
output_item.model_dump()
output_item_dump
)
if generic_response_output_item.content:
content = generic_response_output_item.content
except Exception:
return
# Try to extract content directly from output_item if validation fails
if hasattr(output_item, "content") and output_item.content:
content = output_item.content
else:
return
elif isinstance(output_item, dict):
content = output_item.get("content", [])
else:
Expand Down Expand Up @@ -582,22 +603,53 @@ async def _apply_guardrail_responses_to_output(

Override this method to customize how responses are applied.
"""
# Handle both dict and Pydantic object responses
if isinstance(response, dict):
response_output = response.get("output", [])
elif hasattr(response, "output"):
response_output = response.output or []
else:
return

for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
output_idx = cast(int, mapping[0])
content_idx = cast(int, mapping[1])

output_item = response.output[output_idx]
if output_idx >= len(response_output):
continue

output_item = response_output[output_idx]

# Handle both GenericResponseOutputItem and dict
# Handle both GenericResponseOutputItem, BaseModel, and dict
if isinstance(output_item, GenericResponseOutputItem):
content_item = output_item.content[content_idx]
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
elif isinstance(content_item, dict):
content_item["text"] = guardrail_response
if output_item.content and content_idx < len(output_item.content):
content_item = output_item.content[content_idx]
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
elif isinstance(content_item, dict):
content_item["text"] = guardrail_response
elif isinstance(output_item, BaseModel):
# Handle other Pydantic models by converting to GenericResponseOutputItem
try:
generic_item = GenericResponseOutputItem.model_validate(
output_item.model_dump()
)
if generic_item.content and content_idx < len(generic_item.content):
content_item = generic_item.content[content_idx]
if isinstance(content_item, OutputText):
content_item.text = guardrail_response
# Update the original response output
if hasattr(output_item, "content") and output_item.content:
original_content = output_item.content[content_idx]
if hasattr(original_content, "text"):
original_content.text = guardrail_response
except Exception:
pass
elif isinstance(output_item, dict):
content = output_item.get("content", [])
if content and content_idx < len(content):
if isinstance(content[content_idx], dict):
content[content_idx]["text"] = guardrail_response
elif hasattr(content[content_idx], "text"):
content[content_idx].text = guardrail_response
6 changes: 6 additions & 0 deletions litellm/proxy/guardrails/guardrail_hooks/grayswan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def initialize_guardrail(
),
categories=_get_config_value(litellm_params, optional_params, "categories"),
policy_id=_get_config_value(litellm_params, optional_params, "policy_id"),
streaming_end_of_stream_only=_get_config_value(
litellm_params, optional_params, "streaming_end_of_stream_only"
) or False,
streaming_sampling_rate=_get_config_value(
litellm_params, optional_params, "streaming_sampling_rate"
) or 5,
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
)
Expand Down
Loading
Loading