Skip to content

Commit d22932e

Browse files
committed
client roots for stdio and pass integration test
1 parent bfe07cb commit d22932e

File tree

4 files changed

+273
-1
lines changed

4 files changed

+273
-1
lines changed

examples/roots_client/main.go

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
"os"
8+
"os/signal"
9+
"syscall"
10+
11+
"github.com/mark3labs/mcp-go/client"
12+
"github.com/mark3labs/mcp-go/client/transport"
13+
"github.com/mark3labs/mcp-go/mcp"
14+
)
15+
16+
// MockRootsHandler implements client.RootsHandler for demonstration.
17+
// In a real implementation, this would integrate with an actual LLM API.
18+
type MockRootsHandler struct{}
19+
20+
func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) {
21+
result := &mcp.ListRootsResult{
22+
Roots: []mcp.Root{
23+
{
24+
Name: "app",
25+
URI: "file:///User/haxxx/app",
26+
},
27+
{
28+
Name: "test-project",
29+
URI: "file:///User/haxxx/projects/test-project",
30+
},
31+
},
32+
}
33+
return result, nil
34+
}
35+
36+
func main() {
37+
if len(os.Args) < 2 {
38+
log.Fatal("Usage: roots_client <server_command>")
39+
}
40+
41+
serverCommand := os.Args[1]
42+
serverArgs := os.Args[2:]
43+
44+
// Create stdio transport to communicate with the server
45+
stdio := transport.NewStdio(serverCommand, nil, serverArgs...)
46+
47+
// Create roots handler
48+
rootsHandler := &MockRootsHandler{}
49+
50+
// Create client with roots capability
51+
mcpClient := client.NewClient(stdio, client.WithRootsHandler(rootsHandler))
52+
53+
ctx := context.Background()
54+
55+
// Start the client
56+
if err := mcpClient.Start(ctx); err != nil {
57+
log.Fatalf("Failed to start client: %v", err)
58+
}
59+
60+
// Setup graceful shutdown
61+
sigChan := make(chan os.Signal, 1)
62+
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
63+
64+
// Create a context that cancels on signal
65+
ctx, cancel := context.WithCancel(ctx)
66+
go func() {
67+
<-sigChan
68+
log.Println("Received shutdown signal, closing client...")
69+
cancel()
70+
}()
71+
72+
// Move defer after error checking
73+
defer func() {
74+
if err := mcpClient.Close(); err != nil {
75+
log.Printf("Error closing client: %v", err)
76+
}
77+
}()
78+
79+
// Initialize the connection
80+
initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{
81+
Params: mcp.InitializeParams{
82+
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
83+
ClientInfo: mcp.Implementation{
84+
Name: "roots-stdio-server",
85+
Version: "1.0.0",
86+
},
87+
Capabilities: mcp.ClientCapabilities{
88+
// Sampling capability will be automatically added by WithSamplingHandler
89+
},
90+
},
91+
})
92+
if err != nil {
93+
log.Fatalf("Failed to initialize: %v", err)
94+
}
95+
96+
log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version)
97+
log.Printf("Server capabilities: %+v", initResult.Capabilities)
98+
99+
// list tools
100+
toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
101+
if err != nil {
102+
log.Fatalf("Failed to list tools: %v", err)
103+
}
104+
log.Printf("Available tools:")
105+
for _, tool := range toolsResult.Tools {
106+
log.Printf(" - %s: %s", tool.Name, tool.Description)
107+
}
108+
109+
// mock the root change
110+
if err := mcpClient.RootListChanges(ctx); err != nil {
111+
log.Printf("fail to notify root list change: %v", err)
112+
}
113+
114+
// call server tool
115+
request := mcp.CallToolRequest{}
116+
request.Params.Name = "roots"
117+
request.Params.Arguments = "{\"testonly\": \"yes\"}"
118+
result, err := mcpClient.CallTool(ctx, request)
119+
if err != nil {
120+
log.Fatalf("failed to call tool roots: %v", err)
121+
} else if len(result.Content) > 0 {
122+
resultStr := ""
123+
for _, content := range result.Content {
124+
if textContent, ok := content.(mcp.TextContent); ok {
125+
resultStr += fmt.Sprintf("%s\n", textContent.Text)
126+
}
127+
}
128+
fmt.Printf("client call tool result: %s", resultStr)
129+
}
130+
}

examples/roots_http_server/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func main() {
6868
}
6969
})
7070

71-
log.Println("Starting MCP server with roots support")
71+
log.Println("Starting MCP Http server with roots support")
7272
log.Println("Http Endpoint: http://localhost:8080/mcp")
7373
log.Println("")
7474
log.Println("This server supports roots over HTTP transport.")

