Skip to content

Commit b87bb40

Browse files
committed
Implement cancellation notifications
1 parent 5d3fb65 commit b87bb40

File tree

4 files changed

+197
-43
lines changed

4 files changed

+197
-43
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System.Text.Json.Serialization;
2+
3+
namespace ModelContextProtocol.Protocol.Messages;
4+
5+
/// <summary>
6+
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
7+
/// </summary>
8+
public sealed class CancelledNotification
9+
{
10+
/// <summary>
11+
/// The ID of the request to cancel.
12+
/// </summary>
13+
[JsonPropertyName("requestId")]
14+
public RequestId RequestId { get; set; }
15+
16+
/// <summary>
17+
/// An optional string describing the reason for the cancellation.
18+
/// </summary>
19+
[JsonPropertyName("reason")]
20+
public string? Reason { get; set; }
21+
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ namespace ModelContextProtocol.Shared;
1616
/// </summary>
1717
internal sealed class McpSession : IDisposable
1818
{
19+
/// <summary>
20+
/// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
21+
/// that can be used to request cancellation of the in-flight handler.
22+
/// </summary>
23+
private static readonly ConcurrentDictionary<RequestId, CancellationTokenSource> s_handlingRequests = new();
24+
1925
private readonly ITransport _transport;
2026
private readonly RequestHandlers _requestHandlers;
2127
private readonly NotificationHandlers _notificationHandlers;
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
6975
{
7076
_logger.TransportMessageRead(EndpointName, message.GetType().Name);
7177

72-
// Fire and forget the message handling task to avoid blocking the transport
73-
// If awaiting the task, the transport will not be able to read more messages,
74-
// which could lead to a deadlock if the handler sends a message back
7578
_ = ProcessMessageAsync();
7679
async Task ProcessMessageAsync()
7780
{
81+
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
82+
CancellationTokenSource? combinedCts = null;
83+
try
84+
{
85+
// Register before we yield, so that the tracking is guaranteed to be there
86+
// when subsequent messages arrive, even if the asynchronous processing happens
87+
// out of order.
88+
if (messageWithId is not null)
89+
{
90+
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
91+
s_handlingRequests[messageWithId.Id] = combinedCts;
92+
}
93+
94+
// Fire and forget the message handling to avoid blocking the transport
95+
// If awaiting the task, the transport will not be able to read more messages,
96+
// which could lead to a deadlock if the handler sends a message back
97+
7898
#if NET
79-
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
99+
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
80100
#else
81-
await default(ForceYielding);
101+
await default(ForceYielding);
82102
#endif
83-
try
84-
{
85-
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);
103+
104+
// Handle the message.
105+
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
86106
}
87107
catch (Exception ex)
88108
{
89-
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
90-
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
109+
// Only send responses for request errors that aren't user-initiated cancellation.
110+
bool isUserCancellation =
111+
ex is OperationCanceledException &&
112+
!cancellationToken.IsCancellationRequested &&
113+
combinedCts?.IsCancellationRequested is true;
114+
115+
if (!isUserCancellation && message is JsonRpcRequest request)
116+
{
117+
_logger.RequestHandlerError(EndpointName, request.Method, ex);
118+
await _transport.SendMessageAsync(new JsonRpcError
119+
{
120+
Id = request.Id,
121+
JsonRpc = "2.0",
122+
Error = new JsonRpcErrorDetail
123+
{
124+
Code = ErrorCodes.InternalError,
125+
Message = ex.Message
126+
}
127+
}, cancellationToken).ConfigureAwait(false);
128+
}
129+
else if (ex is not OperationCanceledException)
130+
{
131+
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
132+
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
133+
}
134+
}
135+
finally
136+
{
137+
if (messageWithId is not null)
138+
{
139+
s_handlingRequests.TryRemove(messageWithId.Id, out _);
140+
combinedCts!.Dispose();
141+
}
91142
}
92143
}
93144
}
@@ -123,6 +174,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123174

124175
private async Task HandleNotification(JsonRpcNotification notification)
125176
{
177+
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178+
if (notification.Method == NotificationMethods.CancelledNotification)
179+
{
180+
try
181+
{
182+
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
183+
s_handlingRequests.TryGetValue(cn.RequestId, out var cts))
184+
{
185+
await cts.CancelAsync().ConfigureAwait(false);
186+
}
187+
}
188+
catch
189+
{
190+
// "Invalid cancellation notifications SHOULD be ignored"
191+
}
192+
}
193+
194+
// Handle user-defined notifications.
126195
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
127196
{
128197
foreach (var notificationHandler in handlers)
@@ -161,33 +230,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161230
{
162231
if (_requestHandlers.TryGetValue(request.Method, out var handler))
163232
{
164-
try
233+
_logger.RequestHandlerCalled(EndpointName, request.Method);
234+
var result = await handler(request, cancellationToken).ConfigureAwait(false);
235+
_logger.RequestHandlerCompleted(EndpointName, request.Method);
236+
await _transport.SendMessageAsync(new JsonRpcResponse
165237
{
166-
_logger.RequestHandlerCalled(EndpointName, request.Method);
167-
var result = await handler(request, cancellationToken).ConfigureAwait(false);
168-
_logger.RequestHandlerCompleted(EndpointName, request.Method);
169-
await _transport.SendMessageAsync(new JsonRpcResponse
170-
{
171-
Id = request.Id,
172-
JsonRpc = "2.0",
173-
Result = result
174-
}, cancellationToken).ConfigureAwait(false);
175-
}
176-
catch (Exception ex)
177-
{
178-
_logger.RequestHandlerError(EndpointName, request.Method, ex);
179-
// Send error response
180-
await _transport.SendMessageAsync(new JsonRpcError
181-
{
182-
Id = request.Id,
183-
JsonRpc = "2.0",
184-
Error = new JsonRpcErrorDetail
185-
{
186-
Code = -32000, // Implementation defined error
187-
Message = ex.Message
188-
}
189-
}, cancellationToken).ConfigureAwait(false);
190-
}
238+
Id = request.Id,
239+
JsonRpc = "2.0",
240+
Result = result
241+
}, cancellationToken).ConfigureAwait(false);
191242
}
192243
else
193244
{
@@ -273,7 +324,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273324
}
274325
}
275326

276-
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
327+
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
277328
{
278329
Throw.IfNull(message);
279330

@@ -288,7 +339,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288339
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
289340
}
290341

291-
return _transport.SendMessageAsync(message, cancellationToken);
342+
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
343+
344+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
345+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
346+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
347+
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
348+
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
349+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
350+
{
351+
tcs.TrySetCanceled(default);
352+
}
353+
}
354+
355+
private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
356+
{
357+
try
358+
{
359+
switch (notificationParams)
360+
{
361+
case null:
362+
return null;
363+
364+
case CancelledNotification cn:
365+
return cn;
366+
367+
case JsonElement je:
368+
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
369+
370+
default:
371+
return JsonSerializer.Deserialize(
372+
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
373+
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
374+
}
375+
}
376+
catch
377+
{
378+
return null;
379+
}
292380
}
293381

294382
public void Dispose()

src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
121121
// MCP Request Params / Results
122122
[JsonSerializable(typeof(CallToolRequestParams))]
123123
[JsonSerializable(typeof(CallToolResponse))]
124+
[JsonSerializable(typeof(CancelledNotification))]
124125
[JsonSerializable(typeof(CompleteRequestParams))]
125126
[JsonSerializable(typeof(CompleteResult))]
126127
[JsonSerializable(typeof(CreateMessageRequestParams))]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ public async Task Can_List_Registered_Tools()
9191
IMcpClient client = await CreateMcpClientForServer();
9292

9393
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
94-
Assert.Equal(12, tools.Count);
94+
Assert.Equal(13, tools.Count);
9595

9696
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
9797
Assert.Equal("Echo", echoTool.Name);
@@ -138,7 +138,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
138138
cancellationToken: TestContext.Current.CancellationToken))
139139
{
140140
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
141-
Assert.Equal(12, tools.Count);
141+
Assert.Equal(13, tools.Count);
142142

143143
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
144144
Assert.Equal("Echo", echoTool.Name);
@@ -165,7 +165,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
165165
IMcpClient client = await CreateMcpClientForServer();
166166

167167
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
168-
Assert.Equal(12, tools.Count);
168+
Assert.Equal(13, tools.Count);
169169

170170
Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
171171
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
@@ -186,7 +186,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
186186
await notificationRead;
187187

188188
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
189-
Assert.Equal(13, tools.Count);
189+
Assert.Equal(14, tools.Count);
190190
Assert.Contains(tools, t => t.Name == "NewTool");
191191

192192
notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
@@ -195,7 +195,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
195195
await notificationRead;
196196

197197
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
198-
Assert.Equal(12, tools.Count);
198+
Assert.Equal(13, tools.Count);
199199
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
200200
}
201201

