Skip to content

Add server initialization hook #19

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions Sources/MCP/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,15 @@ public actor Server {
}

/// Start the server
public func start(transport: any Transport) async throws {
/// - Parameters:
/// - transport: The transport to use for the server
/// - initializeHook: An optional hook that runs when the client sends an initialize request
public func start(
transport: any Transport,
initializeHook: (@Sendable (Client.Info, Client.Capabilities) async throws -> Void)? = nil
) async throws {
self.connection = transport
registerDefaultHandlers()
registerDefaultHandlers(initializeHook: initializeHook)
try await transport.connect()

await logger?.info(
Expand Down Expand Up @@ -205,7 +211,8 @@ public actor Server {
"Error processing message", metadata: ["error": "\(error)"])
let response = AnyMethod.response(
id: requestID ?? .random,
error: error as? Error ?? Error.internalError(error.localizedDescription)
error: error as? Error
?? Error.internalError(error.localizedDescription)
)
try? await send(response)
}
Expand Down Expand Up @@ -354,7 +361,9 @@ public actor Server {
}
}

private func registerDefaultHandlers() {
private func registerDefaultHandlers(
initializeHook: (@Sendable (Client.Info, Client.Capabilities) async throws -> Void)?
) {
// Initialize
withMethodHandler(Initialize.self) { [weak self] params in
guard let self = self else {
Expand All @@ -371,26 +380,30 @@ public actor Server {
"Unsupported protocol version: \(params.protocolVersion)")
}

// Call initialization hook if registered
if let hook = initializeHook {
try await hook(params.clientInfo, params.capabilities)
}

// Set initial state
await self.setInitialState(
clientInfo: params.clientInfo,
clientCapabilities: params.capabilities,
protocolVersion: params.protocolVersion
)

let result = Initialize.Result(
protocolVersion: Version.latest,
capabilities: await self.capabilities,
serverInfo: self.serverInfo,
instructions: nil
)

// Send initialized notification after a short delay
Task {
try? await Task.sleep(nanoseconds: 100_000_000) // 100ms
try? await self.notify(InitializedNotification.message())
}

return result
return Initialize.Result(
protocolVersion: Version.latest,
capabilities: await self.capabilities,
serverInfo: self.serverInfo,
instructions: nil
)
}

// Ping
Expand Down
85 changes: 85 additions & 0 deletions Tests/MCPTests/ServerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,89 @@ struct ServerTests {
await server.stop()
await transport.disconnect()
}

@Test("Initialize hook - successful")
func testInitializeHookSuccess() async throws {
let transport = MockTransport()

actor TestState {
var hookCalled = false
func setHookCalled() { hookCalled = true }
func wasHookCalled() -> Bool { hookCalled }
}

let state = TestState()
let server = Server(name: "TestServer", version: "1.0")

// Start with the hook directly
try await server.start(transport: transport) { clientInfo, capabilities in
#expect(clientInfo.name == "TestClient")
#expect(clientInfo.version == "1.0")
await state.setHookCalled()
}

// Wait for server to initialize
try await Task.sleep(nanoseconds: 10_000_000) // 10ms

// Queue an initialize request
try await transport.queueRequest(
Initialize.request(
.init(
protocolVersion: Version.latest,
capabilities: .init(),
clientInfo: .init(name: "TestClient", version: "1.0")
)
))

// Wait for message processing and hook execution
try await Task.sleep(nanoseconds: 200_000_000) // 200ms

#expect(await state.wasHookCalled() == true)
#expect(await transport.sentMessages.count >= 1)

let messages = await transport.sentMessages
if let response = messages.first {
#expect(response.contains("serverInfo"))
}

await server.stop()
}

@Test("Initialize hook - rejection")
func testInitializeHookRejection() async throws {
let transport = MockTransport()

let server = Server(name: "TestServer", version: "1.0")

try await server.start(transport: transport) { clientInfo, _ in
if clientInfo.name == "BlockedClient" {
throw Error.invalidRequest("Client not allowed")
}
}

// Wait for server to initialize
try await Task.sleep(nanoseconds: 10_000_000) // 10ms

// Queue an initialize request from blocked client
try await transport.queueRequest(
Initialize.request(
.init(
protocolVersion: Version.latest,
capabilities: .init(),
clientInfo: .init(name: "BlockedClient", version: "1.0")
)
))

// Wait for message processing
try await Task.sleep(nanoseconds: 200_000_000) // 200ms

#expect(await transport.sentMessages.count >= 2)

let messages = await transport.sentMessages
if let response = messages.first {
#expect(response.contains("error"))
#expect(response.contains("Client not allowed"))
}
await server.stop()
}
}