|
1 | 1 | """MCP server configuration loading utilities.""" |
2 | 2 |
|
3 | 3 | import json |
| 4 | +import re |
| 5 | +from abc import ABC, abstractmethod |
| 6 | +from dataclasses import dataclass, field |
| 7 | +from enum import Enum |
4 | 8 | from pathlib import Path |
5 | | -from typing import Dict, List, Optional |
| 9 | +from typing import Callable, Dict, List, Optional |
| 10 | +from urllib.parse import urlparse |
6 | 11 |
|
7 | 12 | from mcp import StdioServerParameters |
| 13 | +from mcp.client.sse import sse_client |
8 | 14 | from mcp.client.stdio import stdio_client |
| 15 | +from mcp.client.streamable_http import streamablehttp_client |
9 | 16 |
|
10 | 17 | from .mcp_client import MCPClient |
11 | 18 | from .mcp_types import MCPTransport |
12 | 19 |
|
13 | 20 |
|
14 | | -class MCPServerConfig: |
15 | | - """Configuration for an MCP server following MCP standards.""" |
| 21 | +class MCPTransportType(Enum): |
| 22 | + """MCP transport types.""" |
16 | 23 |
|
17 | | - def __init__( |
18 | | - self, |
19 | | - name: str, |
20 | | - command: str, |
21 | | - args: Optional[List[str]] = None, |
22 | | - env: Optional[Dict[str, str]] = None, |
23 | | - timeout: Optional[int] = None, |
24 | | - ): |
25 | | - """Initialize MCP server configuration. |
26 | | -
|
27 | | - Args: |
28 | | - name: Server name |
29 | | - command: Command to run the server |
30 | | - args: Command arguments |
31 | | - env: Environment variables |
32 | | - timeout: Timeout in milliseconds |
33 | | - """ |
34 | | - self.name = name |
35 | | - self.command = command |
36 | | - self.args = args or [] |
37 | | - self.env = env or {} |
38 | | - self.timeout = timeout or 60000 |
| 24 | + STDIO = "stdio" |
| 25 | + STREAMABLE_HTTP = "streamable-http" |
| 26 | + SSE = "sse" |
39 | 27 |
|
40 | | - def create_client(self) -> MCPClient: |
41 | | - """Create an MCPClient from this configuration.""" |
| 28 | + |
| 29 | +@dataclass |
| 30 | +class MCPTransportConfig(ABC): |
| 31 | + """Base configuration for MCP Transports.""" |
| 32 | + |
| 33 | + name: str |
| 34 | + timeout: float = 60000 |
| 35 | + transport_type: Optional[MCPTransportType] = None |
| 36 | + |
| 37 | + @abstractmethod |
| 38 | + def create_transport_callable(self) -> Callable[[], MCPTransport]: |
| 39 | + """Create a transport callable for this configuration.""" |
| 40 | + pass |
| 41 | + |
| 42 | + |
| 43 | +@dataclass |
| 44 | +class StdioTransportConfig(MCPTransportConfig): |
| 45 | + """Configuration for STDIO transport (local subprocess).""" |
| 46 | + |
| 47 | + command: str = "" |
| 48 | + args: List[str] = field(default_factory=list) |
| 49 | + env: Dict[str, str] = field(default_factory=dict) |
| 50 | + cwd: Optional[str] = None |
| 51 | + |
| 52 | + def __post_init__(self) -> None: |
| 53 | + """Set transport type after initialization.""" |
| 54 | + self.transport_type = MCPTransportType.STDIO |
| 55 | + |
| 56 | + def create_transport_callable(self) -> Callable[[], MCPTransport]: |
| 57 | + """Create STDIO transport callable.""" |
42 | 58 |
|
43 | 59 | def transport_callable() -> MCPTransport: |
44 | | - server_params = StdioServerParameters(command=self.command, args=self.args, env=self.env) |
| 60 | + server_params = StdioServerParameters( |
| 61 | + command=self.command, |
| 62 | + args=self.args, |
| 63 | + env=self.env, |
| 64 | + cwd=self.cwd, |
| 65 | + ) |
45 | 66 | return stdio_client(server_params) |
46 | 67 |
|
47 | | - return MCPClient(transport_callable) |
| 68 | + return transport_callable |
| 69 | + |
| 70 | + |
| 71 | +@dataclass |
| 72 | +class HTTPTransportConfig(MCPTransportConfig): |
| 73 | + """Configuration for HTTP transport.""" |
| 74 | + |
| 75 | + url: str = "" |
| 76 | + authorization_token: Optional[str] = None |
| 77 | + headers: Dict[str, str] = field(default_factory=dict) |
| 78 | + |
| 79 | + def __post_init__(self) -> None: |
| 80 | + """Set transport type after initialization.""" |
| 81 | + self.transport_type = MCPTransportType.STREAMABLE_HTTP |
| 82 | + |
| 83 | + def create_transport_callable(self) -> Callable[[], MCPTransport]: |
| 84 | + """Create STREAMABLE_HTTP transport callable.""" |
| 85 | + |
| 86 | + def transport_callable() -> MCPTransport: |
| 87 | + headers = self.headers.copy() |
| 88 | + if self.authorization_token: |
| 89 | + headers["Authorization"] = f"Bearer {self.authorization_token}" |
| 90 | + return streamablehttp_client(self.url, headers=headers if headers else None) |
| 91 | + |
| 92 | + return transport_callable |
| 93 | + |
| 94 | + |
| 95 | +@dataclass |
| 96 | +class SSETransportConfig(MCPTransportConfig): |
| 97 | + """Configuration for SSE transport.""" |
| 98 | + |
| 99 | + url: str = "" |
| 100 | + authorization_token: Optional[str] = None |
| 101 | + headers: Dict[str, str] = field(default_factory=dict) |
| 102 | + |
| 103 | + def __post_init__(self) -> None: |
| 104 | + """Set transport type after initialization.""" |
| 105 | + self.transport_type = MCPTransportType.SSE |
| 106 | + |
| 107 | + def create_transport_callable(self) -> Callable[[], MCPTransport]: |
| 108 | + """Create SSE transport callable.""" |
| 109 | + |
| 110 | + def transport_callable() -> MCPTransport: |
| 111 | + headers = self.headers.copy() if self.headers else {} |
| 112 | + if self.authorization_token: |
| 113 | + headers["Authorization"] = f"Bearer {self.authorization_token}" |
| 114 | + return sse_client(self.url, headers=headers if headers else None) |
| 115 | + |
| 116 | + return transport_callable |
| 117 | + |
| 118 | + |
| 119 | +def _infer_transport_from_url(url: str) -> str: |
| 120 | + """Infer transport type from URL when none specified. |
| 121 | +
|
| 122 | + - If path contains '/sse' (optionally followed by / ? & or end), treat as 'sse' |
| 123 | + - Otherwise default to 'streamable-http' |
| 124 | + """ |
| 125 | + try: |
| 126 | + path = (urlparse(url).path or "").lower() |
| 127 | + except Exception: |
| 128 | + return "streamable-http" |
| 129 | + return "sse" if re.search(r"/sse(/|\?|&|$)", path) else "streamable-http" |
| 130 | + |
| 131 | + |
| 132 | +class MCPServerConfig: |
| 133 | + """Configuration for an MCP server following MCP standards.""" |
| 134 | + |
| 135 | + def __init__(self, transport_config: MCPTransportConfig) -> None: |
| 136 | + """Initialize MCP server configuration.""" |
| 137 | + self.transport_config = transport_config |
| 138 | + self.name = transport_config.name |
| 139 | + self.timeout = transport_config.timeout |
| 140 | + |
| 141 | + def create_client(self) -> MCPClient: |
| 142 | + """Create an MCPClient from this configuration.""" |
| 143 | + return MCPClient(self.transport_config.create_transport_callable()) |
48 | 144 |
|
49 | 145 | @classmethod |
50 | 146 | def from_config(cls, config_path: str) -> List["MCPServerConfig"]: |
51 | | - """Load MCP server configurations from standard mcp.json format. |
52 | | -
|
53 | | - Args: |
54 | | - config_path: Path to the MCP configuration file |
55 | | -
|
56 | | - Returns: |
57 | | - List of MCPServerConfig instances |
58 | | -
|
59 | | - Config file examples: |
60 | | - Anthropic MCP Server Config Examples: (https://modelcontextprotocol.io/examples) |
61 | | - AmazonQ MCP Server Config Examples: (https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-understanding-config.html) |
62 | | -
|
63 | | - Expected format: |
64 | | - { |
65 | | - "mcpServers": { |
66 | | - "server-name": { |
67 | | - "command": "command-to-run", |
68 | | - "args": ["arg1", "arg2"], |
69 | | - "env": { |
70 | | - "ENV_VAR1": "value1", |
71 | | - "ENV_VAR2": "value2" |
72 | | - }, |
73 | | - "timeout": 60000 |
74 | | - } |
75 | | - } |
76 | | - } |
77 | | -
|
78 | | - """ |
| 147 | + """Load MCP server configurations from a config file.""" |
79 | 148 | config_file = Path(config_path) |
80 | 149 | if not config_file.exists(): |
81 | 150 | raise FileNotFoundError(f"Config file not found: {config_path}") |
82 | 151 |
|
83 | 152 | with open(config_file) as f: |
84 | 153 | config_data = json.load(f) |
85 | 154 |
|
| 155 | + if "mcpServers" in config_data and config_data["mcpServers"] is not None: |
| 156 | + return cls._parse_mcp_servers_format(config_data["mcpServers"]) |
| 157 | + |
| 158 | + return [] |
| 159 | + |
| 160 | + @classmethod |
| 161 | + def _parse_mcp_servers_format(cls, mcp_servers: Dict[str, Dict]) -> List["MCPServerConfig"]: |
| 162 | + """Parse mcpServers format (Claude Desktop/legacy and enhanced).""" |
86 | 163 | servers = [] |
87 | | - mcp_server_name = set() |
88 | | - mcp_servers = config_data.get("mcpServers", {}) |
89 | | - expected_attrs = {"command", "args", "env", "timeout"} |
| 164 | + server_names = set() |
90 | 165 |
|
91 | 166 | for name, server_config in mcp_servers.items(): |
92 | | - if len(name) == 0 or len(name) > 250 or server_config.get("command") is None or name in mcp_server_name: |
93 | | - raise ValueError(f"Invalid server configuration for {name}") |
94 | | - if set(server_config.keys()) - expected_attrs: |
95 | | - raise ValueError(f"Invalid server configuration for {name}") |
96 | | - |
97 | | - servers.append( |
98 | | - cls( |
99 | | - name=name, |
100 | | - command=server_config["command"], |
101 | | - args=server_config.get("args"), |
102 | | - env=server_config.get("env"), |
103 | | - timeout=server_config.get("timeout"), |
104 | | - ) |
| 167 | + if len(name) == 0 or len(name) > 250 or name in server_names: |
| 168 | + raise ValueError(f"Invalid server name: {name}") |
| 169 | + server_names.add(name) |
| 170 | + |
| 171 | + # Accept multiple keys, infer from URL when absent |
| 172 | + raw_transport = ( |
| 173 | + server_config.get("transport") |
| 174 | + or server_config.get("transportType") |
| 175 | + or server_config.get("transport_type") |
105 | 176 | ) |
106 | | - mcp_server_name.add(name) |
| 177 | + if raw_transport is None and "url" in server_config: |
| 178 | + raw_transport = _infer_transport_from_url(str(server_config["url"])) |
| 179 | + transport_type = (raw_transport or "stdio").lower() |
| 180 | + |
| 181 | + timeout = server_config.get("timeout", 60000) |
| 182 | + |
| 183 | + try: |
| 184 | + if transport_type == "stdio" or "command" in server_config: |
| 185 | + if "command" not in server_config: |
| 186 | + raise ValueError(f"STDIO server {name} missing 'command'") |
| 187 | + |
| 188 | + transport_config: MCPTransportConfig = StdioTransportConfig( |
| 189 | + name=name, |
| 190 | + command=server_config["command"], |
| 191 | + args=server_config.get("args", []), |
| 192 | + env=server_config.get("env", {}), |
| 193 | + cwd=server_config.get("cwd"), |
| 194 | + timeout=timeout, |
| 195 | + ) |
| 196 | + |
| 197 | + elif transport_type == "streamable-http": |
| 198 | + if "url" not in server_config: |
| 199 | + raise ValueError(f"Steamable-HTTP server {name} missing 'url'") |
| 200 | + |
| 201 | + transport_config = HTTPTransportConfig( |
| 202 | + name=name, |
| 203 | + url=server_config["url"], |
| 204 | + authorization_token=server_config.get("authorization_token"), |
| 205 | + headers=server_config.get("headers", {}), |
| 206 | + timeout=timeout, |
| 207 | + ) |
| 208 | + |
| 209 | + elif transport_type == "sse": |
| 210 | + if "url" not in server_config: |
| 211 | + raise ValueError(f"SSE server {name} missing 'url'") |
| 212 | + |
| 213 | + transport_config = SSETransportConfig( |
| 214 | + name=name, |
| 215 | + url=server_config["url"], |
| 216 | + authorization_token=server_config.get("authorization_token"), |
| 217 | + headers=server_config.get("headers", {}), |
| 218 | + timeout=timeout, |
| 219 | + ) |
| 220 | + else: |
| 221 | + raise ValueError(f"Unsupported transport type: {transport_type}") |
| 222 | + |
| 223 | + servers.append(cls(transport_config)) |
| 224 | + |
| 225 | + except Exception as e: |
| 226 | + raise ValueError(f"Invalid configuration for server {name}: {e}") from e |
107 | 227 |
|
108 | 228 | return servers |
0 commit comments