Skip to content

Commit af43905

Browse files
committed
Implement cancellation notifications
1 parent 7f5ec50 commit af43905

File tree

4 files changed

+197
-43
lines changed

4 files changed

+197
-43
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System.Text.Json.Serialization;
2+
3+
namespace ModelContextProtocol.Protocol.Messages;
4+
5+
/// <summary>
6+
/// This notification indicates that the result will be unused, so any associated processing SHOULD cease.
7+
/// </summary>
8+
public sealed class CancelledNotification
9+
{
10+
/// <summary>
11+
/// The ID of the request to cancel.
12+
/// </summary>
13+
[JsonPropertyName("requestId")]
14+
public RequestId RequestId { get; set; }
15+
16+
/// <summary>
17+
/// An optional string describing the reason for the cancellation.
18+
/// </summary>
19+
[JsonPropertyName("reason")]
20+
public string? Reason { get; set; }
21+
}

src/ModelContextProtocol/Shared/McpSession.cs

Lines changed: 126 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ namespace ModelContextProtocol.Shared;
1616
/// </summary>
1717
internal sealed class McpSession : IDisposable
1818
{
19+
/// <summary>
20+
/// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
21+
/// that can be used to request cancellation of the in-flight handler.
22+
/// </summary>
23+
private static readonly ConcurrentDictionary<RequestId, CancellationTokenSource> s_handlingRequests = new();
24+
1925
private readonly ITransport _transport;
2026
private readonly RequestHandlers _requestHandlers;
2127
private readonly NotificationHandlers _notificationHandlers;
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
6975
{
7076
_logger.TransportMessageRead(EndpointName, message.GetType().Name);
7177

72-
// Fire and forget the message handling task to avoid blocking the transport
73-
// If awaiting the task, the transport will not be able to read more messages,
74-
// which could lead to a deadlock if the handler sends a message back
7578
_ = ProcessMessageAsync();
7679
async Task ProcessMessageAsync()
7780
{
81+
IJsonRpcMessageWithId? messageWithId = message as IJsonRpcMessageWithId;
82+
CancellationTokenSource? combinedCts = null;
83+
try
84+
{
85+
// Register before we yield, so that the tracking is guaranteed to be there
86+
// when subsequent messages arrive, even if the asynchronous processing happens
87+
// out of order.
88+
if (messageWithId is not null)
89+
{
90+
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
91+
s_handlingRequests[messageWithId.Id] = combinedCts;
92+
}
93+
94+
// Fire and forget the message handling to avoid blocking the transport
95+
// If awaiting the task, the transport will not be able to read more messages,
96+
// which could lead to a deadlock if the handler sends a message back
97+
7898
#if NET
79-
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
99+
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
80100
#else
81-
await default(ForceYielding);
101+
await default(ForceYielding);
82102
#endif
83-
try
84-
{
85-
await HandleMessageAsync(message, cancellationToken).ConfigureAwait(false);
103+
104+
// Handle the message.
105+
await HandleMessageAsync(message, combinedCts?.Token ?? cancellationToken).ConfigureAwait(false);
86106
}
87107
catch (Exception ex)
88108
{
89-
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
90-
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
109+
// Only send responses for request errors that aren't user-initiated cancellation.
110+
bool isUserCancellation =
111+
ex is OperationCanceledException &&
112+
!cancellationToken.IsCancellationRequested &&
113+
combinedCts?.IsCancellationRequested is true;
114+
115+
if (!isUserCancellation && message is JsonRpcRequest request)
116+
{
117+
_logger.RequestHandlerError(EndpointName, request.Method, ex);
118+
await _transport.SendMessageAsync(new JsonRpcError
119+
{
120+
Id = request.Id,
121+
JsonRpc = "2.0",
122+
Error = new JsonRpcErrorDetail
123+
{
124+
Code = ErrorCodes.InternalError,
125+
Message = ex.Message
126+
}
127+
}, cancellationToken).ConfigureAwait(false);
128+
}
129+
else if (ex is not OperationCanceledException)
130+
{
131+
var payload = JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>());
132+
_logger.MessageHandlerError(EndpointName, message.GetType().Name, payload, ex);
133+
}
134+
}
135+
finally
136+
{
137+
if (messageWithId is not null)
138+
{
139+
s_handlingRequests.TryRemove(messageWithId.Id, out _);
140+
combinedCts!.Dispose();
141+
}
91142
}
92143
}
93144
}
@@ -123,6 +174,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123174

124175
private async Task HandleNotification(JsonRpcNotification notification)
125176
{
177+
// Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178+
if (notification.Method == NotificationMethods.CancelledNotification)
179+
{
180+
try
181+
{
182+
if (GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
183+
s_handlingRequests.TryGetValue(cn.RequestId, out var cts))
184+
{
185+
await cts.CancelAsync().ConfigureAwait(false);
186+
}
187+
}
188+
catch
189+
{
190+
// "Invalid cancellation notifications SHOULD be ignored"
191+
}
192+
}
193+
194+
// Handle user-defined notifications.
126195
if (_notificationHandlers.TryGetValue(notification.Method, out var handlers))
127196
{
128197
foreach (var notificationHandler in handlers)
@@ -161,33 +230,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161230
{
162231
if (_requestHandlers.TryGetValue(request.Method, out var handler))
163232
{
164-
try
233+
_logger.RequestHandlerCalled(EndpointName, request.Method);
234+
var result = await handler(request, cancellationToken).ConfigureAwait(false);
235+
_logger.RequestHandlerCompleted(EndpointName, request.Method);
236+
await _transport.SendMessageAsync(new JsonRpcResponse
165237
{
166-
_logger.RequestHandlerCalled(EndpointName, request.Method);
167-
var result = await handler(request, cancellationToken).ConfigureAwait(false);
168-
_logger.RequestHandlerCompleted(EndpointName, request.Method);
169-
await _transport.SendMessageAsync(new JsonRpcResponse
170-
{
171-
Id = request.Id,
172-
JsonRpc = "2.0",
173-
Result = result
174-
}, cancellationToken).ConfigureAwait(false);
175-
}
176-
catch (Exception ex)
177-
{
178-
_logger.RequestHandlerError(EndpointName, request.Method, ex);
179-
// Send error response
180-
await _transport.SendMessageAsync(new JsonRpcError
181-
{
182-
Id = request.Id,
183-
JsonRpc = "2.0",
184-
Error = new JsonRpcErrorDetail
185-
{
186-
Code = -32000, // Implementation defined error
187-
Message = ex.Message
188-
}
189-
}, cancellationToken).ConfigureAwait(false);
190-
}
238+
Id = request.Id,
239+
JsonRpc = "2.0",
240+
Result = result
241+
}, cancellationToken).ConfigureAwait(false);
191242
}
192243
else
193244
{
@@ -273,7 +324,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273324
}
274325
}
275326

276-
public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
327+
public async Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancellationToken = default)
277328
{
278329
Throw.IfNull(message);
279330

@@ -288,7 +339,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288339
_logger.SendingMessage(EndpointName, JsonSerializer.Serialize(message, _jsonOptions.GetTypeInfo<IJsonRpcMessage>()));
289340
}
290341

291-
return _transport.SendMessageAsync(message, cancellationToken);
342+
await _transport.SendMessageAsync(message, cancellationToken).ConfigureAwait(false);
343+
344+
// If the sent notification was a cancellation notification, cancel the pending request's await, as either the
345+
// server won't be sending a response, or per the specification, the response should be ignored. There are inherent
346+
// race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
347+
if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification &&
348+
GetCancelledNotificationParams(notification.Params) is CancelledNotification cn &&
349+
_pendingRequests.TryRemove(cn.RequestId, out var tcs))
350+
{
351+
tcs.TrySetCanceled(default);
352+
}
353+
}
354+
355+
private static CancelledNotification? GetCancelledNotificationParams(object? notificationParams)
356+
{
357+
try
358+
{
359+
switch (notificationParams)
360+
{
361+
case null:
362+
return null;
363+
364+
case CancelledNotification cn:
365+
return cn;
366+
367+
case JsonElement je:
368+
return JsonSerializer.Deserialize(je, McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
369+
370+
default:
371+
return JsonSerializer.Deserialize(
372+
JsonSerializer.Serialize(notificationParams, McpJsonUtilities.DefaultOptions.GetTypeInfo<object?>()),
373+
McpJsonUtilities.DefaultOptions.GetTypeInfo<CancelledNotification>());
374+
}
375+
}
376+
catch
377+
{
378+
return null;
379+
}
292380
}
293381

294382
public void Dispose()

src/ModelContextProtocol/Utils/Json/McpJsonUtilities.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)
121121
// MCP Request Params / Results
122122
[JsonSerializable(typeof(CallToolRequestParams))]
123123
[JsonSerializable(typeof(CallToolResponse))]
124+
[JsonSerializable(typeof(CancelledNotification))]
124125
[JsonSerializable(typeof(CompleteRequestParams))]
125126
[JsonSerializable(typeof(CompleteResult))]
126127
[JsonSerializable(typeof(CreateMessageRequestParams))]

tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ public async Task Can_List_Registered_Tools()
9292
IMcpClient client = await CreateMcpClientForServer();
9393

9494
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
95-
Assert.Equal(12, tools.Count);
95+
Assert.Equal(13, tools.Count);
9696

9797
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
9898
Assert.Equal("Echo", echoTool.Name);
@@ -139,7 +139,7 @@ public async Task Can_Create_Multiple_Servers_From_Options_And_List_Registered_T
139139
cancellationToken: TestContext.Current.CancellationToken))
140140
{
141141
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
142-
Assert.Equal(12, tools.Count);
142+
Assert.Equal(13, tools.Count);
143143

144144
McpClientTool echoTool = tools.First(t => t.Name == "Echo");
145145
Assert.Equal("Echo", echoTool.Name);
@@ -166,7 +166,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
166166
IMcpClient client = await CreateMcpClientForServer();
167167

168168
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
169-
Assert.Equal(12, tools.Count);
169+
Assert.Equal(13, tools.Count);
170170

171171
Channel<JsonRpcNotification> listChanged = Channel.CreateUnbounded<JsonRpcNotification>();
172172
client.AddNotificationHandler(NotificationMethods.ToolListChangedNotification, notification =>
@@ -187,7 +187,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
187187
await notificationRead;
188188

189189
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
190-
Assert.Equal(13, tools.Count);
190+
Assert.Equal(14, tools.Count);
191191
Assert.Contains(tools, t => t.Name == "NewTool");
192192

193193
notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken);
@@ -196,7 +196,7 @@ public async Task Can_Be_Notified_Of_Tool_Changes()
196196
await notificationRead;
197197

198198
tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
199-
Assert.Equal(12, tools.Count);
199+
Assert.Equal(13, tools.Count);
200200
Assert.DoesNotContain(tools, t => t.Name == "NewTool");
201201
}
202202