examples/roots_server/main.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log"
7+
8+
"github.com/mark3labs/mcp-go/mcp"
9+
"github.com/mark3labs/mcp-go/server"
10+
)
11+
12+
func handleNotification(ctx context.Context, notification mcp.JSONRPCNotification) {
13+
fmt.Printf("notification received: %v", notification.Notification.Method)
14+
}
15+
16+
func main() {
17+
// Enable roots capability
18+
opts := []server.ServerOption{
19+
server.WithToolCapabilities(true),
20+
server.WithRoots(),
21+
}
22+
// Create MCP server with roots capability
23+
mcpServer := server.NewMCPServer("roots-stdio-server", "1.0.0", opts...)
24+
25+
// Add list root list change notification
26+
mcpServer.AddNotificationHandler(mcp.MethodNotificationToolsListChanged, handleNotification)
27+
mcpServer.EnableSampling()
28+
29+
// Add a simple tool to test roots list
30+
mcpServer.AddTool(mcp.Tool{
31+
Name: "roots",
32+
Description: "list root result",
33+
InputSchema: mcp.ToolInputSchema{
34+
Type: "object",
35+
Properties: map[string]any{
36+
"testonly": map[string]any{
37+
"type": "string",
38+
"description": "is this test only?",
39+
},
40+
},
41+
Required: []string{"testonly"},
42+
},
43+
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
44+
rootRequest := mcp.ListRootsRequest{
45+
Request: mcp.Request{
46+
Method: string(mcp.MethodListRoots),
47+
},
48+
}
49+
50+
if result, err := mcpServer.RequestRoots(ctx, rootRequest); err == nil {
51+
return &mcp.CallToolResult{
52+
Content: []mcp.Content{
53+
mcp.TextContent{
54+
Type: "text",
55+
Text: fmt.Sprintf("Root list: %v", result.Roots),
56+
},
57+
},
58+
}, nil
59+
60+
} else {
61+
return &mcp.CallToolResult{
62+
Content: []mcp.Content{
63+
mcp.TextContent{
64+
Type: "text",
65+
Text: fmt.Sprintf("Fail to list root, %v", err),
66+
},
67+
},
68+
}, err
69+
}
70+
})
71+
72+
// Create stdio server
73+
if err := server.ServeStdio(mcpServer); err != nil {
74+
log.Fatalf("Server Stdio error: %v\n", err)
75+
}
76+
}

server/stdio.go

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,11 @@ func (s *StdioServer) processMessage(
592592
return nil
593593
}
594594

595+
// Check if this is a response to an list roots request
596+
if s.handleListRootsResponse(rawMessage) {
597+
return nil
598+
}
599+
595600
// Check if this is a tool call that might need sampling (and thus should be processed concurrently)
596601
var baseMessage struct {
597602
Method string `json:"method"`
@@ -762,6 +767,67 @@ func (s *stdioSession) handleElicitationResponse(rawMessage json.RawMessage) boo
762767
return true
763768
}
764769

770+
// handleListRootsResponse checks if the message is a response to an list roots request
771+
// and routes it to the appropriate pending request channel.
772+
func (s *StdioServer) handleListRootsResponse(rawMessage json.RawMessage) bool {
773+
return stdioSessionInstance.handleListRootsResponse(rawMessage)
774+
}
775+
776+
// handleListRootsResponse handles incoming list root responses for this session
777+
func (s *stdioSession) handleListRootsResponse(rawMessage json.RawMessage) bool {
778+
// Try to parse as a JSON-RPC response
779+
var response struct {
780+
JSONRPC string `json:"jsonrpc"`
781+
ID json.Number `json:"id"`
782+
Result json.RawMessage `json:"result,omitempty"`
783+
Error *struct {
784+
Code int `json:"code"`
785+
Message string `json:"message"`
786+
} `json:"error,omitempty"`
787+
}
788+
789+
if err := json.Unmarshal(rawMessage, &response); err != nil {
790+
return false
791+
}
792+
// Parse the ID as int64
793+
id, err := response.ID.Int64()
794+
if err != nil || (response.Result == nil && response.Error == nil) {
795+
return false
796+
}
797+
798+
// Check if we have a pending list root request with this ID
799+
s.pendingMu.RLock()
800+
responseChan, exists := s.pendingRoots[id]
801+
s.pendingMu.RUnlock()
802+
803+
if !exists {
804+
return false
805+
}
806+
807+
// Parse and send the response
808+
rootsResp := &rootsResponse{}
809+
810+
if response.Error != nil {
811+
rootsResp.err = fmt.Errorf("list root request failed: %s", response.Error.Message)
812+
} else {
813+
var result mcp.ListRootsResult
814+
if err := json.Unmarshal(response.Result, &result); err != nil {
815+
rootsResp.err = fmt.Errorf("failed to unmarshal list root response: %w", err)
816+
} else {
817+
rootsResp.result = &result
818+
}
819+
}
820+
821+
// Send the response (non-blocking)
822+
select {
823+
case responseChan <- rootsResp:
824+
default:
825+
// Channel is full or closed, ignore
826+
}
827+
828+
return true
829+
}
830+
765831
// writeResponse marshals and writes a JSON-RPC response message followed by a newline.
766832
// Returns an error if marshaling or writing fails.
767833
func (s *StdioServer) writeResponse(

0 commit comments

Comments
 (0)