@@ -10,6 +10,8 @@ import (
1010 "os"
1111 "os/exec"
1212 "sync"
13+ "syscall"
14+ "time"
1315
1416 "github.com/mark3labs/mcp-go/mcp"
1517 "github.com/mark3labs/mcp-go/util"
@@ -24,21 +26,22 @@ type Stdio struct {
2426 args []string
2527 env []string
2628
27- cmd * exec.Cmd
28- cmdFunc CommandFunc
29- stdin io.WriteCloser
30- stdout * bufio.Scanner
31- stderr io.ReadCloser
32- responses map [string ]chan * JSONRPCResponse
33- mu sync.RWMutex
34- done chan struct {}
35- onNotification func (mcp.JSONRPCNotification )
36- notifyMu sync.RWMutex
37- onRequest RequestHandler
38- requestMu sync.RWMutex
39- ctx context.Context
40- ctxMu sync.RWMutex
41- logger util.Logger
29+ cmd * exec.Cmd
30+ cmdFunc CommandFunc
31+ stdin io.WriteCloser
32+ stdout * bufio.Scanner
33+ stderr io.ReadCloser
34+ responses map [string ]chan * JSONRPCResponse
35+ mu sync.RWMutex
36+ done chan struct {}
37+ onNotification func (mcp.JSONRPCNotification )
38+ notifyMu sync.RWMutex
39+ onRequest RequestHandler
40+ requestMu sync.RWMutex
41+ ctx context.Context
42+ ctxMu sync.RWMutex
43+ logger util.Logger
44+ terminateDuration time.Duration
4245}
4346
4447// StdioOption defines a function that configures a Stdio transport instance.
@@ -66,6 +69,13 @@ func WithCommandLogger(logger util.Logger) StdioOption {
6669 }
6770}
6871
72+ // WithTerminateDuration sets the duration to wait for graceful shutdown before sending SIGTERM.
73+ func WithTerminateDuration (duration time.Duration ) StdioOption {
74+ return func (s * Stdio ) {
75+ s .terminateDuration = duration
76+ }
77+ }
78+
6979// NewIO returns a new stdio-based transport using existing input, output, and
7080// logging streams instead of spawning a subprocess.
7181// This is useful for testing and simulating client behavior.
@@ -75,10 +85,11 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
7585 stdout : bufio .NewScanner (input ),
7686 stderr : logging ,
7787
78- responses : make (map [string ]chan * JSONRPCResponse ),
79- done : make (chan struct {}),
80- ctx : context .Background (),
81- logger : util .DefaultLogger (),
88+ responses : make (map [string ]chan * JSONRPCResponse ),
89+ done : make (chan struct {}),
90+ ctx : context .Background (),
91+ logger : util .DefaultLogger (),
92+ terminateDuration : 5 * time .Second , // Default 5 second timeout
8293 }
8394}
8495
@@ -109,10 +120,11 @@ func NewStdioWithOptions(
109120 args : args ,
110121 env : env ,
111122
112- responses : make (map [string ]chan * JSONRPCResponse ),
113- done : make (chan struct {}),
114- ctx : context .Background (),
115- logger : util .DefaultLogger (),
123+ responses : make (map [string ]chan * JSONRPCResponse ),
124+ done : make (chan struct {}),
125+ ctx : context .Background (),
126+ logger : util .DefaultLogger (),
127+ terminateDuration : 5 * time .Second , // Default 5 second timeout
116128 }
117129
118130 for _ , opt := range opts {
@@ -189,8 +201,10 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
189201 return nil
190202}
191203
192- // Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
193- // Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
204+ // Close closes the input stream to the child process, and awaits normal
205+ // termination of the command. If the command does not exit, it is signalled to
206+ // terminate, and then eventually killed. This follows the MCP specification
207+ // for stdio transport shutdown.
194208func (c * Stdio ) Close () error {
195209 select {
196210 case <- c .done :
@@ -200,6 +214,8 @@ func (c *Stdio) Close() error {
200214 // cancel all in-flight request
201215 close (c .done )
202216
217+ // For the stdio transport, the client SHOULD initiate shutdown by:
218+ // First, closing the input stream to the child process (the server)
203219 if c .stdin != nil {
204220 if err := c .stdin .Close (); err != nil {
205221 return fmt .Errorf ("failed to close stdin: %w" , err )
@@ -212,7 +228,40 @@ func (c *Stdio) Close() error {
212228 }
213229
214230 if c .cmd != nil {
215- return c .cmd .Wait ()
231+ resChan := make (chan error , 1 )
232+ go func () {
233+ resChan <- c .cmd .Wait ()
234+ }()
235+
236+ // Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time
237+ wait := func () (error , bool ) {
238+ select {
239+ case err := <- resChan :
240+ return err , true
241+ case <- time .After (c .terminateDuration ):
242+ }
243+ return nil , false
244+ }
245+
246+ if err , ok := wait (); ok {
247+ return err
248+ }
249+
250+ // Note the condition here: if sending SIGTERM fails, don't wait and just
251+ // move on to SIGKILL.
252+ if err := c .cmd .Process .Signal (syscall .SIGTERM ); err == nil {
253+ if err , ok := wait (); ok {
254+ return err
255+ }
256+ }
257+ // Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM
258+ if err := c .cmd .Process .Kill (); err != nil {
259+ return err
260+ }
261+ if err , ok := wait (); ok {
262+ return err
263+ }
264+ return fmt .Errorf ("unresponsive subprocess" )
216265 }
217266
218267 return nil
0 commit comments