Skip to content

Add timeout parameter to Client.send method (#35) #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion Sources/MCP/Base/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public enum MCPError: Swift.Error, Sendable {
// Transport specific errors
case connectionClosed
case transportError(Swift.Error)
case requestTimedOut(String?)

/// The JSON-RPC 2.0 error code
public var code: Int {
Expand All @@ -33,6 +34,7 @@ public enum MCPError: Swift.Error, Sendable {
case .serverError(let code, _): return code
case .connectionClosed: return -32000
case .transportError: return -32001
case .requestTimedOut: return -32002
}
}

Expand Down Expand Up @@ -72,6 +74,8 @@ extension MCPError: LocalizedError {
return "Connection closed"
case .transportError(let error):
return "Transport error: \(error.localizedDescription)"
case .requestTimedOut(let detail):
return "Request timed out" + (detail.map { ": \($0)" } ?? "")
}
}

Expand All @@ -93,6 +97,8 @@ extension MCPError: LocalizedError {
return "The connection to the server was closed"
case .transportError(let error):
return (error as? LocalizedError)?.failureReason ?? error.localizedDescription
case .requestTimedOut:
return "Request exceeded the client-side timeout duration, default time is 10 seconds"
}
}

Expand All @@ -108,6 +114,8 @@ extension MCPError: LocalizedError {
return "Verify the parameters match the method's expected parameters"
case .connectionClosed:
return "Try reconnecting to the server"
case .requestTimedOut:
return "Try sending the request again, or increase the timeout if necessary"
default:
return nil
}
Expand Down Expand Up @@ -147,7 +155,8 @@ extension MCPError: Codable {
.invalidRequest(let detail),
.methodNotFound(let detail),
.invalidParams(let detail),
.internalError(let detail):
.internalError(let detail),
.requestTimedOut(let detail):
if let detail = detail {
try container.encode(["detail": detail], forKey: .data)
}
Expand Down Expand Up @@ -204,6 +213,8 @@ extension MCPError: Codable {
userInfo: [NSLocalizedDescriptionKey: underlyingErrorString]
)
)
case -32002:
self = .requestTimedOut(unwrapDetail(message))
default:
self = .serverError(code: code, message: message)
}
Expand Down Expand Up @@ -240,6 +251,8 @@ extension MCPError: Hashable {
break
case .transportError(let error):
hasher.combine(error.localizedDescription)
case .requestTimedOut(let detail):
hasher.combine(detail)
}
}
}
Expand Down
55 changes: 43 additions & 12 deletions Sources/MCP/Client/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ public actor Client {
// MARK: - Requests

/// Send a request and receive its response
public func send<M: Method>(_ request: Request<M>) async throws -> M.Result {
public func send<M: Method>(_ request: Request<M>, timeout: Duration = .seconds(10.0))
async throws -> M.Result
{
guard let connection = connection else {
throw MCPError.internalError("Client connection not initialized")
}
Expand All @@ -262,21 +264,46 @@ public actor Client {

// Store the pending request first
return try await withCheckedThrowingContinuation { continuation in
Task {
self.addPendingRequest(
id: request.id,
continuation: continuation,
type: M.Result.self
)
self.addPendingRequest(
id: request.id,
continuation: continuation,
type: M.Result.self
)

// Send the request data
var sendRequestTask: Task<Void, Never>? = nil

// Send the request data
// A timeout task is created to remove a request if it is still pending after time out duration
var timeoutTask: Task<Void, Never>? = nil

sendRequestTask = Task {
do {
// Use the existing connection send
try await connection.send(requestData)
} catch {
// If send fails immediately, resume continuation and remove pending request
continuation.resume(throwing: error)
// If send fails immediately, remove pending request and cancel timeout task
self.removePendingRequest(id: request.id) // Ensure cleanup on send error
timeoutTask?.cancel()
continuation.resume(throwing: error)
}
}

timeoutTask = Task {
do {
try await Task.sleep(until: .now + timeout)

// If timed out, remove pending request and cancel send request task
if self.pendingRequests.keys.contains(request.id) {
self.removePendingRequest(id: request.id) // Ensure cleanup on send error
sendRequestTask?.cancel()
continuation.resume(
throwing: MCPError.requestTimedOut(
"Request timed out after \(timeout)"
)
)
}
} catch {
// Do nothing here if the task is cancaled
}
}
}
Expand Down Expand Up @@ -457,9 +484,13 @@ public actor Client {
return result
}

public func ping() async throws {
public func ping(timeout: Duration? = nil) async throws {
let request = Ping.request()
_ = try await send(request)
if let timeout {
_ = try await send(request, timeout: timeout)
} else {
_ = try await send(request)
}
}

// MARK: - Prompts
Expand Down
103 changes: 103 additions & 0 deletions Tests/MCPTests/ClientTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -363,4 +363,107 @@ struct ClientTests {

await client.disconnect()
}

@Test("Request timeout - request should time out if server does not respond")
func testRequestTimesOut() async throws {
let transport = MockTransport()
let client = Client(name: "TestClient", version: "1.0")

try await client.connect(transport: transport)
do {
// Do not queue any response on the transport
// so the client never receives a response
try await client.ping(timeout: .milliseconds(100))
#expect(Bool(false), "Expected request to time out, but it succeeded")

} catch let error as MCPError {
switch error {
case .requestTimedOut(let detail):
// This is the expected error
#expect(Bool(true), "Got requestTimedOut as expected: \(detail ?? "")")
default:
// If it is a different MCPError, fail
#expect(Bool(false), "Expected requestTimedOut, got \(error)")
}
} catch {
#expect(Bool(false), "Expected an MCPError, but got \(error)")
}

await client.disconnect()
}

@Test("Request timeout - request should time out if server responds too late")
func testRequestTimesOutIfResponseIsLate() async throws {
let transport = MockTransport()
let client = Client(name: "TestClient", version: "1.0")

try await client.connect(transport: transport)

// Prepare a ping which will be sent with a short 100ms timeout
let request = Ping.request()

// Prepare a task to queue a response after 200ms
// which is beyond the 100ms timeout
Task {
try? await Task.sleep(for: .milliseconds(200))
let response = Response<Ping>(id: request.id, result: .init())
let anyResponse = try? AnyResponse(response)
if let anyResponse {
try? await transport.queue(response: anyResponse)
} else {
#expect(Bool(false), "Failed to produce any response")
}
}

do {
try await _ = client.send(request, timeout: .milliseconds(100))
#expect(Bool(false), "Expected request to time out, but it succeeded")
} catch let error as MCPError {
switch error {
case .requestTimedOut(let detail):
// This is the expected error
#expect(Bool(true), "Got requestTimedOut as expected: \(detail ?? "")")
default:
// If it is a different MCPError, fail
#expect(Bool(false), "Expected requestTimedOut, got \(error)")
}
} catch {
#expect(Bool(false), "Expected an MCPError, but got \(error)")
}

await client.disconnect()
}

@Test("Request timeout - request should succeed if server responds before timeout")
func testRequestDoesNotTimeOutIfResponseIsFast() async throws {
let transport = MockTransport()
let client = Client(name: "TestClient", version: "1.0")

try await client.connect(transport: transport)

// Prepare a ping which will be sent with a short 200ms timeout
let request = Ping.request()

// Prepare a task to queue a response after 100ms
// which is less than the specified timeout (200ms)
Task {
try? await Task.sleep(for: .milliseconds(100))
let response = Response<Ping>(id: request.id, result: .init())
let anyResponse = try? AnyResponse(response)
if let anyResponse {
try? await transport.queue(response: anyResponse)
} else {
#expect(Bool(false), "Failed to produce any response")
}
}

do {
_ = try await client.send(request, timeout: .milliseconds(200))
#expect(Bool(true), "Request succeeded before timeout")
} catch let error {
#expect(Bool(false), "Did not expect an error here, but got \(error)")
}

await client.disconnect()
}
}