Skip to content

Commit c63a872

Browse files
zatsmattt
andauthored
Fix client parsing (#8)
* Allow decoding Decodable type directly from Value Summary: I noticed we are not converting value associated with response to the strong types requests recorded This change will help to decode the value type later once we support the rest of plumbing (next commit) * Store and pass original request when response arrives Summary: Looks like there were multiple issues with parsing response: 1. It appears that we would mistakenly pass response both for response and request arguments in the private func handleResponse(_ response: Response<AnyMethod>, for request: Any) 2. Later I couldn't figure out how to win against swift type system, so ended up creating a type-erased request to store in the pendingRequests instead of original type under Any (that wouldn't allow to downcast original type later) Test Plan: 1. Standup iMCP client from https://github.com/loopwork-ai/iMCP 2. Create a dummy project setting up MCP client against "/private/var/folders/26/8gncz37x7slbfrr95d9jcr400000gn/T/AppTranslocation/62ADEC4D-3545-4E98-A612-0E6DF52CE525/d/iMCP.app/Contents/MacOS/imcp-server" and call initialize Note response is being parsed correctly and client.initialize() does not hang indefinitely * Fix warning: 'Cast from Client.AnyPendingRequest to unrelated type Client.PendingRequest<Any> always fails' * Fix warning: 'Conditional cast from NWError to NWError always succeeds' * Uncomment code in RoundtripTests * Formatting * Reorganize declarations * Inline JSONDecoder extension * Expand roundtrip tests --------- Co-authored-by: Sash Zats sash@zats.io <> Co-authored-by: Mattt Zmuda <mattt@loopwork.com>
1 parent 9ebe125 commit c63a872

File tree

3 files changed

+122
-45
lines changed

3 files changed

+122
-45
lines changed

