|
4 | 4 | using Microsoft.AspNetCore.WebUtilities; |
5 | 5 | using Microsoft.Extensions.Logging; |
6 | 6 | using Microsoft.Extensions.Options; |
7 | | -using Microsoft.Extensions.Primitives; |
8 | 7 | using Microsoft.Net.Http.Headers; |
9 | 8 | using ModelContextProtocol.AspNetCore.Stateless; |
10 | 9 | using ModelContextProtocol.Protocol; |
@@ -136,6 +135,7 @@ public async Task HandleDeleteRequestAsync(HttpContext context) |
136 | 135 | var transport = new StreamableHttpServerTransport |
137 | 136 | { |
138 | 137 | Stateless = true, |
| 138 | + SessionId = sessionId, |
139 | 139 | }; |
140 | 140 | session = await CreateSessionAsync(context, transport, sessionId, statelessSessionId); |
141 | 141 | } |
@@ -184,7 +184,10 @@ private async ValueTask<HttpMcpSession<StreamableHttpServerTransport>> StartNewS |
184 | 184 | if (!HttpServerTransportOptions.Stateless) |
185 | 185 | { |
186 | 186 | sessionId = MakeNewSessionId(); |
187 | | - transport = new(); |
| 187 | + transport = new() |
| 188 | + { |
| 189 | + SessionId = sessionId, |
| 190 | + }; |
188 | 191 | context.Response.Headers["mcp-session-id"] = sessionId; |
189 | 192 | } |
190 | 193 | else |
@@ -286,19 +289,22 @@ internal static string MakeNewSessionId() |
286 | 289 |
|
287 | 290 | private void ScheduleStatelessSessionIdWrite(HttpContext context, StreamableHttpServerTransport transport) |
288 | 291 | { |
289 | | - context.Response.OnStarting(() => |
| 292 | + transport.OnInitRequestReceived = initRequestParams => |
290 | 293 | { |
291 | 294 | var statelessId = new StatelessSessionId |
292 | 295 | { |
293 | | - ClientInfo = transport?.InitializeRequest?.ClientInfo, |
| 296 | + ClientInfo = initRequestParams.ClientInfo, |
294 | 297 | UserIdClaim = GetUserIdClaim(context.User), |
295 | 298 | }; |
296 | 299 |
|
297 | 300 | var sessionJson = JsonSerializer.Serialize(statelessId, StatelessSessionIdJsonContext.Default.StatelessSessionId); |
298 | | - var sessionId = Protector.Protect(sessionJson); |
299 | | - |
300 | | - context.Response.Headers["mcp-session-id"] = sessionId; |
| 301 | + transport.SessionId = Protector.Protect(sessionJson); |
| 302 | + }; |
301 | 303 |
|
| 304 | + context.Response.OnStarting(() => |
| 305 | + { |
| 306 | + Debug.Assert(transport.SessionId is not null); |
| 307 | + context.Response.Headers["mcp-session-id"] = transport.SessionId; |
302 | 308 | return Task.CompletedTask; |
303 | 309 | }); |
304 | 310 | } |
|
0 commit comments