Skip to content

Commit 3942767

Browse files
committed
feat: improve process killing
1 parent 7657c15 commit 3942767

File tree

1 file changed

+75
-26
lines changed

1 file changed

+75
-26
lines changed

client/transport/stdio.go

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"os"
1111
"os/exec"
1212
"sync"
13+
"syscall"
14+
"time"
1315

1416
"github.com/mark3labs/mcp-go/mcp"
1517
"github.com/mark3labs/mcp-go/util"
@@ -24,21 +26,22 @@ type Stdio struct {
2426
args []string
2527
env []string
2628

27-
cmd *exec.Cmd
28-
cmdFunc CommandFunc
29-
stdin io.WriteCloser
30-
stdout *bufio.Scanner
31-
stderr io.ReadCloser
32-
responses map[string]chan *JSONRPCResponse
33-
mu sync.RWMutex
34-
done chan struct{}
35-
onNotification func(mcp.JSONRPCNotification)
36-
notifyMu sync.RWMutex
37-
onRequest RequestHandler
38-
requestMu sync.RWMutex
39-
ctx context.Context
40-
ctxMu sync.RWMutex
41-
logger util.Logger
29+
cmd *exec.Cmd
30+
cmdFunc CommandFunc
31+
stdin io.WriteCloser
32+
stdout *bufio.Scanner
33+
stderr io.ReadCloser
34+
responses map[string]chan *JSONRPCResponse
35+
mu sync.RWMutex
36+
done chan struct{}
37+
onNotification func(mcp.JSONRPCNotification)
38+
notifyMu sync.RWMutex
39+
onRequest RequestHandler
40+
requestMu sync.RWMutex
41+
ctx context.Context
42+
ctxMu sync.RWMutex
43+
logger util.Logger
44+
terminateDuration time.Duration
4245
}
4346

4447
// StdioOption defines a function that configures a Stdio transport instance.
@@ -66,6 +69,13 @@ func WithCommandLogger(logger util.Logger) StdioOption {
6669
}
6770
}
6871

72+
// WithTerminateDuration sets the duration to wait for graceful shutdown before sending SIGTERM.
73+
func WithTerminateDuration(duration time.Duration) StdioOption {
74+
return func(s *Stdio) {
75+
s.terminateDuration = duration
76+
}
77+
}
78+
6979
// NewIO returns a new stdio-based transport using existing input, output, and
7080
// logging streams instead of spawning a subprocess.
7181
// This is useful for testing and simulating client behavior.
@@ -75,10 +85,11 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio
7585
stdout: bufio.NewScanner(input),
7686
stderr: logging,
7787

78-
responses: make(map[string]chan *JSONRPCResponse),
79-
done: make(chan struct{}),
80-
ctx: context.Background(),
81-
logger: util.DefaultLogger(),
88+
responses: make(map[string]chan *JSONRPCResponse),
89+
done: make(chan struct{}),
90+
ctx: context.Background(),
91+
logger: util.DefaultLogger(),
92+
terminateDuration: 5 * time.Second, // Default 5 second timeout
8293
}
8394
}
8495

@@ -109,10 +120,11 @@ func NewStdioWithOptions(
109120
args: args,
110121
env: env,
111122

112-
responses: make(map[string]chan *JSONRPCResponse),
113-
done: make(chan struct{}),
114-
ctx: context.Background(),
115-
logger: util.DefaultLogger(),
123+
responses: make(map[string]chan *JSONRPCResponse),
124+
done: make(chan struct{}),
125+
ctx: context.Background(),
126+
logger: util.DefaultLogger(),
127+
terminateDuration: 5 * time.Second, // Default 5 second timeout
116128
}
117129

118130
for _, opt := range opts {
@@ -189,8 +201,10 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
189201
return nil
190202
}
191203

192-
// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
193-
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
204+
// Close closes the input stream to the child process, and awaits normal
205+
// termination of the command. If the command does not exit, it is signalled to
206+
// terminate, and then eventually killed. This follows the MCP specification
207+
// for stdio transport shutdown.
194208
func (c *Stdio) Close() error {
195209
select {
196210
case <-c.done:
@@ -200,6 +214,8 @@ func (c *Stdio) Close() error {
200214
// cancel all in-flight request
201215
close(c.done)
202216

217+
// For the stdio transport, the client SHOULD initiate shutdown by:
218+
// First, closing the input stream to the child process (the server)
203219
if c.stdin != nil {
204220
if err := c.stdin.Close(); err != nil {
205221
return fmt.Errorf("failed to close stdin: %w", err)
@@ -212,7 +228,40 @@ func (c *Stdio) Close() error {
212228
}
213229

214230
if c.cmd != nil {
215-
return c.cmd.Wait()
231+
resChan := make(chan error, 1)
232+
go func() {
233+
resChan <- c.cmd.Wait()
234+
}()
235+
236+
// Waiting for the server to exit, or sending SIGTERM if the server does not exit within a reasonable time
237+
wait := func() (error, bool) {
238+
select {
239+
case err := <-resChan:
240+
return err, true
241+
case <-time.After(c.terminateDuration):
242+
}
243+
return nil, false
244+
}
245+
246+
if err, ok := wait(); ok {
247+
return err
248+
}
249+
250+
// Note the condition here: if sending SIGTERM fails, don't wait and just
251+
// move on to SIGKILL.
252+
if err := c.cmd.Process.Signal(syscall.SIGTERM); err == nil {
253+
if err, ok := wait(); ok {
254+
return err
255+
}
256+
}
257+
// Sending SIGKILL if the server does not exit within a reasonable time after SIGTERM
258+
if err := c.cmd.Process.Kill(); err != nil {
259+
return err
260+
}
261+
if err, ok := wait(); ok {
262+
return err
263+
}
264+
return fmt.Errorf("unresponsive subprocess")
216265
}
217266

218267
return nil

0 commit comments

Comments
 (0)