Skip to content

Fix client parsing #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions Sources/MCP/Base/Transports.swift
Original file line number Diff line number Diff line change
Expand Up @@ -396,13 +396,7 @@ public actor StdioTransport: Transport {
if !receiveContinuationResumed {
receiveContinuationResumed = true
if let error = error {
if let nwError = error as? NWError {
continuation.resume(throwing: MCP.Error.transportError(nwError))
} else {
continuation.resume(
throwing: MCP.Error.internalError("Receive error: \(error)")
)
}
continuation.resume(throwing: MCP.Error.transportError(error))
} else if let content = content {
continuation.resume(returning: content)
} else {
Expand Down
65 changes: 49 additions & 16 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,48 @@ public actor Client {
/// The task for the message handling loop
private var task: Task<Void, Never>?

/// An error indicating a type mismatch when decoding a pending request
private struct TypeMismatchError: Swift.Error {}

/// A pending request with a continuation for the result
private struct PendingRequest<T> {
let continuation: CheckedContinuation<T, Swift.Error>
}

/// A type-erased pending request
private struct AnyPendingRequest {
private let _resume: (Result<Any, Swift.Error>) -> Void

init<T: Sendable & Decodable>(_ request: PendingRequest<T>) {
_resume = { result in
switch result {
case .success(let value):
if let typedValue = value as? T {
request.continuation.resume(returning: typedValue)
} else if let value = value as? Value,
let data = try? JSONEncoder().encode(value),
let decoded = try? JSONDecoder().decode(T.self, from: data)
{
request.continuation.resume(returning: decoded)
} else {
request.continuation.resume(throwing: TypeMismatchError())
}
case .failure(let error):
request.continuation.resume(throwing: error)
}
}
}
func resume(returning value: Any) {
_resume(.success(value))
}

func resume(throwing error: Swift.Error) {
_resume(.failure(error))
}
}

/// A dictionary of type-erased pending requests, keyed by request ID
private var pendingRequests: [ID: Any] = [:]
private var pendingRequests: [ID: AnyPendingRequest] = [:]

public init(
name: String,
Expand Down Expand Up @@ -129,8 +165,10 @@ public actor Client {

// Attempt to decode string data as AnyResponse or AnyMessage
let decoder = JSONDecoder()
if let response = try? decoder.decode(AnyResponse.self, from: data) {
await handleResponse(response, for: response)
if let response = try? decoder.decode(AnyResponse.self, from: data),
let request = pendingRequests[response.id]
{
await handleResponse(response, for: request)
} else if let message = try? decoder.decode(AnyMessage.self, from: data) {
await handleMessage(message)
} else {
Expand Down Expand Up @@ -158,11 +196,7 @@ public actor Client {
public func disconnect() async {
// Cancel all pending requests
for (id, request) in pendingRequests {
// We know this cast is safe because we only store PendingRequest values
if let typedRequest = request as? PendingRequest<Any> {
typedRequest.continuation.resume(
throwing: Error.internalError("Client disconnected"))
}
request.resume(throwing: Error.internalError("Client disconnected"))
pendingRequests.removeValue(forKey: id)
}

Expand Down Expand Up @@ -220,12 +254,12 @@ public actor Client {
}
}

private func addPendingRequest<T>(
private func addPendingRequest<T: Sendable & Decodable>(
id: ID,
continuation: CheckedContinuation<T, Swift.Error>,
type: T.Type
) {
pendingRequests[id] = PendingRequest(continuation: continuation)
pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation))
}

private func removePendingRequest(id: ID) {
Expand Down Expand Up @@ -320,19 +354,18 @@ public actor Client {

// MARK: -

private func handleResponse(_ response: Response<AnyMethod>, for request: Any) async {
private func handleResponse(_ response: Response<AnyMethod>, for request: AnyPendingRequest)
async
{
await logger?.debug(
"Processing response",
metadata: ["id": "\(response.id)"])

// We know this cast is safe because we only store PendingRequest values
guard let typedRequest = request as? PendingRequest<Any> else { return }

switch response.result {
case .success(let value):
typedRequest.continuation.resume(returning: value)
request.resume(returning: value)
case .failure(let error):
typedRequest.continuation.resume(throwing: error)
request.resume(throwing: error)
}

removePendingRequest(id: response.id)
Expand Down
94 changes: 72 additions & 22 deletions Tests/MCPTests/RoundtripTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ import Testing
@Suite("Roundtrip Tests")
struct RoundtripTests {
@Test(
"Initialize roundtrip",
.timeLimit(.minutes(1))
)
func testInitializeRoundtrip() async throws {
func testRoundtrip() async throws {
let (clientToServerRead, clientToServerWrite) = try FileDescriptor.pipe()
let (serverToClientRead, serverToClientWrite) = try FileDescriptor.pipe()

Expand All @@ -36,32 +35,83 @@ struct RoundtripTests {
version: "1.0.0",
capabilities: .init(prompts: .init(), tools: .init())
)
await server.withMethodHandler(ListTools.self) { _ in
return ListTools.Result(tools: [
Tool(
name: "add",
description: "Adds two numbers together",
inputSchema: [
"a": ["type": "integer", "description": "The first number"],
"a": ["type": "integer", "description": "The second number"],
])
])
}
await server.withMethodHandler(CallTool.self) { request in
guard request.name == "add" else {
return CallTool.Result(content: [.text("Invalid tool name")], isError: true)
}

guard let a = request.arguments?["a"]?.intValue,
let b = request.arguments?["b"]?.intValue
else {
return CallTool.Result(
content: [.text("Did not receive valid arguments")], isError: true)
}

return CallTool.Result(content: [.text("\(a + b)")], isError: false)
}

let client = Client(name: "TestClient", version: "1.0")

try await server.start(transport: serverTransport)
try await client.connect(transport: clientTransport)

// let initTask = Task {
// let result = try await client.initialize()
let initTask = Task {
let result = try await client.initialize()

#expect(result.serverInfo.name == "TestServer")
#expect(result.serverInfo.version == "1.0.0")
#expect(result.capabilities.prompts != nil)
#expect(result.capabilities.tools != nil)
#expect(result.protocolVersion == Version.latest)
}
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await Task.sleep(for: .seconds(1))
initTask.cancel()
throw CancellationError()
}
group.addTask {
try await initTask.value
}
try await group.next()
group.cancelAll()
}

let listToolsTask = Task {
let result = try await client.listTools()
#expect(result.count == 1)
#expect(result[0].name == "add")
}

let callToolTask = Task {
let result = try await client.callTool(name: "add", arguments: ["a": 1, "b": 2])
#expect(result.isError == false)
#expect(result.content == [.text("3")])
}

// #expect(result.serverInfo.name == "TestServer")
// #expect(result.serverInfo.version == "1.0.0")
// #expect(result.capabilities.prompts != nil)
// #expect(result.capabilities.tools != nil)
// #expect(result.protocolVersion == Version.latest)
// }
// try await withThrowingTaskGroup(of: Void.self) { group in
// group.addTask {
// try await Task.sleep(for: .seconds(1))
// initTask.cancel()
// throw CancellationError()
// }
// group.addTask {
// try await initTask.value
// }
// try await group.next()
// group.cancelAll()
// }
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await Task.sleep(for: .seconds(1))
listToolsTask.cancel()
throw CancellationError()
}
group.addTask {
try await callToolTask.value
}
try await group.next()
group.cancelAll()
}

await server.stop()
await client.disconnect()
Expand Down