@@ -560,6 +560,35 @@ public async Task HandlesIProgressParameter()
560560
}
561561
}
562562

563+
[Fact]
564+
public async Task CancellationNotificationsPropagateToToolTokens()
565+
{
566+
IMcpClient client = await CreateMcpClientForServer();
567+
568+
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
569+
Assert.NotNull(tools);
570+
Assert.NotEmpty(tools);
571+
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));
572+
573+
var requestId = new RequestId(Guid.NewGuid().ToString());
574+
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
575+
{
576+
Method = RequestMethods.ToolsCall,
577+
Id = requestId,
578+
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
579+
}, TestContext.Current.CancellationToken);
580+
581+
await client.SendNotificationAsync(
582+
NotificationMethods.CancelledNotification,
583+
parameters: new CancelledNotification()
584+
{
585+
RequestId = requestId,
586+
},
587+
cancellationToken: TestContext.Current.CancellationToken);
588+
589+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
590+
}
591+
563592
[McpServerToolType]
564593
public sealed class EchoTool(ObjectWithId objectFromDI)
565594
{
@@ -625,6 +654,21 @@ public static string EchoComplex(ComplexObject complex)
625654
return complex.Name!;
626655
}
627656

657+
[McpServerTool]
658+
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
659+
{
660+
try
661+
{
662+
await Task.Delay(Timeout.Infinite, cancellationToken);
663+
}
664+
catch (Exception)
665+
{
666+
return "canceled";
667+
}
668+
669+
return "unreachable";
670+
}
671+
628672
[McpServerTool]
629673
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";
630674

0 commit comments

Comments
 (0)