Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
.claude
coverage.out
coverage.txt
.vscode/launch.json
72 changes: 72 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type Client struct {
serverCapabilities mcp.ServerCapabilities
protocolVersion string
samplingHandler SamplingHandler
rootsHandler RootsHandler
elicitationHandler ElicitationHandler
}

Expand All @@ -44,6 +45,15 @@ func WithSamplingHandler(handler SamplingHandler) ClientOption {
}
}

// WithRootsHandler sets the roots handler for the client.
// WithRootsHandler returns a ClientOption that sets the client's RootsHandler.
// When provided, the client will declare the roots capability (ListChanged) during initialization.
func WithRootsHandler(handler RootsHandler) ClientOption {
return func(c *Client) {
c.rootsHandler = handler
}
}

// WithElicitationHandler sets the elicitation handler for the client.
// When set, the client will declare elicitation capability during initialization.
func WithElicitationHandler(handler ElicitationHandler) ClientOption {
Expand Down Expand Up @@ -177,6 +187,13 @@ func (c *Client) Initialize(
if c.samplingHandler != nil {
capabilities.Sampling = &struct{}{}
}
if c.rootsHandler != nil {
capabilities.Roots = &struct {
ListChanged bool `json:"listChanged,omitempty"`
}{
ListChanged: true,
}
}
// Add elicitation capability if handler is configured
if c.elicitationHandler != nil {
capabilities.Elicitation = &struct{}{}
Expand Down Expand Up @@ -464,6 +481,28 @@ func (c *Client) Complete(
return &result, nil
}

// RootListChanges sends a roots list-changed notification to the server.
func (c *Client) RootListChanges(
ctx context.Context,
) error {
// Send root list changes notification
notification := mcp.JSONRPCNotification{
JSONRPC: mcp.JSONRPC_VERSION,
Notification: mcp.Notification{
Method: mcp.MethodNotificationRootsListChanged,
},
}

err := c.transport.SendNotification(ctx, notification)
if err != nil {
return fmt.Errorf(
"failed to send root list change notification: %w",
err,
)
}
return nil
}

// handleIncomingRequest processes incoming requests from the server.
// This is the main entry point for server-to-client requests like sampling and elicitation.
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
Expand All @@ -474,6 +513,8 @@ func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JS
return c.handleElicitationRequestTransport(ctx, request)
case string(mcp.MethodPing):
return c.handlePingRequestTransport(ctx, request)
case string(mcp.MethodListRoots):
return c.handleListRootsRequestTransport(ctx, request)
default:
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
}
Expand Down Expand Up @@ -536,6 +577,37 @@ func (c *Client) handleSamplingRequestTransport(ctx context.Context, request tra
return response, nil
}

// handleListRootsRequestTransport handles list roots requests at the transport level.
func (c *Client) handleListRootsRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.rootsHandler == nil {
return nil, fmt.Errorf("no roots handler configured")
}

// Create the MCP request
mcpRequest := mcp.ListRootsRequest{
Request: mcp.Request{
Method: string(mcp.MethodListRoots),
},
}

// Call the list roots handler
result, err := c.rootsHandler.ListRoots(ctx, mcpRequest)
if err != nil {
return nil, err
}

// Marshal the result
resultBytes, err := json.Marshal(result)
if err != nil {
return nil, fmt.Errorf("failed to marshal result: %w", err)
}

// Create the transport response
response := transport.NewJSONRPCResultResponse(request.ID, json.RawMessage(resultBytes))

return response, nil
}

// handleElicitationRequestTransport handles elicitation requests at the transport level.
func (c *Client) handleElicitationRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
if c.elicitationHandler == nil {
Expand Down
17 changes: 17 additions & 0 deletions client/roots.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package client

import (
"context"

"github.com/mark3labs/mcp-go/mcp"
)

// RootsHandler defines the interface for handling roots requests from servers.
// Clients can implement this interface to provide roots list to servers.
type RootsHandler interface {
// ListRoots handles a list root request from the server and returns the roots list.
// The implementation should:
// 1. Validate input against the requested schema
// 2. Return the appropriate response
ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error)
}
11 changes: 9 additions & 2 deletions client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type InProcessTransport struct {
server *server.MCPServer
samplingHandler server.SamplingHandler
elicitationHandler server.ElicitationHandler
rootsHandler server.RootsHandler
session *server.InProcessSession
sessionID string

Expand All @@ -37,6 +38,12 @@ func WithElicitationHandler(handler server.ElicitationHandler) InProcessOption {
}
}

func WithRootsHandler(handler server.RootsHandler) InProcessOption {
return func(t *InProcessTransport) {
t.rootsHandler = handler
}
}

func NewInProcessTransport(server *server.MCPServer) *InProcessTransport {
return &InProcessTransport{
server: server,
Expand Down Expand Up @@ -66,8 +73,8 @@ func (c *InProcessTransport) Start(ctx context.Context) error {
c.startedMu.Unlock()

// Create and register session if we have handlers
if c.samplingHandler != nil || c.elicitationHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler)
if c.samplingHandler != nil || c.elicitationHandler != nil || c.rootsHandler != nil {
c.session = server.NewInProcessSessionWithHandlers(c.sessionID, c.samplingHandler, c.elicitationHandler, c.rootsHandler)
if err := c.server.RegisterSession(ctx, c.session); err != nil {
c.startedMu.Lock()
c.started = false
Expand Down
163 changes: 163 additions & 0 deletions examples/roots_client/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package main

import (
"context"
"fmt"
"log"
"net/url"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
)

// fileURI returns a file:// URI for both Unix and Windows absolute paths.
func fileURI(p string) string {
p = filepath.ToSlash(p)
if !strings.HasPrefix(p, "/") { // e.g., "C:/Users/..." on Windows
p = "/" + p
}
return (&url.URL{Scheme: "file", Path: p}).String()
}

// MockRootsHandler implements client.RootsHandler for demonstration.
// In a real implementation, this would enumerate workspace/project roots.
type MockRootsHandler struct{}

// ListRoots implements client.RootsHandler by returning example workspace roots.
func (h *MockRootsHandler) ListRoots(ctx context.Context, request mcp.ListRootsRequest) (*mcp.ListRootsResult, error) {
home, err := os.UserHomeDir()
if err != nil {
log.Printf("Warning: failed to get home directory: %v", err)
home = "/tmp" // fallback for demonstration
}
app := filepath.ToSlash(filepath.Join(home, "app"))
proj := filepath.ToSlash(filepath.Join(home, "projects", "test-project"))
result := &mcp.ListRootsResult{
Roots: []mcp.Root{
{
Name: "app",
URI: fileURI(app),
},
{
Name: "test-project",
URI: fileURI(proj),
},
},
}
return result, nil
}

// main starts a mock MCP roots client that communicates with a subprocess over stdio.
// It expects the server command as the first command-line argument, creates a stdio
// transport and an MCP client with a MockRootsHandler, starts and initializes the
// client, logs server info and available tools, notifies the server of root list
// changes, invokes the "roots" tool and prints any text content returned, and
// shuts down the client gracefully on SIGINT or SIGTERM.
func main() {
if len(os.Args) < 2 {
log.Fatal("Usage: roots_client <server_command>")
}

serverCommand := os.Args[1]
serverArgs := os.Args[2:]

// Create stdio transport to communicate with the server
stdio := transport.NewStdio(serverCommand, nil, serverArgs...)

// Create roots handler
rootsHandler := &MockRootsHandler{}

// Create client with roots capability
mcpClient := client.NewClient(stdio, client.WithRootsHandler(rootsHandler))

ctx := context.Background()

// Start the client
if err := mcpClient.Start(ctx); err != nil {
log.Fatalf("Failed to start client: %v", err)
}

// Setup graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

// Create a context that cancels on signal
ctx, cancel := context.WithCancel(ctx)
go func() {
<-sigChan
log.Println("Received shutdown signal, closing client...")
cancel()
}()

// Move defer after error checking
defer func() {
if err := mcpClient.Close(); err != nil {
log.Printf("Error closing client: %v", err)
}
}()

// Initialize the connection
initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{
Params: mcp.InitializeParams{
ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION,
ClientInfo: mcp.Implementation{
Name: "roots-stdio-client",
Version: "1.0.0",
},
Capabilities: mcp.ClientCapabilities{
// Roots capability will be automatically added by WithRootsHandler
},
},
})
if err != nil {
log.Fatalf("Failed to initialize: %v", err)
}

log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version)
log.Printf("Server capabilities: %+v", initResult.Capabilities)

// list tools
toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
log.Fatalf("Failed to list tools: %v", err)
}
log.Printf("Available tools:")
for _, tool := range toolsResult.Tools {
log.Printf(" - %s: %s", tool.Name, tool.Description)
}

// call server tool
request := mcp.CallToolRequest{}
request.Params.Name = "roots"
request.Params.Arguments = map[string]any{"testonly": "yes"}
result, err := mcpClient.CallTool(ctx, request)
if err != nil {
log.Fatalf("failed to call tool roots: %v", err)
} else if result.IsError {
log.Printf("tool reported error")
} else if len(result.Content) > 0 {
resultStr := ""
for _, content := range result.Content {
switch tc := content.(type) {
case mcp.TextContent:
resultStr += fmt.Sprintf("%s\n", tc.Text)
}
}
fmt.Printf("client call tool result: %s\n", resultStr)
}

// mock the root change
if err := mcpClient.RootListChanges(ctx); err != nil {
log.Printf("failed to notify root list change: %v", err)
}

// Keep running until cancelled by signal
<-ctx.Done()
log.Println("Client context cancelled")
}
Loading
Loading