Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions hindsight-api/hindsight_api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import logging
from contextlib import asynccontextmanager
from typing import Optional

from fastapi import FastAPI

Expand Down Expand Up @@ -46,14 +45,14 @@ def create_app(
# Both HTTP and MCP
app = create_app(memory, mcp_api_enabled=True)
"""
mcp_app = None
mcp_servers = None

# Create MCP app first if enabled (we need its lifespan for chaining)
# Create MCP servers first if enabled (we need their lifespans for chaining)
if mcp_api_enabled:
try:
from .mcp import create_mcp_app
from .mcp import MCPMiddleware, create_mcp_servers

mcp_app = create_mcp_app(memory=memory)
mcp_servers = create_mcp_servers(memory=memory)
except ImportError as e:
logger.error(f"MCP server requested but dependencies not available: {e}")
logger.error("Install with: pip install hindsight-api[mcp]")
Expand All @@ -70,11 +69,9 @@ def create_app(
app = FastAPI(title="Hindsight API", version="0.0.7")
logger.info("HTTP REST API disabled")

# Mount MCP server and chain its lifespan if enabled
if mcp_app is not None:
# Get both MCP apps' underlying Starlette apps for lifespan access
multi_bank_starlette_app = mcp_app.multi_bank_app
single_bank_starlette_app = mcp_app.single_bank_app
# Add MCP middleware and chain its lifespan if enabled
if mcp_servers is not None:
multi_bank_server, single_bank_server, multi_bank_starlette_app, single_bank_starlette_app = mcp_servers

# Store the original lifespan
original_lifespan = app.router.lifespan_context
Expand All @@ -94,8 +91,19 @@ async def chained_lifespan(app_instance: FastAPI):
# Replace the app's lifespan with the chained version
app.router.lifespan_context = chained_lifespan

# Mount the MCP middleware
app.mount(mcp_mount_path, mcp_app)
# Add MCP as a wrapping middleware — intercepts /mcp* requests directly,
# passes everything else through to the FastAPI app. No Starlette Mount
# means no 307 redirect for /mcp (no trailing slash).
app.add_middleware(
MCPMiddleware,
memory=memory,
prefix=mcp_mount_path,
multi_bank_app=multi_bank_starlette_app,
single_bank_app=single_bank_starlette_app,
multi_bank_server=multi_bank_server,
single_bank_server=single_bank_server,
)

logger.info(f"MCP server enabled at {mcp_mount_path}/")

return app
Expand Down
130 changes: 64 additions & 66 deletions hindsight-api/hindsight_api/api/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def create_mcp_server(memory: MemoryEngine, multi_bank: bool = True) -> FastMCP:


class MCPMiddleware:
"""ASGI middleware that handles authentication and routes to appropriate MCP server.
"""ASGI middleware that intercepts MCP requests and routes to appropriate MCP server.

This middleware wraps the main FastAPI app and intercepts requests matching the
configured prefix (default: /mcp). Non-MCP requests pass through to the inner app.

Authentication:
1. If HINDSIGHT_API_MCP_AUTH_TOKEN is set (legacy), validates against that token
Expand Down Expand Up @@ -149,27 +152,33 @@ class MCPMiddleware:
--header "X-Bank-Id: my-bank" --header "Authorization: Bearer <token>"
"""

def __init__(self, app, memory: MemoryEngine):
def __init__(
self,
app,
memory: MemoryEngine,
prefix: str = "/mcp",
multi_bank_app=None,
single_bank_app=None,
multi_bank_server=None,
single_bank_server=None,
):
self.app = app
self.prefix = prefix
self.memory = memory
self.tenant_extension = memory._tenant_extension

# Create two server instances:
# 1. Multi-bank server (for /mcp/ root endpoint)
self.multi_bank_server = create_mcp_server(memory, multi_bank=True)
self.multi_bank_app = self.multi_bank_server.http_app(path="/")

# 2. Single-bank server (for /mcp/{bank_id}/ endpoints)
self.single_bank_server = create_mcp_server(memory, multi_bank=False)
self.single_bank_app = self.single_bank_server.http_app(path="/")

# Backward compatibility: expose multi_bank_app as mcp_app
self.mcp_app = self.multi_bank_app

# Expose the lifespan for the parent app to chain (use multi-bank as default)
self.lifespan = (
self.multi_bank_app.lifespan_handler if hasattr(self.multi_bank_app, "lifespan_handler") else None
)
if multi_bank_app and single_bank_app:
# Pre-created servers (used when called via add_middleware from create_app)
self.multi_bank_app = multi_bank_app
self.single_bank_app = single_bank_app
self.multi_bank_server = multi_bank_server
self.single_bank_server = single_bank_server
else:
# Create servers internally (for direct construction / tests)
self.multi_bank_server = create_mcp_server(memory, multi_bank=True)
self.multi_bank_app = self.multi_bank_server.http_app(path="/")
self.single_bank_server = create_mcp_server(memory, multi_bank=False)
self.single_bank_app = self.single_bank_server.http_app(path="/")

def _get_header(self, scope: dict, name: str) -> str | None:
"""Extract a header value from ASGI scope."""
Expand All @@ -181,9 +190,20 @@ def _get_header(self, scope: dict, name: str) -> str | None:

async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.multi_bank_app(scope, receive, send)
await self.app(scope, receive, send)
return

path = scope.get("path", "")

# Check if this is an MCP request (matches prefix)
if not (path == self.prefix or path.startswith(self.prefix + "/")):
# Not an MCP request — pass through to the inner app
await self.app(scope, receive, send)
return

# Strip prefix from path
path = path[len(self.prefix) :] or "/"

# Extract auth token from header (for tenant auth propagation)
auth_header = self._get_header(scope, "Authorization")
auth_token: str | None = None
Expand Down Expand Up @@ -222,36 +242,15 @@ async def __call__(self, scope, receive, send):
_current_schema.set(tenant_context.schema_name) if tenant_context and tenant_context.schema_name else None
)

