Skip to content

Commit 79af76f

Browse files
committed
fix deadlock sendTransport
1 parent 656a7b4 commit 79af76f

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

client/transport/stdio.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ func (c *Stdio) Start(ctx context.Context) error {
123123
// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess;
124124
// otherwise, the default behavior uses exec.CommandContext with the merged environment.
125125
// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication.
126+
// A background goroutine is also started to wait for the subprocess to exit,
127+
// ensuring that the done channel is closed automatically if the process terminates unexpectedly.
126128
func (c *Stdio) spawnCommand(ctx context.Context) error {
127129
if c.command == "" {
128130
return nil
@@ -163,6 +165,16 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
163165
return fmt.Errorf("failed to start command: %w", err)
164166
}
165167

168+
go func() {
169+
_ = cmd.Wait()
170+
select {
171+
case <-c.done:
172+
// Already closed explicitly (via Close), do nothing
173+
default:
174+
close(c.done) // Automatically signal subprocess exit
175+
}
176+
}()
177+
166178
return nil
167179
}
168180

@@ -191,6 +203,16 @@ func (c *Stdio) Close() error {
191203
return nil
192204
}
193205

206+
// IsClosed reports whether the subprocess has exited and the transport is no longer usable.
207+
func (c *Stdio) IsClosed() bool {
208+
select {
209+
case <-c.done:
210+
return true
211+
default:
212+
return false
213+
}
214+
}
215+
194216
// GetSessionId returns the session ID of the transport.
195217
// Since stdio does not maintain a session ID, it returns an empty string.
196218
func (c *Stdio) GetSessionId() string {
@@ -293,6 +315,11 @@ func (c *Stdio) SendRequest(
293315
c.mu.Unlock()
294316
}
295317

318+
if c.IsClosed() {
319+
deleteResponseChan()
320+
return nil, fmt.Errorf("cannot send request: subprocess is closed")
321+
}
322+
296323
// Send request
297324
if _, err := c.stdin.Write(requestBytes); err != nil {
298325
deleteResponseChan()
@@ -303,6 +330,9 @@ func (c *Stdio) SendRequest(
303330
case <-ctx.Done():
304331
deleteResponseChan()
305332
return nil, ctx.Err()
333+
case <-c.done:
334+
deleteResponseChan()
335+
return nil, fmt.Errorf("subprocess exited while waiting for response")
306336
case response := <-responseChan:
307337
return response, nil
308338
}

client/transport/stdio_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,31 @@ func TestStdio(t *testing.T) {
384384
t.Errorf("Expected array with 3 items, got %v", result.Params["array"])
385385
}
386386
})
387+
388+
t.Run("SendRequestFailsIfSubprocessExited", func(t *testing.T) {
389+
// Start a subprocess that exits immediately
390+
ctx := context.Background()
391+
stdio := NewStdio("sh", nil, "-c", "exit 0")
392+
393+
err := stdio.Start(ctx)
394+
require.NoError(t, err)
395+
396+
// Wait for subprocess to exit
397+
require.Eventually(t, func() bool {
398+
return stdio.IsClosed()
399+
}, time.Second, 10*time.Millisecond)
400+
401+
// Try to send a request
402+
_, err = stdio.SendRequest(ctx, JSONRPCRequest{
403+
JSONRPC: "2.0",
404+
ID: mcp.NewRequestId("dead"),
405+
Method: "noop",
406+
})
407+
408+
require.Error(t, err)
409+
require.Contains(t, err.Error(), "subprocess")
410+
})
411+
387412
}
388413

389414
func TestStdioErrors(t *testing.T) {
@@ -609,6 +634,32 @@ func TestStdio_SpawnCommand_UsesCommandFunc_Error(t *testing.T) {
609634
require.EqualError(t, err, "test error")
610635
}
611636

637+
func TestStdio_DoneClosedWhenSubcommandExits(t *testing.T) {
638+
ctx := context.Background()
639+
640+
stdio := NewStdioWithOptions(
641+
"sh",
642+
nil,
643+
[]string{"-c", "exit 0"},
644+
)
645+
646+
require.NotNil(t, stdio)
647+
648+
err := stdio.spawnCommand(ctx)
649+
require.NoError(t, err)
650+
651+
t.Cleanup(func() {
652+
if stdio.cmd.Process != nil {
653+
_ = stdio.cmd.Process.Kill()
654+
}
655+
})
656+
657+
// Wait up to 200ms for the done channel to close
658+
require.Eventually(t, func() bool {
659+
return stdio.IsClosed()
660+
}, 200*time.Millisecond, 10*time.Millisecond, "expected done to be closed after subprocess exited")
661+
}
662+
612663
func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) {
613664
configured := false
614665

@@ -620,3 +671,28 @@ func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) {
620671
require.NotNil(t, stdio)
621672
require.True(t, configured, "option was not applied")
622673
}
674+
675+
func TestStdio_IsClosed(t *testing.T) {
676+
t.Run("returns false before Start", func(t *testing.T) {
677+
stdio := NewStdio("sh", nil, "-c", "sleep 1")
678+
require.False(t, stdio.IsClosed(), "expected IsClosed to be false before Start")
679+
})
680+
681+
t.Run("returns false after Start", func(t *testing.T) {
682+
stdio := NewStdio("sh", nil, "-c", "sleep 1")
683+
err := stdio.Start(context.Background())
684+
require.NoError(t, err)
685+
defer stdio.Close()
686+
require.False(t, stdio.IsClosed(), "expected IsClosed to be false right after Start")
687+
})
688+
689+
t.Run("returns true after subprocess exits", func(t *testing.T) {
690+
stdio := NewStdio("sh", nil, "-c", "exit 0")
691+
err := stdio.Start(context.Background())
692+
require.NoError(t, err)
693+
694+
require.Eventually(t, func() bool {
695+
return stdio.IsClosed()
696+
}, 200*time.Millisecond, 10*time.Millisecond, "expected IsClosed to return true after subprocess exits")
697+
})
698+
}

0 commit comments

Comments
 (0)