Sources/MCP/Base/Transports.swift

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,7 @@ public actor StdioTransport: Transport {
396396
if !receiveContinuationResumed {
397397
receiveContinuationResumed = true
398398
if let error = error {
399-
if let nwError = error as? NWError {
400-
continuation.resume(throwing: MCP.Error.transportError(nwError))
401-
} else {
402-
continuation.resume(
403-
throwing: MCP.Error.internalError("Receive error: \(error)")
404-
)
405-
}
399+
continuation.resume(throwing: MCP.Error.transportError(error))
406400
} else if let content = content {
407401
continuation.resume(returning: content)
408402
} else {

Sources/MCP/Client/Client.swift

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,48 @@ public actor Client {
8787
/// The task for the message handling loop
8888
private var task: Task<Void, Never>?
8989

90+
/// An error indicating a type mismatch when decoding a pending request
91+
private struct TypeMismatchError: Swift.Error {}
92+
9093
/// A pending request with a continuation for the result
9194
private struct PendingRequest<T> {
9295
let continuation: CheckedContinuation<T, Swift.Error>
9396
}
97+
98+
/// A type-erased pending request
99+
private struct AnyPendingRequest {
100+
private let _resume: (Result<Any, Swift.Error>) -> Void
101+
102+
init<T: Sendable & Decodable>(_ request: PendingRequest<T>) {
103+
_resume = { result in
104+
switch result {
105+
case .success(let value):
106+
if let typedValue = value as? T {
107+
request.continuation.resume(returning: typedValue)
108+
} else if let value = value as? Value,
109+
let data = try? JSONEncoder().encode(value),
110+
let decoded = try? JSONDecoder().decode(T.self, from: data)
111+
{
112+
request.continuation.resume(returning: decoded)
113+
} else {
114+
request.continuation.resume(throwing: TypeMismatchError())
115+
}
116+
case .failure(let error):
117+
request.continuation.resume(throwing: error)
118+
}
119+
}
120+
}
121+
func resume(returning value: Any) {
122+
_resume(.success(value))
123+
}
124+
125+
func resume(throwing error: Swift.Error) {
126+
_resume(.failure(error))
127+
}
128+
}
129+
94130
/// A dictionary of type-erased pending requests, keyed by request ID
95-
private var pendingRequests: [ID: Any] = [:]
131+
private var pendingRequests: [ID: AnyPendingRequest] = [:]
96132

97133
public init(
98134
name: String,
@@ -129,8 +165,10 @@ public actor Client {
129165

130166
// Attempt to decode string data as AnyResponse or AnyMessage
131167
let decoder = JSONDecoder()
132-
if let response = try? decoder.decode(AnyResponse.self, from: data) {
133-
await handleResponse(response, for: response)
168+
if let response = try? decoder.decode(AnyResponse.self, from: data),
169+
let request = pendingRequests[response.id]
170+
{
171+
await handleResponse(response, for: request)
134172
} else if let message = try? decoder.decode(AnyMessage.self, from: data) {
135173
await handleMessage(message)
136174
} else {
@@ -158,11 +196,7 @@ public actor Client {
158196
public func disconnect() async {
159197
// Cancel all pending requests
160198
for (id, request) in pendingRequests {
161-
// We know this cast is safe because we only store PendingRequest values
162-
if let typedRequest = request as? PendingRequest<Any> {
163-
typedRequest.continuation.resume(
164-
throwing: Error.internalError("Client disconnected"))
165-
}
199+
request.resume(throwing: Error.internalError("Client disconnected"))
166200
pendingRequests.removeValue(forKey: id)
167201
}
168202

@@ -220,12 +254,12 @@ public actor Client {
220254
}
221255
}
222256

223-
private func addPendingRequest<T>(
257+
private func addPendingRequest<T: Sendable & Decodable>(
224258
id: ID,
225259
continuation: CheckedContinuation<T, Swift.Error>,
226260
type: T.Type
227261
) {
228-
pendingRequests[id] = PendingRequest(continuation: continuation)
262+
pendingRequests[id] = AnyPendingRequest(PendingRequest(continuation: continuation))
229263
}
230264

231265
private func removePendingRequest(id: ID) {
@@ -320,19 +354,18 @@ public actor Client {
320354

321355
// MARK: -
322356

323-
private func handleResponse(_ response: Response<AnyMethod>, for request: Any) async {
357+
private func handleResponse(_ response: Response<AnyMethod>, for request: AnyPendingRequest)
358+
async
359+
{
324360
await logger?.debug(
325361
"Processing response",
326362
metadata: ["id": "\(response.id)"])
327363

328-
// We know this cast is safe because we only store PendingRequest values
329-
guard let typedRequest = request as? PendingRequest<Any> else { return }
330-
331364
switch response.result {
332365
case .success(let value):
333-
typedRequest.continuation.resume(returning: value)
366+
request.resume(returning: value)
334367
case .failure(let error):
335-
typedRequest.continuation.resume(throwing: error)
368+
request.resume(throwing: error)
336369
}
337370

338371
removePendingRequest(id: response.id)

Tests/MCPTests/RoundtripTests.swift

Lines changed: 72 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ import Testing
88
@Suite("Roundtrip Tests")
99
struct RoundtripTests {
1010
@Test(
11-
"Initialize roundtrip",
1211
.timeLimit(.minutes(1))
1312
)
14-
func testInitializeRoundtrip() async throws {
13+
func testRoundtrip() async throws {
1514
let (clientToServerRead, clientToServerWrite) = try FileDescriptor.pipe()
1615
let (serverToClientRead, serverToClientWrite) = try FileDescriptor.pipe()
1716

@@ -36,32 +35,83 @@ struct RoundtripTests {
3635
version: "1.0.0",
3736
capabilities: .init(prompts: .init(), tools: .init())
3837
)
38+
await server.withMethodHandler(ListTools.self) { _ in
39+
return ListTools.Result(tools: [
40+
Tool(
41+
name: "add",
42+
description: "Adds two numbers together",
43+
inputSchema: [
44+
"a": ["type": "integer", "description": "The first number"],
45+
"a": ["type": "integer", "description": "The second number"],
46+
])
47+
])
48+
}
49+
await server.withMethodHandler(CallTool.self) { request in
50+
guard request.name == "add" else {
51+
return CallTool.Result(content: [.text("Invalid tool name")], isError: true)
52+
}
53+
54+
guard let a = request.arguments?["a"]?.intValue,
55+
let b = request.arguments?["b"]?.intValue
56+
else {
57+
return CallTool.Result(
58+
content: [.text("Did not receive valid arguments")], isError: true)
59+
}
60+
61+
return CallTool.Result(content: [.text("\(a + b)")], isError: false)
62+
}
63+
3964
let client = Client(name: "TestClient", version: "1.0")
4065

4166
try await server.start(transport: serverTransport)
4267
try await client.connect(transport: clientTransport)
4368

44-
// let initTask = Task {
45-
// let result = try await client.initialize()
69+
let initTask = Task {
70+
let result = try await client.initialize()
71+
72+
#expect(result.serverInfo.name == "TestServer")
73+
#expect(result.serverInfo.version == "1.0.0")
74+
#expect(result.capabilities.prompts != nil)
75+
#expect(result.capabilities.tools != nil)
76+
#expect(result.protocolVersion == Version.latest)
77+
}
78+
try await withThrowingTaskGroup(of: Void.self) { group in
79+
group.addTask {
80+
try await Task.sleep(for: .seconds(1))
81+
initTask.cancel()
82+
throw CancellationError()
83+
}
84+
group.addTask {
85+
try await initTask.value
86+
}
87+
try await group.next()
88+
group.cancelAll()
89+
}
90+
91+
let listToolsTask = Task {
92+
let result = try await client.listTools()
93+
#expect(result.count == 1)
94+
#expect(result[0].name == "add")
95+
}
96+
97+
let callToolTask = Task {
98+
let result = try await client.callTool(name: "add", arguments: ["a": 1, "b": 2])
99+
#expect(result.isError == false)
100+
#expect(result.content == [.text("3")])
101+
}
46102

47-
// #expect(result.serverInfo.name == "TestServer")
48-
// #expect(result.serverInfo.version == "1.0.0")
49-
// #expect(result.capabilities.prompts != nil)
50-
// #expect(result.capabilities.tools != nil)
51-
// #expect(result.protocolVersion == Version.latest)
52-
// }
53-
// try await withThrowingTaskGroup(of: Void.self) { group in
54-
// group.addTask {
55-
// try await Task.sleep(for: .seconds(1))
56-
// initTask.cancel()
57-
// throw CancellationError()
58-
// }
59-
// group.addTask {
60-
// try await initTask.value
61-
// }
62-
// try await group.next()
63-
// group.cancelAll()
64-
// }
103+
try await withThrowingTaskGroup(of: Void.self) { group in
104+
group.addTask {
105+
try await Task.sleep(for: .seconds(1))
106+
listToolsTask.cancel()
107+
throw CancellationError()
108+
}
109+
group.addTask {
110+
try await callToolTask.value
111+
}
112+
try await group.next()
113+
group.cancelAll()
114+
}
65115

66116
await server.stop()
67117
await client.disconnect()

0 commit comments

Comments
 (0)