Skip to content

Commit 204b273

Browse files
andigclaude
andcommitted
fix: implement EnableSampling() to properly declare sampling capability
Previously, EnableSampling() was a no-op that didn't actually enable the sampling capability in the server's declared capabilities. Changes: - Add Sampling field to mcp.ServerCapabilities struct - Add sampling field to internal serverCapabilities struct - Update EnableSampling() to set the sampling capability flag - Update handleInitialize() to include sampling in capability response - Add test to verify sampling capability is properly declared Now when EnableSampling() is called, the server will properly declare sampling capability during initialization, allowing clients to know that the server supports sending sampling requests. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 1cae3a9 commit 204b273

File tree

4 files changed

+49
-0
lines changed

4 files changed

+49
-0
lines changed

mcp/types.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ type ServerCapabilities struct {
472472
// list.
473473
ListChanged bool `json:"listChanged,omitempty"`
474474
} `json:"resources,omitempty"`
475+
// Present if the server supports sending sampling requests to clients.
476+
Sampling *struct{} `json:"sampling,omitempty"`
475477
// Present if the server offers any tools to call.
476478
Tools *struct {
477479
// Whether this server supports notifications for changes to the tool list.

server/sampling.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ import (
1212
func (s *MCPServer) EnableSampling() {
1313
s.capabilitiesMu.Lock()
1414
defer s.capabilitiesMu.Unlock()
15+
16+
enabled := true
17+
s.capabilities.sampling = &enabled
1518
}
1619

1720
// RequestSampling sends a sampling request to the client.

server/sampling_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,42 @@ func TestMCPServer_RequestSampling_Success(t *testing.T) {
113113
t.Errorf("expected model %q, got %q", "test-model", result.Model)
114114
}
115115
}
116+
117+
func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) {
118+
server := NewMCPServer("test", "1.0.0")
119+
120+
// Verify sampling capability is not set initially
121+
ctx := context.Background()
122+
initRequest := mcp.InitializeRequest{
123+
Params: mcp.InitializeParams{
124+
ProtocolVersion: "2025-03-26",
125+
ClientInfo: mcp.Implementation{
126+
Name: "test-client",
127+
Version: "1.0.0",
128+
},
129+
Capabilities: mcp.ClientCapabilities{},
130+
},
131+
}
132+
133+
result, err := server.handleInitialize(ctx, 1, initRequest)
134+
if err != nil {
135+
t.Fatalf("unexpected error: %v", err)
136+
}
137+
138+
if result.Capabilities.Sampling != nil {
139+
t.Error("sampling capability should not be set before EnableSampling() is called")
140+
}
141+
142+
// Enable sampling
143+
server.EnableSampling()
144+
145+
// Verify sampling capability is now set
146+
result, err = server.handleInitialize(ctx, 2, initRequest)
147+
if err != nil {
148+
t.Fatalf("unexpected error after EnableSampling(): %v", err)
149+
}
150+
151+
if result.Capabilities.Sampling == nil {
152+
t.Error("sampling capability should be set after EnableSampling() is called")
153+
}
154+
}

server/server.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ type serverCapabilities struct {
181181
resources *resourceCapabilities
182182
prompts *promptCapabilities
183183
logging *bool
184+
sampling *bool
184185
}
185186

186187
// resourceCapabilities defines the supported resource-related features
@@ -580,6 +581,10 @@ func (s *MCPServer) handleInitialize(
580581
capabilities.Logging = &struct{}{}
581582
}
582583

584+
if s.capabilities.sampling != nil && *s.capabilities.sampling {
585+
capabilities.Sampling = &struct{}{}
586+
}
587+
583588
result := mcp.InitializeResult{
584589
ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion),
585590
ServerInfo: mcp.Implementation{

0 commit comments

Comments
 (0)