Skip to content

Commit a9bcbf9

Browse files
authored
Update transports to send and receive data instead of strings (modelcontextprotocol#42)
* Update transports to send and receive data instead of strings * Rename MockTransport helper methods * Adopt Task.sleep(for:...) instead of variant taking nanoseconds * Update comment * Revert inadvertant change to visibility of send method
1 parent de829ee commit a9bcbf9

File tree

7 files changed

+138
-147
lines changed

7 files changed

+138
-147
lines changed

Sources/MCP/Base/Transports.swift

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ public protocol Transport: Actor {
1919
/// Disconnects from the transport
2020
func disconnect() async
2121

22-
/// Sends a message string
23-
func send(_ message: String) async throws
22+
/// Sends data
23+
func send(_ data: Data) async throws
2424

25-
/// Receives message strings as an async sequence
26-
func receive() -> AsyncThrowingStream<String, Swift.Error>
25+
/// Receives data in an async sequence
26+
func receive() -> AsyncThrowingStream<Data, Swift.Error>
2727
}
2828

2929
/// Standard input/output transport implementation
@@ -33,8 +33,8 @@ public actor StdioTransport: Transport {
3333
public nonisolated let logger: Logger
3434

3535
private var isConnected = false
36-
private let messageStream: AsyncStream<String>
37-
private let messageContinuation: AsyncStream<String>.Continuation
36+
private let messageStream: AsyncStream<Data>
37+
private let messageContinuation: AsyncStream<Data>.Continuation
3838

3939
public init(
4040
input: FileDescriptor = FileDescriptor.standardInput,
@@ -50,7 +50,7 @@ public actor StdioTransport: Transport {
5050
factory: { _ in SwiftLogNoOpLogHandler() })
5151

5252
// Create message stream
53-
var continuation: AsyncStream<String>.Continuation!
53+
var continuation: AsyncStream<Data>.Continuation!
5454
self.messageStream = AsyncStream { continuation = $0 }
5555
self.messageContinuation = continuation
5656
}
@@ -105,15 +105,13 @@ public actor StdioTransport: Transport {
105105
let messageData = pendingData[..<newlineIndex]
106106
pendingData = pendingData[(newlineIndex + 1)...]
107107

108-
if let message = String(data: messageData, encoding: .utf8),
109-
!message.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
110-
{
111-
logger.debug("Message received", metadata: ["message": "\(message)"])
112-
messageContinuation.yield(message)
108+
if !messageData.isEmpty {
109+
logger.debug("Message received", metadata: ["size": "\(messageData.count)"])
110+
messageContinuation.yield(Data(messageData))
113111
}
114112
}
115113
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
116-
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms backoff
114+
try? await Task.sleep(for: .milliseconds(10))
117115
continue
118116
} catch {
119117
if !Task.isCancelled {
@@ -133,17 +131,16 @@ public actor StdioTransport: Transport {
133131
logger.info("Transport disconnected")
134132
}
135133

136-
public func send(_ message: String) async throws {
134+
public func send(_ message: Data) async throws {
137135
guard isConnected else {
138136
throw Error.transportError(Errno.socketNotConnected)
139137
}
140138

141-
let message = message + "\n"
142-
guard let data = message.data(using: .utf8) else {
143-
throw Error.transportError(Errno.invalidArgument)
144-
}
139+
// Add newline as delimiter
140+
var messageWithNewline = message
141+
messageWithNewline.append(UInt8(ascii: "\n"))
145142

146-
var remaining = data
143+
var remaining = messageWithNewline
147144
while !remaining.isEmpty {
148145
do {
149146
let written = try remaining.withUnsafeBytes { buffer in
@@ -153,15 +150,15 @@ public actor StdioTransport: Transport {
153150
remaining = remaining.dropFirst(written)
154151
}
155152
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
156-
try await Task.sleep(nanoseconds: 10_000_000) // 10ms backoff
153+
try await Task.sleep(for: .milliseconds(10))
157154
continue
158155
} catch {
159156
throw Error.transportError(error)
160157
}
161158
}
162159
}
163160

164-
public func receive() -> AsyncThrowingStream<String, Swift.Error> {
161+
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
165162
return AsyncThrowingStream { continuation in
166163
Task {
167164
for await message in messageStream {
@@ -182,8 +179,8 @@ public actor StdioTransport: Transport {
182179
public nonisolated let logger: Logger
183180

184181
private var isConnected = false
185-
private let messageStream: AsyncThrowingStream<String, Swift.Error>
186-
private let messageContinuation: AsyncThrowingStream<String, Swift.Error>.Continuation
182+
private let messageStream: AsyncThrowingStream<Data, Swift.Error>
183+
private let messageContinuation: AsyncThrowingStream<Data, Swift.Error>.Continuation
187184

188185
// Track connection state for continuations
189186
private var connectionContinuationResumed = false
@@ -198,7 +195,7 @@ public actor StdioTransport: Transport {
198195
)
199196

200197
// Create message stream
201-
var continuation: AsyncThrowingStream<String, Swift.Error>.Continuation!
198+
var continuation: AsyncThrowingStream<Data, Swift.Error>.Continuation!
202199
self.messageStream = AsyncThrowingStream { continuation = $0 }
203200
self.messageContinuation = continuation
204201
}
@@ -289,14 +286,14 @@ public actor StdioTransport: Transport {
289286
logger.info("Network transport disconnected")
290287
}
291288

292-
public func send(_ message: String) async throws {
289+
public func send(_ message: Data) async throws {
293290
guard isConnected else {
294291
throw MCP.Error.internalError("Transport not connected")
295292
}
296293

297-
guard let data = (message + "\n").data(using: .utf8) else {
298-
throw MCP.Error.internalError("Failed to encode message")
299-
}
294+
// Add newline as delimiter
295+
var messageWithNewline = message
296+
messageWithNewline.append(UInt8(ascii: "\n"))
300297

301298
// Use a local actor-isolated variable to track continuation state
302299
var sendContinuationResumed = false
@@ -309,7 +306,7 @@ public actor StdioTransport: Transport {
309306
}
310307

311308
connection.send(
312-
content: data,
309+
content: messageWithNewline,
313310
completion: .contentProcessed { [weak self] error in
314311
guard let self = self else { return }
315312

@@ -329,7 +326,7 @@ public actor StdioTransport: Transport {
329326
}
330327
}
331328

332-
public func receive() -> AsyncThrowingStream<String, Swift.Error> {
329+
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
333330
return AsyncThrowingStream { continuation in
334331
Task {
335332
do {
@@ -357,11 +354,10 @@ public actor StdioTransport: Transport {
357354
let messageData = buffer[..<newlineIndex]
358355
buffer = buffer[(newlineIndex + 1)...]
359356

360-
if let message = String(data: messageData, encoding: .utf8),
361-
!message.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty
362-
{
363-
logger.debug("Message received", metadata: ["message": "\(message)"])
364-
messageContinuation.yield(message)
357+
if !messageData.isEmpty {
358+
logger.debug(
359+
"Message received", metadata: ["size": "\(messageData.count)"])
360+
messageContinuation.yield(Data(messageData))
365361
}
366362
}
367363
} catch let error as NWError {

Sources/MCP/Client/Client.swift

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -181,15 +181,10 @@ public actor Client {
181181

182182
do {
183183
let stream = await connection.receive()
184-
for try await string in stream {
184+
for try await data in stream {
185185
if Task.isCancelled { break } // Check inside loop too
186186

187-
// Decode and handle incoming message
188-
guard let data = string.data(using: .utf8) else {
189-
throw Error.parseError("Invalid UTF-8 data")
190-
}
191-
192-
// Attempt to decode string data as AnyResponse or AnyMessage
187+
// Attempt to decode data as AnyResponse or AnyMessage
193188
let decoder = JSONDecoder()
194189
if let response = try? decoder.decode(AnyResponse.self, from: data),
195190
let request = pendingRequests[response.id]
@@ -207,7 +202,7 @@ public actor Client {
207202
}
208203
}
209204
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
210-
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms
205+
try? await Task.sleep(for: .milliseconds(10))
211206
continue
212207
} catch {
213208
await logger?.error(
@@ -256,22 +251,19 @@ public actor Client {
256251
}
257252

258253
let requestData = try JSONEncoder().encode(request)
259-
guard let requestString = String(data: requestData, encoding: .utf8) else {
260-
throw Error.internalError("Failed to encode request")
261-
}
262254

255+
// Store the pending request first
263256
return try await withCheckedThrowingContinuation { continuation in
264-
// Store the pending request first
265257
Task {
266258
self.addPendingRequest(
267259
id: request.id,
268260
continuation: continuation,
269261
type: M.Result.self
270262
)
271263

272-
// Send the request
264+
// Send the request data
273265
do {
274-
try await connection.send(requestString)
266+
try await connection.send(requestData)
275267
} catch {
276268
continuation.resume(throwing: error)
277269
self.removePendingRequest(id: request.id)

Sources/MCP/Server/Server.swift

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,11 @@ public actor Server {
172172
task = Task {
173173
do {
174174
let stream = await transport.receive()
175-
for try await string in stream {
175+
for try await data in stream {
176176
if Task.isCancelled { break } // Check cancellation inside loop
177177

178178
var requestID: ID?
179179
do {
180-
guard let data = string.data(using: .utf8) else {
181-
throw Error.parseError("Invalid UTF-8 data")
182-
}
183-
184180
// Attempt to decode string data as AnyRequest or AnyMessage
185181
let decoder = JSONDecoder()
186182
if let request = try? decoder.decode(AnyRequest.self, from: data) {
@@ -203,7 +199,7 @@ public actor Server {
203199
}
204200
} catch let error where Error.isResourceTemporarilyUnavailable(error) {
205201
// Resource temporarily unavailable, retry after a short delay
206-
try? await Task.sleep(nanoseconds: 10_000_000) // 10ms
202+
try? await Task.sleep(for: .milliseconds(10))
207203
continue
208204
} catch {
209205
await logger?.error(
@@ -266,19 +262,17 @@ public actor Server {
266262

267263
// MARK: - Sending
268264

269-
/// Send a response to a client
265+
/// Send a response to a request
270266
public func send<M: Method>(_ response: Response<M>) async throws {
271267
guard let connection = connection else {
272268
throw Error.internalError("Server connection not initialized")
273269
}
270+
274271
let encoder = JSONEncoder()
275272
encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes]
276273

277274
let responseData = try encoder.encode(response)
278-
279-
if let responseStr = String(data: responseData, encoding: .utf8) {
280-
try await connection.send(responseStr)
281-
}
275+
try await connection.send(responseData)
282276
}
283277

284278
/// Send a notification to connected clients
@@ -291,10 +285,7 @@ public actor Server {
291285
encoder.outputFormatting = [.sortedKeys, .withoutEscapingSlashes]
292286

293287
let notificationData = try encoder.encode(notification)
294-
295-
if let notificationStr = String(data: notificationData, encoding: .utf8) {
296-
try await connection.send(notificationStr)
297-
}
288+
try await connection.send(notificationData)
298289
}
299290

300291
// MARK: -
@@ -407,7 +398,7 @@ public actor Server {
407398

408399
// Send initialized notification after a short delay
409400
Task {
410-
try? await Task.sleep(nanoseconds: 100_000_000) // 100ms
401+
try? await Task.sleep(for: .milliseconds(10))
411402
try? await self.notify(InitializedNotification.message())
412403
}
413404

Tests/MCPTests/ClientTests.swift

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,27 @@ struct ClientTests {
2727

2828
try await client.connect(transport: transport)
2929
// Small delay to ensure message loop is started
30-
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
30+
try await Task.sleep(for: .milliseconds(10))
3131

3232
// Create a task for initialize that we'll cancel
3333
let initTask = Task {
3434
try await client.initialize()
3535
}
3636

3737
// Give it a moment to send the request
38-
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
38+
try await Task.sleep(for: .milliseconds(10))
3939

4040
#expect(await transport.sentMessages.count == 1)
41-
#expect(await transport.sentMessages[0].contains(Initialize.name))
42-
#expect(await transport.sentMessages[0].contains(client.name))
43-
#expect(await transport.sentMessages[0].contains(client.version))
41+
#expect(await transport.sentMessages.first?.contains(Initialize.name) == true)
42+
#expect(await transport.sentMessages.first?.contains(client.name) == true)
43+
#expect(await transport.sentMessages.first?.contains(client.version) == true)
4444

4545
// Cancel the initialize task
4646
initTask.cancel()
4747

4848
// Disconnect client to clean up message loop and give time for continuation cleanup
4949
await client.disconnect()
50-
try await Task.sleep(nanoseconds: 50_000_000) // 50ms
50+
try await Task.sleep(for: .milliseconds(50))
5151
}
5252

5353
@Test(
@@ -60,25 +60,25 @@ struct ClientTests {
6060

6161
try await client.connect(transport: transport)
6262
// Small delay to ensure message loop is started
63-
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
63+
try await Task.sleep(for: .milliseconds(10))
6464

6565
// Create a task for the ping that we'll cancel
6666
let pingTask = Task {
6767
try await client.ping()
6868
}
6969

7070
// Give it a moment to send the request
71-
try await Task.sleep(nanoseconds: 10_000_000) // 10ms
71+
try await Task.sleep(for: .milliseconds(10))
7272

7373
#expect(await transport.sentMessages.count == 1)
74-
#expect(await transport.sentMessages[0].contains(Ping.name))
74+
#expect(await transport.sentMessages.first?.contains(Ping.name) == true)
7575

7676
// Cancel the ping task
7777
pingTask.cancel()
7878

7979
// Disconnect client to clean up message loop and give time for continuation cleanup
8080
await client.disconnect()
81-
try await Task.sleep(nanoseconds: 50_000_000) // 50ms
81+
try await Task.sleep(for: .milliseconds(50))
8282
}
8383

8484
@Test("Connection failure handling")
@@ -168,7 +168,7 @@ struct ClientTests {
168168

169169
// Wait a bit for any setup to complete
170170
try await Task.sleep(for: .milliseconds(10))
171-
171+
172172
// Send the listPrompts request and immediately provide an error response
173173
let promptsTask = Task {
174174
do {
@@ -187,7 +187,7 @@ struct ClientTests {
187187
id: decodedRequest.id,
188188
error: Error.methodNotFound("Test: Prompts capability not available")
189189
)
190-
try await transport.queueResponse(errorResponse)
190+
try await transport.queue(response: errorResponse)
191191

192192
// Try the request now that we have a response queued
193193
do {

0 commit comments

Comments
 (0)