Skip to content

Commit 2510dda

Browse files
authored
Add server initialization hook (#19)
1 parent e0186b5 commit 2510dda

File tree

2 files changed

+108
-11
lines changed

2 files changed

+108
-11
lines changed

Sources/MCP/Server/Server.swift

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,15 @@ public actor Server {
155155
}
156156

157157
/// Start the server
158-
public func start(transport: any Transport) async throws {
158+
/// - Parameters:
159+
/// - transport: The transport to use for the server
160+
/// - initializeHook: An optional hook that runs when the client sends an initialize request
161+
public func start(
162+
transport: any Transport,
163+
initializeHook: (@Sendable (Client.Info, Client.Capabilities) async throws -> Void)? = nil
164+
) async throws {
159165
self.connection = transport
160-
registerDefaultHandlers()
166+
registerDefaultHandlers(initializeHook: initializeHook)
161167
try await transport.connect()
162168

163169
await logger?.info(
@@ -364,7 +370,9 @@ public actor Server {
364370
}
365371
}
366372

367-
private func registerDefaultHandlers() {
373+
private func registerDefaultHandlers(
374+
initializeHook: (@Sendable (Client.Info, Client.Capabilities) async throws -> Void)?
375+
) {
368376
// Initialize
369377
withMethodHandler(Initialize.self) { [weak self] params in
370378
guard let self = self else {
@@ -381,26 +389,30 @@ public actor Server {
381389
"Unsupported protocol version: \(params.protocolVersion)")
382390
}
383391

392+
// Call initialization hook if registered
393+
if let hook = initializeHook {
394+
try await hook(params.clientInfo, params.capabilities)
395+
}
396+
397+
// Set initial state
384398
await self.setInitialState(
385399
clientInfo: params.clientInfo,
386400
clientCapabilities: params.capabilities,
387401
protocolVersion: params.protocolVersion
388402
)
389403

390-
let result = Initialize.Result(
391-
protocolVersion: Version.latest,
392-
capabilities: await self.capabilities,
393-
serverInfo: self.serverInfo,
394-
instructions: nil
395-
)
396-
397404
// Send initialized notification after a short delay
398405
Task {
399406
try? await Task.sleep(nanoseconds: 100_000_000) // 100ms
400407
try? await self.notify(InitializedNotification.message())
401408
}
402409

403-
return result
410+
return Initialize.Result(
411+
protocolVersion: Version.latest,
412+
capabilities: await self.capabilities,
413+
serverInfo: self.serverInfo,
414+
instructions: nil
415+
)
404416
}
405417

406418
// Ping

Tests/MCPTests/ServerTests.swift

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,89 @@ struct ServerTests {
4848
await server.stop()
4949
await transport.disconnect()
5050
}
51+
52+
@Test("Initialize hook - successful")
53+
func testInitializeHookSuccess() async throws {
54+
let transport = MockTransport()
55+
56+
actor TestState {
57+
var hookCalled = false
58+
func setHookCalled() { hookCalled = true }
59+
func wasHookCalled() -> Bool { hookCalled }
60+
}
61+
62+
let state = TestState()
63+
let server = Server(name: "TestServer", version: "1.0")
64+
65+
// Start with the hook directly
66+
try await server.start(transport: transport) { clientInfo, capabilities in
67+
#expect(clientInfo.name == "TestClient")
68+
#expect(clientInfo.version == "1.0")
69+
await state.setHookCalled()
70+
}
71+
72+
// Wait for server to initialize
73+
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
74+
75+
// Queue an initialize request
76+
try await transport.queueRequest(
77+
Initialize.request(
78+
.init(
79+
protocolVersion: Version.latest,
80+
capabilities: .init(),
81+
clientInfo: .init(name: "TestClient", version: "1.0")
82+
)
83+
))
84+
85+
// Wait for message processing and hook execution
86+
try await Task.sleep(nanoseconds: 200_000_000) // 200ms
87+
88+
#expect(await state.wasHookCalled() == true)
89+
#expect(await transport.sentMessages.count >= 1)
90+
91+
let messages = await transport.sentMessages
92+
if let response = messages.first {
93+
#expect(response.contains("serverInfo"))
94+
}
95+
96+
await server.stop()
97+
}
98+
99+
@Test("Initialize hook - rejection")
100+
func testInitializeHookRejection() async throws {
101+
let transport = MockTransport()
102+
103+
let server = Server(name: "TestServer", version: "1.0")
104+
105+
try await server.start(transport: transport) { clientInfo, _ in
106+
if clientInfo.name == "BlockedClient" {
107+
throw Error.invalidRequest("Client not allowed")
108+
}
109+
}
110+
111+
// Wait for server to initialize
112+
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
113+
114+
// Queue an initialize request from blocked client
115+
try await transport.queueRequest(
116+
Initialize.request(
117+
.init(
118+
protocolVersion: Version.latest,
119+
capabilities: .init(),
120+
clientInfo: .init(name: "BlockedClient", version: "1.0")
121+
)
122+
))
123+
124+
// Wait for message processing
125+
try await Task.sleep(nanoseconds: 200_000_000) // 200ms
126+
127+
#expect(await transport.sentMessages.count >= 2)
128+
129+
let messages = await transport.sentMessages
130+
if let response = messages.first {
131+
#expect(response.contains("error"))
132+
#expect(response.contains("Client not allowed"))
133+
}
134+
await server.stop()
135+
}
51136
}

0 commit comments

Comments
 (0)