@@ -513,6 +513,35 @@ public async Task HandlesIProgressParameter()
513513
}
514514
}
515515

516+
[Fact]
517+
public async Task CancellationNotificationsPropagateToToolTokens()
518+
{
519+
IMcpClient client = await CreateMcpClientForServer();
520+
521+
var tools = await client.ListToolsAsync(TestContext.Current.CancellationToken);
522+
Assert.NotNull(tools);
523+
Assert.NotEmpty(tools);
524+
McpClientTool cancelableTool = tools.First(t => t.Name == nameof(EchoTool.InfiniteCancelableOperation));
525+
526+
var requestId = new RequestId(Guid.NewGuid().ToString());
527+
var invokeTask = client.SendRequestAsync<CallToolResponse>(new JsonRpcRequest()
528+
{
529+
Method = RequestMethods.ToolsCall,
530+
Id = requestId,
531+
Params = new CallToolRequestParams() { Name = cancelableTool.ProtocolTool.Name },
532+
}, TestContext.Current.CancellationToken);
533+
534+
await client.SendNotificationAsync(
535+
NotificationMethods.CancelledNotification,
536+
parameters: new CancelledNotification()
537+
{
538+
RequestId = requestId,
539+
},
540+
cancellationToken: TestContext.Current.CancellationToken);
541+
542+
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => invokeTask);
543+
}
544+
516545
[McpServerToolType]
517546
public sealed class EchoTool(ObjectWithId objectFromDI)
518547
{
@@ -578,6 +607,21 @@ public static string EchoComplex(ComplexObject complex)
578607
return complex.Name!;
579608
}
580609

610+
[McpServerTool]
611+
public static async Task<string> InfiniteCancelableOperation(CancellationToken cancellationToken)
612+
{
613+
try
614+
{
615+
await Task.Delay(Timeout.Infinite, cancellationToken);
616+
}
617+
catch (Exception)
618+
{
619+
return "canceled";
620+
}
621+
622+
return "unreachable";
623+
}
624+
581625
[McpServerTool]
582626
public string GetCtorParameter() => $"{_randomValue}:{objectFromDI.Id}";
583627

0 commit comments

Comments
 (0)