@@ -16,6 +16,12 @@ namespace ModelContextProtocol.Shared;
16
16
/// </summary>
17
17
internal sealed class McpSession : IDisposable
18
18
{
19
+ /// <summary>
20
+ /// In-flight request handling, indexed by request ID. The value provides a <see cref="CancellationTokenSource"/>
21
+ /// that can be used to request cancellation of the in-flight handler.
22
+ /// </summary>
23
+ private static readonly ConcurrentDictionary < RequestId , CancellationTokenSource > s_handlingRequests = new ( ) ;
24
+
19
25
private readonly ITransport _transport ;
20
26
private readonly RequestHandlers _requestHandlers ;
21
27
private readonly NotificationHandlers _notificationHandlers ;
@@ -69,25 +75,70 @@ public async Task ProcessMessagesAsync(CancellationToken cancellationToken)
69
75
{
70
76
_logger . TransportMessageRead ( EndpointName , message . GetType ( ) . Name ) ;
71
77
72
- // Fire and forget the message handling task to avoid blocking the transport
73
- // If awaiting the task, the transport will not be able to read more messages,
74
- // which could lead to a deadlock if the handler sends a message back
75
78
_ = ProcessMessageAsync ( ) ;
76
79
async Task ProcessMessageAsync ( )
77
80
{
81
+ IJsonRpcMessageWithId ? messageWithId = message as IJsonRpcMessageWithId ;
82
+ CancellationTokenSource ? combinedCts = null ;
83
+ try
84
+ {
85
+ // Register before we yield, so that the tracking is guaranteed to be there
86
+ // when subsequent messages arrive, even if the asynchronous processing happens
87
+ // out of order.
88
+ if ( messageWithId is not null )
89
+ {
90
+ combinedCts = CancellationTokenSource . CreateLinkedTokenSource ( cancellationToken ) ;
91
+ s_handlingRequests [ messageWithId . Id ] = combinedCts ;
92
+ }
93
+
94
+ // Fire and forget the message handling to avoid blocking the transport
95
+ // If awaiting the task, the transport will not be able to read more messages,
96
+ // which could lead to a deadlock if the handler sends a message back
97
+
78
98
#if NET
79
- await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
99
+ await Task . CompletedTask . ConfigureAwait ( ConfigureAwaitOptions . ForceYielding ) ;
80
100
#else
81
- await default ( ForceYielding ) ;
101
+ await default ( ForceYielding ) ;
82
102
#endif
83
- try
84
- {
85
- await HandleMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
103
+
104
+ // Handle the message.
105
+ await HandleMessageAsync ( message , combinedCts ? . Token ?? cancellationToken ) . ConfigureAwait ( false ) ;
86
106
}
87
107
catch ( Exception ex )
88
108
{
89
- var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
90
- _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
109
+ // Only send responses for request errors that aren't user-initiated cancellation.
110
+ bool isUserCancellation =
111
+ ex is OperationCanceledException &&
112
+ ! cancellationToken . IsCancellationRequested &&
113
+ combinedCts ? . IsCancellationRequested is true ;
114
+
115
+ if ( ! isUserCancellation && message is JsonRpcRequest request )
116
+ {
117
+ _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
118
+ await _transport . SendMessageAsync ( new JsonRpcError
119
+ {
120
+ Id = request . Id ,
121
+ JsonRpc = "2.0" ,
122
+ Error = new JsonRpcErrorDetail
123
+ {
124
+ Code = ErrorCodes . InternalError ,
125
+ Message = ex . Message
126
+ }
127
+ } , cancellationToken ) . ConfigureAwait ( false ) ;
128
+ }
129
+ else if ( ex is not OperationCanceledException )
130
+ {
131
+ var payload = JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ;
132
+ _logger . MessageHandlerError ( EndpointName , message . GetType ( ) . Name , payload , ex ) ;
133
+ }
134
+ }
135
+ finally
136
+ {
137
+ if ( messageWithId is not null )
138
+ {
139
+ s_handlingRequests . TryRemove ( messageWithId . Id , out _ ) ;
140
+ combinedCts ! . Dispose ( ) ;
141
+ }
91
142
}
92
143
}
93
144
}
@@ -123,6 +174,24 @@ private async Task HandleMessageAsync(IJsonRpcMessage message, CancellationToken
123
174
124
175
private async Task HandleNotification ( JsonRpcNotification notification )
125
176
{
177
+ // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.)
178
+ if ( notification . Method == NotificationMethods . CancelledNotification )
179
+ {
180
+ try
181
+ {
182
+ if ( GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
183
+ s_handlingRequests . TryGetValue ( cn . RequestId , out var cts ) )
184
+ {
185
+ await cts . CancelAsync ( ) . ConfigureAwait ( false ) ;
186
+ }
187
+ }
188
+ catch
189
+ {
190
+ // "Invalid cancellation notifications SHOULD be ignored"
191
+ }
192
+ }
193
+
194
+ // Handle user-defined notifications.
126
195
if ( _notificationHandlers . TryGetValue ( notification . Method , out var handlers ) )
127
196
{
128
197
foreach ( var notificationHandler in handlers )
@@ -161,33 +230,15 @@ private async Task HandleRequest(JsonRpcRequest request, CancellationToken cance
161
230
{
162
231
if ( _requestHandlers . TryGetValue ( request . Method , out var handler ) )
163
232
{
164
- try
233
+ _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
234
+ var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
235
+ _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
236
+ await _transport . SendMessageAsync ( new JsonRpcResponse
165
237
{
166
- _logger . RequestHandlerCalled ( EndpointName , request . Method ) ;
167
- var result = await handler ( request , cancellationToken ) . ConfigureAwait ( false ) ;
168
- _logger . RequestHandlerCompleted ( EndpointName , request . Method ) ;
169
- await _transport . SendMessageAsync ( new JsonRpcResponse
170
- {
171
- Id = request . Id ,
172
- JsonRpc = "2.0" ,
173
- Result = result
174
- } , cancellationToken ) . ConfigureAwait ( false ) ;
175
- }
176
- catch ( Exception ex )
177
- {
178
- _logger . RequestHandlerError ( EndpointName , request . Method , ex ) ;
179
- // Send error response
180
- await _transport . SendMessageAsync ( new JsonRpcError
181
- {
182
- Id = request . Id ,
183
- JsonRpc = "2.0" ,
184
- Error = new JsonRpcErrorDetail
185
- {
186
- Code = - 32000 , // Implementation defined error
187
- Message = ex . Message
188
- }
189
- } , cancellationToken ) . ConfigureAwait ( false ) ;
190
- }
238
+ Id = request . Id ,
239
+ JsonRpc = "2.0" ,
240
+ Result = result
241
+ } , cancellationToken ) . ConfigureAwait ( false ) ;
191
242
}
192
243
else
193
244
{
@@ -273,7 +324,7 @@ public async Task<TResult> SendRequestAsync<TResult>(JsonRpcRequest request, Can
273
324
}
274
325
}
275
326
276
- public Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
327
+ public async Task SendMessageAsync ( IJsonRpcMessage message , CancellationToken cancellationToken = default )
277
328
{
278
329
Throw . IfNull ( message ) ;
279
330
@@ -288,7 +339,44 @@ public Task SendMessageAsync(IJsonRpcMessage message, CancellationToken cancella
288
339
_logger . SendingMessage ( EndpointName , JsonSerializer . Serialize ( message , _jsonOptions . GetTypeInfo < IJsonRpcMessage > ( ) ) ) ;
289
340
}
290
341
291
- return _transport . SendMessageAsync ( message , cancellationToken ) ;
342
+ await _transport . SendMessageAsync ( message , cancellationToken ) . ConfigureAwait ( false ) ;
343
+
344
+ // If the sent notification was a cancellation notification, cancel the pending request's await, as either the
345
+ // server won't be sending a response, or per the specification, the response should be ignored. There are inherent
346
+ // race conditions here, so it's possible and allowed for the operation to complete before we get to this point.
347
+ if ( message is JsonRpcNotification { Method : NotificationMethods . CancelledNotification } notification &&
348
+ GetCancelledNotificationParams ( notification . Params ) is CancelledNotification cn &&
349
+ _pendingRequests . TryRemove ( cn . RequestId , out var tcs ) )
350
+ {
351
+ tcs . TrySetCanceled ( default ) ;
352
+ }
353
+ }
354
+
355
+ private static CancelledNotification ? GetCancelledNotificationParams ( object ? notificationParams )
356
+ {
357
+ try
358
+ {
359
+ switch ( notificationParams )
360
+ {
361
+ case null :
362
+ return null ;
363
+
364
+ case CancelledNotification cn :
365
+ return cn ;
366
+
367
+ case JsonElement je :
368
+ return JsonSerializer . Deserialize ( je , McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
369
+
370
+ default :
371
+ return JsonSerializer . Deserialize (
372
+ JsonSerializer . Serialize ( notificationParams , McpJsonUtilities . DefaultOptions . GetTypeInfo < object ? > ( ) ) ,
373
+ McpJsonUtilities . DefaultOptions . GetTypeInfo < CancelledNotification > ( ) ) ;
374
+ }
375
+ }
376
+ catch
377
+ {
378
+ return null ;
379
+ }
292
380
}
293
381
294
382
public void Dispose ( )
0 commit comments