path = scope.get("path", "")

# Strip any mount prefix (e.g., /mcp) that FastAPI might not have stripped
root_path = scope.get("root_path", "")
if root_path and path.startswith(root_path):
path = path[len(root_path) :] or "/"

# Also handle case where mount path wasn't stripped (e.g., /mcp/...)
if path.startswith("/mcp/"):
path = path[4:] # Remove /mcp prefix
elif path == "/mcp":
path = "/"

# Ensure path has leading slash (needed after stripping mount path)
if path and not path.startswith("/"):
path = "/" + path

# Try to get bank_id from header first (for Claude Code compatibility)
bank_id = self._get_header(scope, "X-Bank-Id")
bank_id_from_path = False

# MCP endpoint paths that should not be treated as bank_ids
MCP_ENDPOINTS = {"sse", "messages"}

# If no header, try to extract from path: /{bank_id}/...
new_path = path
if not bank_id and path.startswith("/") and len(path) > 1:
parts = path[1:].split("/", 1)
# Don't treat MCP endpoints as bank_ids
if parts[0] and parts[0] not in MCP_ENDPOINTS:
if parts[0]:
# First segment looks like a bank_id
bank_id = parts[0]
bank_id_from_path = True
Expand Down Expand Up @@ -280,9 +279,19 @@ async def __call__(self, scope, receive, send):
# Clear root_path since we're passing directly to the app
new_scope["root_path"] = ""

# Wrap send to rewrite the SSE endpoint URL to include bank_id if using path-based routing
# Wrap send to rewrite the SSE endpoint URL to include bank_id if using path-based routing.
# Only rewrite SSE (text/event-stream) responses to avoid corrupting tool results
# that might contain the literal string "data: /messages".
is_sse_response = False

async def send_wrapper(message):
if message["type"] == "http.response.body" and bank_id_from_path:
nonlocal is_sse_response
if message["type"] == "http.response.start":
for header_name, header_value in message.get("headers", []):
if header_name == b"content-type" and b"text/event-stream" in header_value:
is_sse_response = True
break
if message["type"] == "http.response.body" and bank_id_from_path and is_sse_response:
body = message.get("body", b"")
if body and b"/messages" in body:
# Rewrite /messages to /{bank_id}/messages in SSE endpoint event
Expand Down Expand Up @@ -320,30 +329,19 @@ async def _send_error(self, send, status: int, message: str):
)


def create_mcp_app(memory: MemoryEngine):
"""
Create an ASGI app that handles MCP requests with dynamic tool exposure.

Authentication:
Uses the TenantExtension from the MemoryEngine (same auth as REST API).

Two modes based on URL structure:

1. Single-bank mode (recommended for agent isolation):
- URL: /mcp/{bank_id}/
- Tools: retain, recall, reflect (no bank_id parameter)
- Example: claude mcp add --transport http my-agent http://localhost:8888/mcp/my-agent-bank/
def create_mcp_servers(memory: MemoryEngine):
"""Create multi-bank and single-bank MCP servers and their Starlette apps.

2. Multi-bank mode (for cross-bank operations):
- URL: /mcp/
- Tools: retain, recall, reflect, list_banks, create_bank (all with bank_id parameter)
- Bank ID from: X-Bank-Id header or HINDSIGHT_MCP_BANK_ID env var (default: "default")
- Example: claude mcp add --transport http hindsight http://localhost:8888/mcp --header "X-Bank-Id: my-bank"

Args:
memory: MemoryEngine instance
Returns the servers and apps separately so lifespans can be chained before
the middleware wraps the main app.

