Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ async def _call_regular_mcp_tool(
if extra_headers is None:
extra_headers = {}
for header in mcp_server.extra_headers:
if header in raw_headers:
if isinstance(header, str) and header in raw_headers:
extra_headers[header] = raw_headers[header]

if mcp_server.static_headers:
Expand Down Expand Up @@ -1692,6 +1692,25 @@ def get_mcp_server_by_name(self, server_name: str) -> Optional[MCPServer]:
return server
return None

def get_mcp_servers_from_ids(
self, server_ids: List[str]
) -> List[MCPServer]:
"""
Get MCP servers from a list of server IDs.

Args:
server_ids: List of server IDs to retrieve

Returns:
List of MCPServer objects corresponding to the provided IDs
"""
servers: List[MCPServer] = []
for server_id in server_ids:
server = self.get_mcp_server_by_id(server_id)
if server:
servers.append(server)
return servers

def _generate_stable_server_id(
self,
server_name: str,
Expand Down
17 changes: 17 additions & 0 deletions litellm/responses/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,27 @@ async def aresponses_api_with_mcp(
user_api_key_auth = kwargs.get("litellm_metadata", {}).get(
"user_api_key_auth"
)

# Extract MCP auth headers from the request to pass to MCP server
secret_fields = kwargs.get("secret_fields")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type this please

(
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers_from_request,
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=secret_fields,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this cover the null case?

tools=tools,
)

tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=tool_server_map,
tool_calls=tool_calls,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers_from_request,
)

if tool_results:
Expand Down
12 changes: 11 additions & 1 deletion litellm/responses/mcp/litellm_proxy_mcp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,13 @@ def _parse_mcp_result(result: Any) -> str:

@staticmethod
async def _execute_tool_calls(
tool_server_map: dict[str, str], tool_calls: List[Any], user_api_key_auth: Any
tool_server_map: dict[str, str],
tool_calls: List[Any],
user_api_key_auth: Any,
mcp_auth_header: Optional[str] = None,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
oauth2_headers: Optional[Dict[str, str]] = None,
raw_headers: Optional[Dict[str, str]] = None,
) -> List[Dict[str, Any]]:
"""Execute tool calls and return results."""
from fastapi import HTTPException
Expand Down Expand Up @@ -435,6 +441,10 @@ async def _execute_tool_calls(
name=tool_name,
arguments=parsed_arguments,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
proxy_logging_obj=proxy_logging_obj,
)

Expand Down
57 changes: 57 additions & 0 deletions litellm/responses/mcp/mcp_streaming_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,61 @@ def __init__(
"custom_llm_provider", None
)

self._extract_mcp_headers_from_params()

# Mark as async iterator
self.is_async = True

def _extract_mcp_headers_from_params(self) -> None:
"""Extract MCP headers from original request params to pass to tool calls"""
from typing import Dict, Optional
from starlette.datastructures import Headers
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)

# Extract headers from secret_fields in original_request_params
raw_headers_from_request: Optional[Dict[str, str]] = None
secret_fields = self.original_request_params.get("secret_fields")
if secret_fields and isinstance(secret_fields, dict):
raw_headers_from_request = secret_fields.get("raw_headers")

# Extract MCP-specific headers
self.mcp_auth_header: Optional[str] = None
self.mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
self.oauth2_headers: Optional[Dict[str, str]] = None
self.raw_headers: Optional[Dict[str, str]] = raw_headers_from_request

if raw_headers_from_request:
headers_obj = Headers(raw_headers_from_request)
self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(headers_obj)
self.mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers_obj)

# Also check if headers are provided in tools array (from request body)
tools = self.original_request_params.get("tools")
if tools:
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "mcp":
tool_headers = tool.get("headers", {})
if tool_headers and isinstance(tool_headers, dict):
# Merge tool headers into mcp_server_auth_headers
headers_obj_from_tool = Headers(tool_headers)
tool_mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj_from_tool)

if tool_mcp_server_auth_headers:
if self.mcp_server_auth_headers is None:
self.mcp_server_auth_headers = {}
# Merge the headers from tool into existing headers
for server_alias, headers_dict in tool_mcp_server_auth_headers.items():
if server_alias not in self.mcp_server_auth_headers:
self.mcp_server_auth_headers[server_alias] = {}
self.mcp_server_auth_headers[server_alias].update(headers_dict)

# Also merge raw headers
if self.raw_headers is None:
self.raw_headers = {}
self.raw_headers.update(tool_headers)

def _should_auto_execute_tools(self) -> bool:
"""Check if tools should be auto-executed"""
Expand Down Expand Up @@ -511,6 +564,10 @@ async def _generate_tool_execution_events(self) -> None:
tool_server_map=self.tool_server_map,
tool_calls=tool_calls,
user_api_key_auth=self.user_api_key_auth,
mcp_auth_header=self.mcp_auth_header,
mcp_server_auth_headers=self.mcp_server_auth_headers,
oauth2_headers=self.oauth2_headers,
raw_headers=self.raw_headers,
)

# Create completion events and output_item.done events for tool execution
Expand Down
60 changes: 60 additions & 0 deletions litellm/responses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Type,
Expand Down Expand Up @@ -350,6 +351,65 @@ def convert_text_format_to_text_param(
return text
return text

@staticmethod
def extract_mcp_headers_from_request(
secret_fields: Optional[Dict[str, Any]],
tools: Optional[Iterable[Any]],
) -> tuple[
Optional[str],
Optional[Dict[str, Dict[str, str]]],
Optional[Dict[str, str]],
Optional[Dict[str, str]],
]:
"""
Extract MCP auth headers from the request to pass to MCP server.
Headers from tools.headers in request body should be passed to MCP server.
"""
from starlette.datastructures import Headers
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)

# Extract headers from secret_fields which contains the original request headers
raw_headers_from_request: Optional[Dict[str, str]] = None
if secret_fields and isinstance(secret_fields, dict):
raw_headers_from_request = secret_fields.get("raw_headers")

# Extract MCP-specific headers using MCPRequestHandler methods
mcp_auth_header: Optional[str] = None
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
oauth2_headers: Optional[Dict[str, str]] = None

if raw_headers_from_request:
headers_obj = Headers(raw_headers_from_request)
mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(headers_obj)
mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(headers_obj)

if tools:
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "mcp":
tool_headers = tool.get("headers", {})
if tool_headers and isinstance(tool_headers, dict):
# Merge tool headers into mcp_server_auth_headers
# Extract server-specific headers from tool.headers
headers_obj_from_tool = Headers(tool_headers)
tool_mcp_server_auth_headers = MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj_from_tool)
if tool_mcp_server_auth_headers:
if mcp_server_auth_headers is None:
mcp_server_auth_headers = {}
# Merge the headers from tool into existing headers
for server_alias, headers_dict in tool_mcp_server_auth_headers.items():
if server_alias not in mcp_server_auth_headers:
mcp_server_auth_headers[server_alias] = {}
mcp_server_auth_headers[server_alias].update(headers_dict)
# Also merge raw headers (non-prefixed headers from tool.headers)
if raw_headers_from_request is None:
raw_headers_from_request = {}
raw_headers_from_request.update(tool_headers)

return mcp_auth_header, mcp_server_auth_headers, oauth2_headers, raw_headers_from_request


class ResponseAPILoggingUtils:
@staticmethod
Expand Down
Loading