Skip to content

Mcp streamable route in gateway. #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
May 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Set Up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"

- name: Install Dependencies
run: |
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,13 @@ await client.connect_to_sse_server(
"MCP-SERVER-BASE-URL": "<The base URL to your MCP server>",
"INVARIANT-PROJECT-NAME": "<The Invariant dataset name>",
"PUSH-INVARIANT-EXPLORER": "true",
"INVARIANT-API-KEY": "<your-invariant-api-key>"
},
)
```

The `INVARIANT-API-KEY` header is used both for pushing the traces to explorer and for guardrailing.

If no `INVARIANT-PROJECT-NAME` header is specified but `PUSH-INVARIANT-EXPLORER` is set to "true", a new Invariant project will be created and the MCP traces will be pushed there.

You can also specify blocking or logging guardrails for the project name by visiting the Explorer.
Expand Down
6 changes: 4 additions & 2 deletions gateway/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
CLIENT_TIMEOUT = 60.0

# MCP related constants
UTF_8 = "utf-8"
INVARIANT_SESSION_ID_PREFIX = "inv-"
MCP_METHOD = "method"
MCP_TOOL_CALL = "tools/call"
MCP_LIST_TOOLS = "tools/list"
Expand All @@ -33,4 +33,6 @@
The operation was blocked by Invariant Guardrails (mention this in your user report).
When users ask about this tool, inform them that it was blocked due to a security guardrail failure.
%s
"""
"""
MCP_SERVER_BASE_URL_HEADER = "mcp-server-base-url"
UTF_8 = "utf-8"
82 changes: 69 additions & 13 deletions gateway/common/mcp_sessions_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pydantic import BaseModel, Field, PrivateAttr
from starlette.datastructures import Headers

from gateway.common.constants import INVARIANT_SESSION_ID_PREFIX
from gateway.common.guardrails import GuardrailRuleSet, GuardrailAction
from gateway.common.request_context import RequestContext
from gateway.integrations.explorer import (
Expand Down Expand Up @@ -45,6 +46,7 @@ class McpSession(BaseModel):
id_to_method_mapping: Dict[int, str] = Field(default_factory=dict)
explorer_dataset: str
push_explorer: bool
invariant_api_key: Optional[str] = None
trace_id: Optional[str] = None
last_trace_length: int = 0
annotations: List[Dict[str, Any]] = Field(default_factory=list)
Expand All @@ -58,8 +60,20 @@ class McpSession(BaseModel):
pending_error_messages: List[dict] = Field(default_factory=list)

# Lock to maintain in-order pushes to explorer
# and other session-related operations
_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)

def get_invariant_api_key(self) -> str:
"""
Get the Invariant API key for the session.

Returns:
str: The Invariant API key
"""
if self.invariant_api_key:
return self.invariant_api_key
return os.getenv("INVARIANT_API_KEY")

async def load_guardrails(self) -> None:
"""
Load guardrails for the session.
Expand All @@ -68,7 +82,7 @@ async def load_guardrails(self) -> None:
"""
self.guardrails = await fetch_guardrails_from_explorer(
self.explorer_dataset,
"Bearer " + os.getenv("INVARIANT_API_KEY"),
"Bearer " + self.get_invariant_api_key(),
# pylint: disable=no-member
self.metadata.get("mcp_client"),
self.metadata.get("mcp_server"),
Expand Down Expand Up @@ -96,11 +110,13 @@ async def session_lock(self):

def session_metadata(self) -> dict:
"""Generate metadata for the current session."""
return {
metadata = {
"session_id": self.session_id,
"system_user": user_and_host(),
**(self.metadata or {}),
}
metadata["is_stateless_http_server"] = self.session_id.startswith(INVARIANT_SESSION_ID_PREFIX)
return metadata

async def get_guardrails_check_result(
self,
Expand All @@ -121,7 +137,7 @@ async def get_guardrails_check_result(
context = RequestContext.create(
request_json={},
dataset_name=self.explorer_dataset,
invariant_authorization="Bearer " + os.getenv("INVARIANT_API_KEY"),
invariant_authorization="Bearer " + self.get_invariant_api_key(),
guardrails=self.guardrails,
guardrails_parameters={
"metadata": self.session_metadata(),
Expand Down Expand Up @@ -191,6 +207,7 @@ async def _push_trace_update(self, deduplicated_annotations: list) -> None:
try:
client = AsyncClient(
api_url=os.getenv("INVARIANT_API_URL", DEFAULT_API_URL),
api_key=self.get_invariant_api_key(),
)

# If no trace exists, create a new one
Expand Down Expand Up @@ -256,6 +273,7 @@ class SseHeaderAttributes(BaseModel):

push_explorer: bool
explorer_dataset: str
invariant_api_key: Optional[str] = None

@classmethod
def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
Expand All @@ -271,6 +289,7 @@ def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
# Extract and process header values
project_name = headers.get("INVARIANT-PROJECT-NAME")
push_explorer_header = headers.get("PUSH-INVARIANT-EXPLORER", "false").lower()
invariant_api_key = headers.get("INVARIANT-API-KEY")

# Determine explorer_dataset
if project_name:
Expand All @@ -282,7 +301,11 @@ def from_request_headers(cls, headers: Headers) -> "SseHeaderAttributes":
push_explorer = push_explorer_header == "true"

# Create and return instance
return cls(push_explorer=push_explorer, explorer_dataset=explorer_dataset)
return cls(
push_explorer=push_explorer,
explorer_dataset=explorer_dataset,
invariant_api_key=invariant_api_key,
)


class McpSessionsManager:
Expand All @@ -292,24 +315,57 @@ class McpSessionsManager:

def __init__(self):
self._sessions: dict[str, McpSession] = {}
# Dictionary to store per-session locks.
# Used for session initialization and deletion.
self._session_locks: dict[str, asyncio.Lock] = {}
# Global lock to protect the locks dictionary itself
self._global_lock = asyncio.Lock()

def session_exists(self, session_id: str) -> bool:
"""Check if a session exists"""
return session_id in self._sessions

async def _get_session_lock(self, session_id: str) -> asyncio.Lock:
"""
Get a lock for a specific session ID, creating one if it doesn't exist.
Uses the global lock to protect access to the locks dictionary.
"""
async with self._global_lock:
if session_id not in self._session_locks:
self._session_locks[session_id] = asyncio.Lock()
return self._session_locks[session_id]

async def cleanup_session_lock(self, session_id: str) -> None:
"""Remove a session lock when it's no longer needed"""
async with self._global_lock:
if session_id in self._session_locks:
del self._session_locks[session_id]

async def initialize_session(
self, session_id: str, sse_header_attributes: SseHeaderAttributes
) -> None:
"""Initialize a new session"""
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
explorer_dataset=sse_header_attributes.explorer_dataset,
push_explorer=sse_header_attributes.push_explorer,
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
await session.load_guardrails()
# Get the lock for this specific session
session_lock = await self._get_session_lock(session_id)

# Acquire the lock for this session
async with session_lock:
# Check again if session exists (it might have been created while waiting for the lock)
if session_id not in self._sessions:
session = McpSession(
session_id=session_id,
**sse_header_attributes.model_dump(
exclude_unset=True,
),
)
self._sessions[session_id] = session
# Load guardrails for the session from the explorer
await session.load_guardrails()
else:
print(
f"Session {session_id} already exists, skipping initialization",
flush=True,
)

def get_session(self, session_id: str) -> McpSession:
"""Get a session by ID"""
Expand Down
Loading
Loading