Skip to content

Commit e682995

Browse files
author
Yashwanth H L
committed
Add WithStreamableHTTPServer option to StreamableHTTPServer to allow setting a custom HTTP server instance, similar to existing functionality in SSE.
1 parent c7c0e13 commit e682995

File tree

2 files changed

+72
-6
lines changed

2 files changed

+72
-6
lines changed

server/streamable_http.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
7373
}
7474
}
7575

76+
// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer
77+
func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
78+
return func(s *StreamableHTTPServer) {
79+
s.httpServer = srv
80+
}
81+
}
82+
7683
// WithLogger sets the logger for the server
7784
func WithLogger(logger util.Logger) StreamableHTTPOption {
7885
return func(s *StreamableHTTPServer) {
@@ -155,15 +162,24 @@ func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request)
155162
// s.Start(":8080")
156163
func (s *StreamableHTTPServer) Start(addr string) error {
157164
s.mu.Lock()
158-
mux := http.NewServeMux()
159-
mux.Handle(s.endpointPath, s)
160-
s.httpServer = &http.Server{
161-
Addr: addr,
162-
Handler: mux,
165+
if s.httpServer == nil {
166+
mux := http.NewServeMux()
167+
mux.Handle(s.endpointPath, s)
168+
s.httpServer = &http.Server{
169+
Addr: addr,
170+
Handler: mux,
171+
}
172+
} else {
173+
if s.httpServer.Addr == "" {
174+
s.httpServer.Addr = addr
175+
} else if s.httpServer.Addr != addr {
176+
return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
177+
}
163178
}
179+
srv := s.httpServer
164180
s.mu.Unlock()
165181

166-
return s.httpServer.ListenAndServe()
182+
return srv.ListenAndServe()
167183
}
168184

169185
// Shutdown gracefully stops the server, closing all active sessions

server/streamable_http_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,56 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) {
670670
})
671671
}
672672

673+
func TestStreamableHTTPServer_WithOptions(t *testing.T) {
674+
t.Run("WithStreamableHTTPServer sets httpServer field", func(t *testing.T) {
675+
mcpServer := NewMCPServer("test", "1.0.0")
676+
customServer := &http.Server{Addr: ":9999"}
677+
httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer))
678+
679+
if httpServer.httpServer != customServer {
680+
t.Errorf("Expected httpServer to be set to custom server instance, got %v", httpServer.httpServer)
681+
}
682+
})
683+
684+
t.Run("Start with conflicting address returns error", func(t *testing.T) {
685+
mcpServer := NewMCPServer("test", "1.0.0")
686+
customServer := &http.Server{Addr: ":9999"}
687+
httpServer := NewStreamableHTTPServer(mcpServer, WithStreamableHTTPServer(customServer))
688+
689+
err := httpServer.Start(":8888")
690+
if err == nil {
691+
t.Error("Expected error for conflicting address, got nil")
692+
} else if !strings.Contains(err.Error(), "conflicting listen address") {
693+
t.Errorf("Expected error message to contain 'conflicting listen address', got '%s'", err.Error())
694+
}
695+
})
696+
697+
t.Run("Options consistency test", func(t *testing.T) {
698+
mcpServer := NewMCPServer("test", "1.0.0")
699+
endpointPath := "/test-mcp"
700+
customServer := &http.Server{}
701+
702+
// Options to test
703+
options := []StreamableHTTPOption{
704+
WithEndpointPath(endpointPath),
705+
WithStreamableHTTPServer(customServer),
706+
}
707+
708+
// Apply options multiple times and verify consistency
709+
for i := 0; i < 10; i++ {
710+
server := NewStreamableHTTPServer(mcpServer, options...)
711+
712+
if server.endpointPath != endpointPath {
713+
t.Errorf("Expected endpointPath %s, got %s", endpointPath, server.endpointPath)
714+
}
715+
716+
if server.httpServer != customServer {
717+
t.Errorf("Expected httpServer to match, got %v", server.httpServer)
718+
}
719+
}
720+
})
721+
}
722+
673723
func postJSON(url string, bodyObject any) (*http.Response, error) {
674724
jsonBody, _ := json.Marshal(bodyObject)
675725
req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody))

0 commit comments

Comments
 (0)