Returns:
ASGI application
Tuple of (multi_bank_server, single_bank_server, multi_bank_app, single_bank_app)
"""
return MCPMiddleware(None, memory)
multi_bank_server = create_mcp_server(memory, multi_bank=True)
multi_bank_app = multi_bank_server.http_app(path="/")

single_bank_server = create_mcp_server(memory, multi_bank=False)
single_bank_app = single_bank_server.http_app(path="/")

return multi_bank_server, single_bank_server, multi_bank_app, single_bank_app
49 changes: 43 additions & 6 deletions hindsight-api/hindsight_api/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,19 @@ async def create_bank(bank_id: str, name: str | None = None, mission: str | None
return f'{{"error": "{e}"}}'


def _validate_mental_model_inputs(
name: str | None = None, source_query: str | None = None, max_tokens: int | None = None
) -> str | None:
"""Validate mental model inputs, returning an error message or None if valid."""
if name is not None and not name.strip():
return "name cannot be empty"
if source_query is not None and not source_query.strip():
return "source_query cannot be empty"
if max_tokens is not None and (max_tokens < 256 or max_tokens > 8192):
return f"max_tokens must be between 256 and 8192, got {max_tokens}"
return None


# =========================================================================
# MENTAL MODEL TOOLS
# =========================================================================
Expand Down Expand Up @@ -656,7 +669,7 @@ async def get_mental_model(
request_context=_get_request_context(config),
)
if model is None:
return json.dumps({"error": f"Mental model '{mental_model_id}' not found"})
return json.dumps({"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"})
return json.dumps(model, indent=2, default=str)
except Exception as e:
logger.error(f"Error getting mental model: {e}", exc_info=True)
Expand Down Expand Up @@ -688,7 +701,7 @@ async def get_mental_model(
request_context=_get_request_context(config),
)
if model is None:
return {"error": f"Mental model '{mental_model_id}' not found"}
return {"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"}
return model
except Exception as e:
logger.error(f"Error getting mental model: {e}", exc_info=True)
Expand Down Expand Up @@ -734,6 +747,12 @@ async def create_mental_model(
if target_bank is None:
return '{"error": "No bank_id configured"}'

validation_error = _validate_mental_model_inputs(
name=name, source_query=source_query, max_tokens=max_tokens
)
if validation_error:
return json.dumps({"error": validation_error})

request_context = _get_request_context(config)

# Create with placeholder content
Expand Down Expand Up @@ -803,6 +822,12 @@ async def create_mental_model(
if target_bank is None:
return {"error": "No bank_id configured"}

validation_error = _validate_mental_model_inputs(
name=name, source_query=source_query, max_tokens=max_tokens
)
if validation_error:
return {"error": validation_error}

request_context = _get_request_context(config)

model = await memory.create_mental_model(
Expand Down Expand Up @@ -868,6 +893,12 @@ async def update_mental_model(
if target_bank is None:
return '{"error": "No bank_id configured"}'

validation_error = _validate_mental_model_inputs(
name=name, source_query=source_query, max_tokens=max_tokens
)
if validation_error:
return json.dumps({"error": validation_error})

model = await memory.update_mental_model(
bank_id=target_bank,
mental_model_id=mental_model_id,
Expand All @@ -878,7 +909,7 @@ async def update_mental_model(
request_context=_get_request_context(config),
)
if model is None:
return json.dumps({"error": f"Mental model '{mental_model_id}' not found"})
return json.dumps({"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"})
return json.dumps(model, indent=2, default=str)
except Exception as e:
logger.error(f"Error updating mental model: {e}", exc_info=True)
Expand Down Expand Up @@ -912,6 +943,12 @@ async def update_mental_model(
if target_bank is None:
return {"error": "No bank_id configured"}

validation_error = _validate_mental_model_inputs(
name=name, source_query=source_query, max_tokens=max_tokens
)
if validation_error:
return {"error": validation_error}

model = await memory.update_mental_model(
bank_id=target_bank,
mental_model_id=mental_model_id,
Expand All @@ -922,7 +959,7 @@ async def update_mental_model(
request_context=_get_request_context(config),
)
if model is None:
return {"error": f"Mental model '{mental_model_id}' not found"}
return {"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"}
return model
except Exception as e:
logger.error(f"Error updating mental model: {e}", exc_info=True)
Expand Down Expand Up @@ -959,7 +996,7 @@ async def delete_mental_model(
request_context=_get_request_context(config),
)
if not deleted:
return json.dumps({"error": f"Mental model '{mental_model_id}' not found"})
return json.dumps({"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"})
return json.dumps({"status": "deleted", "mental_model_id": mental_model_id})
except Exception as e:
logger.error(f"Error deleting mental model: {e}", exc_info=True)
Expand Down Expand Up @@ -990,7 +1027,7 @@ async def delete_mental_model(
request_context=_get_request_context(config),
)
if not deleted:
return {"error": f"Mental model '{mental_model_id}' not found"}
return {"error": f"Mental model '{mental_model_id}' not found in bank '{target_bank}'"}
return {"status": "deleted", "mental_model_id": mental_model_id}
except Exception as e:
logger.error(f"Error deleting mental model: {e}", exc_info=True)
Expand Down
Loading
Loading