11package io .modelcontextprotocol .server .transport ;
22
3+ import java .awt .PageAttributes ;
4+
5+ import com .fasterxml .jackson .core .JsonProcessingException ;
36import com .fasterxml .jackson .core .type .TypeReference ;
47import com .fasterxml .jackson .databind .ObjectMapper ;
8+
59import io .modelcontextprotocol .spec .DefaultMcpTransportContext ;
610import io .modelcontextprotocol .spec .McpError ;
711import io .modelcontextprotocol .spec .McpSchema ;
1014import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
1115import io .modelcontextprotocol .spec .McpTransportContext ;
1216import io .modelcontextprotocol .util .Assert ;
17+
1318import org .slf4j .Logger ;
1419import org .slf4j .LoggerFactory ;
1520import org .springframework .http .HttpStatus ;
1924import org .springframework .web .reactive .function .server .RouterFunctions ;
2025import org .springframework .web .reactive .function .server .ServerRequest ;
2126import org .springframework .web .reactive .function .server .ServerResponse ;
27+
2228import reactor .core .Disposable ;
2329import reactor .core .Exceptions ;
2430import reactor .core .publisher .Flux ;
2531import reactor .core .publisher .FluxSink ;
2632import reactor .core .publisher .Mono ;
2733
2834import java .io .IOException ;
35+ import java .util .ArrayList ;
36+ import java .util .List ;
2937import java .util .concurrent .ConcurrentHashMap ;
3038import java .util .function .Function ;
3139
40+ /**
41+ * Server-side implementation of the Model Context Protocol (MCP) streamable transport
42+ * layer using HTTP with Server-Sent Events (SSE) through Spring WebFlux.
43+ *
44+ * <p>
45+ *
46+ * @author Dariusz Jędrzejczyk
47+ * @author Zachary German
48+ * @see McpStreamableServerTransportProvider
49+ * @see RouterFunction
50+ */
3251public class WebFluxStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
3352
3453 private static final Logger logger = LoggerFactory .getLogger (WebFluxStreamableServerTransportProvider .class );
@@ -37,6 +56,12 @@ public class WebFluxStreamableServerTransportProvider implements McpStreamableSe
3756
3857 public static final String DEFAULT_BASE_URL = "" ;
3958
59+ private static final String MCP_SESSION_ID = "mcp-session-id" ;
60+
61+ private static final String LAST_EVENT_ID = "Last-Event-ID" ;
62+
63+ private static final String ACCEPT = "Accept" ;
64+
4065 private final ObjectMapper objectMapper ;
4166
4267 private final String baseUrl ;
@@ -195,21 +220,40 @@ private Mono<ServerResponse> handleGet(ServerRequest request) {
195220 McpTransportContext transportContext = this .contextExtractor .apply (request );
196221
197222 return Mono .defer (() -> {
198- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
199- return ServerResponse .badRequest ().build (); // TODO: say we need a session
200- // id
223+ List <String > badRequestErrors = new ArrayList <>();
224+
225+ String accept = request .headers ().asHttpHeaders ().getFirst (ACCEPT );
226+ if (accept == null || !accept .contains (MediaType .TEXT_EVENT_STREAM_VALUE )) {
227+ badRequestErrors .add ("text/event-stream required in Accept header" );
201228 }
202229
203- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
230+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
231+
232+ if (sessionId == null || sessionId .isBlank ()) {
233+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
234+ }
235+
236+ if (!badRequestErrors .isEmpty ()) {
237+ String combinedMessage = String .join ("; " , badRequestErrors );
238+ try {
239+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
240+ return ServerResponse .badRequest ().bodyValue (errorJson );
241+ }
242+ catch (JsonProcessingException e ) {
243+ logger .debug ("Failed to serialize McpError: {}" , e );
244+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
245+ .bodyValue ("Failed to serialize error message." );
246+ }
247+ }
204248
205249 McpStreamableServerSession session = this .sessions .get (sessionId );
206250
207251 if (session == null ) {
208252 return ServerResponse .notFound ().build ();
209253 }
210254
211- if (request .headers ().asHttpHeaders ().containsKey ("mcp-last-id" )) {
212- String lastId = request .headers ().asHttpHeaders ().getFirst ("mcp-last-id" );
255+ if (request .headers ().asHttpHeaders ().containsKey (LAST_EVENT_ID )) {
256+ String lastId = request .headers ().asHttpHeaders ().getFirst (LAST_EVENT_ID );
213257 return ServerResponse .ok ()
214258 .contentType (MediaType .TEXT_EVENT_STREAM )
215259 .body (session .replay (lastId ), ServerSentEvent .class );
@@ -252,9 +296,31 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
252296
253297 return request .bodyToMono (String .class ).<ServerResponse >flatMap (body -> {
254298 try {
299+ List <String > badRequestErrors = new ArrayList <>();
300+
301+ String accept = request .headers ().asHttpHeaders ().getFirst (ACCEPT );
302+ if (accept == null || !accept .contains (MediaType .TEXT_EVENT_STREAM_VALUE )) {
303+ badRequestErrors .add ("text/event-stream required in Accept header" );
304+ }
305+ if (accept == null || !accept .contains (MediaType .APPLICATION_JSON_VALUE )) {
306+ badRequestErrors .add ("application/json required in Accept header" );
307+ }
308+
255309 McpSchema .JSONRPCMessage message = McpSchema .deserializeJsonRpcMessage (objectMapper , body );
256310 if (message instanceof McpSchema .JSONRPCRequest jsonrpcRequest
257311 && jsonrpcRequest .method ().equals (McpSchema .METHOD_INITIALIZE )) {
312+ if (!badRequestErrors .isEmpty ()) {
313+ String combinedMessage = String .join ("; " , badRequestErrors );
314+ try {
315+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
316+ return ServerResponse .badRequest ().bodyValue (errorJson );
317+ }
318+ catch (JsonProcessingException e ) {
319+ logger .debug ("Failed to serialize McpError: {}" , e );
320+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
321+ .bodyValue ("Failed to serialize error message." );
322+ }
323+ }
258324 McpSchema .InitializeRequest initializeRequest = objectMapper .convertValue (jsonrpcRequest .params (),
259325 new TypeReference <McpSchema .InitializeRequest >() {
260326 });
@@ -274,15 +340,29 @@ private Mono<ServerResponse> handlePost(ServerRequest request) {
274340 })
275341 .flatMap (initResult -> ServerResponse .ok ()
276342 .contentType (MediaType .APPLICATION_JSON )
277- .header ("mcp-session-id" , init .session ().getId ())
343+ .header (MCP_SESSION_ID , init .session ().getId ())
278344 .bodyValue (initResult ));
279345 }
280346
281- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
282- return ServerResponse .badRequest ().bodyValue (new McpError ("Session ID missing" ));
347+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
348+
349+ if (sessionId == null || sessionId .isBlank ()) {
350+ badRequestErrors .add ("Session ID required in mcp-session-id header" );
351+ }
352+
353+ if (!badRequestErrors .isEmpty ()) {
354+ String combinedMessage = String .join ("; " , badRequestErrors );
355+ try {
356+ String errorJson = objectMapper .writeValueAsString (new McpError (combinedMessage ));
357+ return ServerResponse .badRequest ().bodyValue (errorJson );
358+ }
359+ catch (JsonProcessingException e ) {
360+ logger .debug ("Failed to serialize McpError: {}" , e );
361+ return ServerResponse .status (HttpStatus .INTERNAL_SERVER_ERROR )
362+ .bodyValue ("Failed to serialize error message." );
363+ }
283364 }
284365
285- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
286366 McpStreamableServerSession session = sessions .get (sessionId );
287367
288368 if (session == null ) {
@@ -330,7 +410,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
330410 McpTransportContext transportContext = this .contextExtractor .apply (request );
331411
332412 return Mono .defer (() -> {
333- if (!request .headers ().asHttpHeaders ().containsKey ("mcp-session-id" )) {
413+ if (!request .headers ().asHttpHeaders ().containsKey (MCP_SESSION_ID )) {
334414 return ServerResponse .badRequest ().build (); // TODO: say we need a session
335415 // id
336416 }
@@ -340,7 +420,7 @@ private Mono<ServerResponse> handleDelete(ServerRequest request) {
340420 return ServerResponse .status (HttpStatus .METHOD_NOT_ALLOWED ).build ();
341421 }
342422
343- String sessionId = request .headers ().asHttpHeaders ().getFirst ("mcp-session-id" );
423+ String sessionId = request .headers ().asHttpHeaders ().getFirst (MCP_SESSION_ID );
344424
345425 McpStreamableServerSession session = this .sessions .get (sessionId );
346426
0 commit comments