Skip to content

Implement cancellation notifications #146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/ModelContextProtocol/Logging/Log.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ internal static partial class Log
[LoggerMessage(Level = LogLevel.Error, Message = "Request failed for {endpointName} with method {method}: {message} ({code})")]
internal static partial void RequestFailed(this ILogger logger, string endpointName, string method, string message, int code);

[LoggerMessage(Level = LogLevel.Information, Message = "Request '{requestId}' canceled via client notification with reason '{Reason}'.")]
internal static partial void RequestCanceled(this ILogger logger, RequestId requestId, string? reason);

[LoggerMessage(Level = LogLevel.Information, Message = "Request response received payload for {endpointName}: {payload}")]
internal static partial void RequestResponseReceivedPayload(this ILogger logger, string endpointName, string payload);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Protocol.Messages;

/// <summary>
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
/// </summary>
public sealed class CancelledNotification
{
/// <summary>
/// The ID of the request to cancel.
/// </summary>
[JsonPropertyName("requestId")]
public RequestId RequestId { get; set; }

/// <summary>
/// An optional string describing the reason for the cancellation.
/// </summary>
[JsonPropertyName("reason")]
public string? Reason { get; set; }
}
165 changes: 127 additions & 38 deletions src/ModelContextProtocol/Shared/McpSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ internal sealed class McpSession : IDisposable
private readonly RequestHandlers _requestHandlers;
private readonly NotificationHandlers _notificationHandlers;

/// <summary>Collection of requests sent on this session and waiting for responses.</summary>
private readonly ConcurrentDictionary<RequestId, TaskCompletionSource<IJsonRpcMessage>> _pendingRequests = [];
/// <summary>
/// Collection of requests received on this session and currently being handled. The value provides a <see cref="CancellationTokenSource"/>
/// that can be used to request cancellation of the in-flight handler.
/// </summary>
private readonly ConcurrentDictionary<RequestId, CancellationTokenSource> _handlingRequests = new();
private readonly JsonSerializerOptions _jsonOptions;
private readonly ILogger _logger;

Expand Down Expand Up @@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
{
_logger.TransportMessageRead(EndpointName, message.GetType().Name);

// Fire and forget the message handling task to avoid blocking the transport
// If awaiting the task, the transport will not be able to read more messages,
// which could lead to a deadlock if the handler sends a message back
_ = ProcessMessageAsync();
async Task ProcessMessageAsync()
{
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
CancellationTokenSource? combinedCts = null;
try
{
// Register before we yield, so that the tracking is guaranteed to be there
// when subsequent messages arrive, even if the asynchronous processing happens
// out of order.
if (messageWithId is not null)
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_handlingRequests[messageWithId.Id] = combinedCts;
}

// Fire and forget the message handling to avoid blocking the transport
// If awaiting the task, the transport will not be able to read more messages,
// which could lead to a deadlock if the handler sends a message back

#if NET
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
#else
await default(ForceYielding);
await default(ForceYielding);
#endif
try
{
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);

// Handle the message.
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
// Only send responses for request errors that aren't user-initiated cancellation.
bool isUserCancellation =
ex is OperationCanceledException &&
!cancellationToken.IsCancellationRequested &&
combinedCts?.IsCancellationRequested is true;

if (!isUserCancellation && message is JsonRpcRequest request)
{
_logger.RequestHandlerError(EndpointName, request.Method, ex);
await _transport.SendMessageAsync(new JsonRpcError
{
Id = request.Id,
JsonRpc = "2.0",
Error = new JsonRpcErrorDetail
{
Code = ErrorCodes.InternalError,
Message = ex.Message
}
}, cancellationToken).ConfigureAwait(false);
}
else if (ex is not OperationCanceledException)
{
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
}
}
finally
{
if (messageWithId is not null)
{
_handlingRequests.TryRemove(messageWithId.Id, out _);
combinedCts!.Dispose();
}
}
}
}
Expand Down Expand Up @@ -123,6 +174,25 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken

private async Task HandleNotification(JsonRpcNotification notification)
{
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
if (notification.Method == NotificationMethods.CancelledNotification)
{
try
{
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
_handlingRequests.TryGetValue(cn.RequestId, out var cts))
{
await cts.CancelAsync().ConfigureAwait(false);
_logger.RequestCanceled(cn.RequestId, cn.Reason);
}
}
catch
{
// "Invalid cancellation notifications SHOULD be ignored"
}
}

