Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions client/inprocess.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package client

import (
"context"

"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

Expand All @@ -10,3 +13,26 @@ func NewInProcessClient(server *server.MCPServer) (*Client, error) {
inProcessTransport := transport.NewInProcessTransport(server)
return NewClient(inProcessTransport), nil
}

// NewInProcessClientWithSamplingHandler creates an in-process client with sampling support
func NewInProcessClientWithSamplingHandler(server *server.MCPServer, handler SamplingHandler) (*Client, error) {
// Create a wrapper that implements server.SamplingHandler
serverHandler := &inProcessSamplingHandlerWrapper{handler: handler}

inProcessTransport := transport.NewInProcessTransportWithOptions(server,
transport.WithSamplingHandler(serverHandler))

client := NewClient(inProcessTransport)
client.samplingHandler = handler

return client, nil
}

// inProcessSamplingHandlerWrapper wraps client.SamplingHandler to implement server.SamplingHandler
type inProcessSamplingHandlerWrapper struct {
handler SamplingHandler
}

func (w *inProcessSamplingHandlerWrapper) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
return w.handler.CreateMessage(ctx, request)
}
148 changes: 148 additions & 0 deletions client/inprocess_sampling_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package client

import (
"context"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// MockSamplingHandler implements SamplingHandler for testing
type MockSamplingHandler struct{}

func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
return &mcp.CreateMessageResult{
SamplingMessage: mcp.SamplingMessage{
Role: mcp.RoleAssistant,
Content: mcp.TextContent{
Type: "text",
Text: "Mock response from sampling handler",
},
},
Model: "mock-model",
StopReason: "endTurn",
}, nil
}

func TestInProcessSampling(t *testing.T) {
// Create server with sampling enabled
mcpServer := server.NewMCPServer("test-server", "1.0.0")
mcpServer.EnableSampling()

// Add a tool that uses sampling
mcpServer.AddTool(mcp.Tool{
Name: "test_sampling",
Description: "Test sampling functionality",
InputSchema: mcp.ToolInputSchema{
Type: "object",
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "Message to send to LLM",
},
},
Required: []string{"message"},
},
}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
message, err := request.RequireString("message")
if err != nil {
return nil, err
}

// Create sampling request
samplingRequest := mcp.CreateMessageRequest{
CreateMessageParams: mcp.CreateMessageParams{
Messages: []mcp.SamplingMessage{
{
Role: mcp.RoleUser,
Content: mcp.TextContent{
Type: "text",
Text: message,
},
},
},
MaxTokens: 100,
Temperature: 0.7,
},
}

// Request sampling from client
result, err := mcpServer.RequestSampling(ctx, samplingRequest)
if err != nil {
return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "Sampling failed: " + err.Error(),
},
},
IsError: true,
}, nil
}

return &mcp.CallToolResult{
Content: []mcp.Content{
mcp.TextContent{
Type: "text",
Text: "Sampling result: " + result.Content.(mcp.TextContent).Text,
},
},
}, nil
})

// Create client with sampling handler
mockHandler := &MockSamplingHandler{}
client, err := NewInProcessClientWithSamplingHandler(mcpServer, mockHandler)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}
defer client.Close()

// Start the client
ctx := context.Background()
if err := client.Start(ctx); err != nil {
t.Fatalf("Failed to start client: %v", err)
}

// Initialize
initRequest := mcp.InitializeRequest{}
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initRequest.Params.ClientInfo = mcp.Implementation{
Name: "test-client",
Version: "1.0.0",
}

_, err = client.Initialize(ctx, initRequest)
if err != nil {
t.Fatalf("Failed to initialize: %v", err)
}

// Call the tool that uses sampling
result, err := client.CallTool(ctx, mcp.CallToolRequest{
Params: mcp.CallToolParams{
Name: "test_sampling",
Arguments: map[string]any{
"message": "Hello, world!",
},
},
})
if err != nil {
t.Fatalf("Tool call failed: %v", err)
}

// Verify the result contains the mock response
if len(result.Content) == 0 {
t.Fatal("Expected content in result")
}

textContent, ok := result.Content[0].(mcp.TextContent)
if !ok {
t.Fatal("Expected text content")
}

expectedText := "Sampling result: Mock response from sampling handler"
if textContent.Text != expectedText {
t.Errorf("Expected %q, got %q", expectedText, textContent.Text)
}
}
43 changes: 41 additions & 2 deletions client/transport/inprocess.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,50 @@ import (
)

type InProcessTransport struct {
server *server.MCPServer
server *server.MCPServer
samplingHandler server.SamplingHandler
session *server.InProcessSession
sessionID string

onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
}

type InProcessOption func(*InProcessTransport)

func WithSamplingHandler(handler server.SamplingHandler) InProcessOption {
return func(t *InProcessTransport) {
t.samplingHandler = handler
}
}

func NewInProcessTransport(server *server.MCPServer) *InProcessTransport {
return &InProcessTransport{
server: server,
}
}

func NewInProcessTransportWithOptions(server *server.MCPServer, opts ...InProcessOption) *InProcessTransport {
t := &InProcessTransport{
server: server,
sessionID: server.GenerateInProcessSessionID(),
}

for _, opt := range opts {
opt(t)
}

return t
}

func (c *InProcessTransport) Start(ctx context.Context) error {
// Create and register session if we have a sampling handler
if c.samplingHandler != nil {
c.session = server.NewInProcessSession(c.sessionID, c.samplingHandler)
if err := c.server.RegisterSession(ctx, c.session); err != nil {
return fmt.Errorf("failed to register session: %w", err)
}
}
return nil
}

Expand All @@ -34,6 +65,11 @@ func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCReq
}
requestBytes = append(requestBytes, '\n')

// Add session to context if available
if c.session != nil {
ctx = c.server.WithContext(ctx, c.session)
}

respMessage := c.server.HandleMessage(ctx, requestBytes)
respByte, err := json.Marshal(respMessage)
if err != nil {
Expand Down Expand Up @@ -65,7 +101,10 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc
c.onNotification = handler
}

func (*InProcessTransport) Close() error {
func (c *InProcessTransport) Close() error {
if c.session != nil {
c.server.UnregisterSession(context.Background(), c.sessionID)
}
return nil
}

Expand Down
39 changes: 39 additions & 0 deletions examples/inprocess_sampling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# InProcess Sampling Example

This example demonstrates how to use sampling with in-process MCP client/server communication.

## Overview

The example shows:
- Creating an MCP server with sampling enabled
- Adding a tool that uses sampling to request LLM completions
- Creating an in-process client with a sampling handler
- Making tool calls that trigger sampling requests

## Key Components

### Server Side
- `mcpServer.EnableSampling()` - Enables sampling capability
- Tool handler calls `mcpServer.RequestSampling()` to request LLM completions
- Sampling requests are handled directly by the client's sampling handler

### Client Side
- `MockSamplingHandler` - Implements the `SamplingHandler` interface
- `NewInProcessClientWithSamplingHandler()` - Creates client with sampling support
- The handler receives sampling requests and returns mock LLM responses

## Running the Example

```bash
go run main.go
```

## Expected Output

```
Tool result: LLM Response (model: mock-llm-v1): Mock LLM response to: 'What is the capital of France?'
```

## Real LLM Integration

To integrate with a real LLM service (OpenAI, Anthropic, etc.), replace the `MockSamplingHandler` with an implementation that calls your preferred LLM API. See the [client sampling documentation](https://mcp-go.dev/clients/advanced-sampling) for examples with real LLM providers.
Loading