// Handle user-defined notifications.
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
{
foreach (var notificationHandler in handlers)
Expand Down Expand Up @@ -161,33 +231,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
{
if (_requestHandlers.TryGetValue(request.Method, out var handler))
{
try
{
_logger.RequestHandlerCalled(EndpointName, request.Method);
var result = await handler(request, cancellationToken).ConfigureAwait(false);
_logger.RequestHandlerCompleted(EndpointName, request.Method);
await _transport.SendMessageAsync(new JsonRpcResponse
{
Id = request.Id,
JsonRpc = "2.0",
Result = result
}, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
_logger.RequestHandlerCalled(EndpointName, request.Method);
var result = await handler(request, cancellationToken).ConfigureAwait(false);
_logger.RequestHandlerCompleted(EndpointName, request.Method);
await _transport.SendMessageAsync(new JsonRpcResponse
{
_logger.RequestHandlerError(EndpointName, request.Method, ex);
// Send error response
await _transport.SendMessageAsync(new JsonRpcError
{
Id = request.Id,
JsonRpc = "2.0",
Error = new JsonRpcErrorDetail
{
Code = -32000, // Implementation defined error
Message = ex.Message
}
}, cancellationToken).ConfigureAwait(false);
}
Id = request.Id,
JsonRpc = "2.0",
Result = result
}, cancellationToken).ConfigureAwait(false);
}
else
{
Expand Down Expand Up @@ -273,7 +325,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
}
}

public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);

Expand All @@ -288,7 +340,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
}

return _transport.SendMessageAsync(message, cancellationToken);
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);

// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
{
tcs.TrySetCanceled(default);
}
}

private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
{
try
{
switch (notificationParams)
{
case null:
return null;

case CancelledNotification cn:
return cn;

case JsonElement je:
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());

default:
return JsonSerializer.Deserialize(
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
}
}
catch
{
return null;
}
}

public void Dispose()
Expand Down
1 change: 1 addition & 0 deletions src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
// MCP Request Params / Results
[JsonSerializable(typeof(CallToolRequestParams))]
[JsonSerializable(typeof(CallToolResponse))]
[JsonSerializable(typeof(CancelledNotification))]
[JsonSerializable(typeof(CompleteRequestParams))]
[JsonSerializable(typeof(CompleteResult))]
[JsonSerializable(typeof(CreateMessageRequestParams))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public async Task Can_List_Registered_Tools()
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "Echo");
Assert.Equal("Echo", echoTool.Name);
Expand Down Expand Up @@ -138,7 +138,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
cancellationToken: TestContext.Current.CancellationToken))
{
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

McpClientTool echoTool = tools.First(t => t.Name == "Echo");
Assert.Equal("Echo", echoTool.Name);
Expand All @@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);

Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
Expand All @@ -186,7 +186,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await notificationRead;

tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(13, tools.Count);
Assert.Equal(14, tools.Count);
Assert.Contains(tools, t => t.Name == "NewTool");

notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
Expand All @@ -195,7 +195,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
await notificationRead;

tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.Equal(12, tools.Count);
Assert.Equal(13, tools.Count);
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
}

Expand Down Expand Up @@ -560,6 +560,35 @@ public async Task HandlesIProgressParameter()
}
}

[Fact]
public async Task CancellationNotificationsPropagateToToolTokens()
{
IMcpClient client = await CreateMcpClientForServer();

var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
Assert.NotNull(tools);
Assert.NotEmpty(tools);
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));

var requestId = new RequestId(Guid.NewGuid().ToString());
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
{
Method = RequestMethods.ToolsCall,
Id = requestId,
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
}, TestContext.Current.CancellationToken);

await client.SendNotificationAsync(
NotificationMethods.CancelledNotification,
parameters: new CancelledNotification()
{
RequestId = requestId,
},
cancellationToken: TestContext.Current.CancellationToken);

await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
}

[McpServerToolType]
public sealed class EchoTool(ObjectWithId objectFromDI)
{
Expand Down Expand Up @@ -625,6 +654,21 @@ public static string EchoComplex(ComplexObject complex)
return complex.Name!;
}

[McpServerTool]
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
{
try
{
await Task.Delay(Timeout.Infinite, cancellationToken);
}
catch (Exception)
{
return "canceled";
}

return "unreachable";
}

[McpServerTool]
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";

Expand Down