Skip to content

Commit 5088c93

Browse files
authored
feat: add SessionWithResourceTemplates for session-specific resource templates (#624)
* feat: add SessionWithResourceTemplates for session-specific resource templates Implements session-specific resource templates to achieve parity with SessionWithTools and SessionWithResources. This allows sessions to have their own resource templates that override global templates with the same URI pattern. Key changes: - Add SessionWithResourceTemplates interface to ClientSession hierarchy - Implement interface in both SSE and StreamableHTTP transports - Add AddSessionResourceTemplate(s) and DeleteSessionResourceTemplates methods - Update handleListResourceTemplates to merge session and global templates - Update handleReadResource to check session templates before global ones - Session templates trigger notifications/resources/list_changed when modified Closes #622 * test: add comprehensive tests for session resource templates - Add sessionTestClientWithResourceTemplates mock - Test AddSessionResourceTemplate and AddSessionResourceTemplates - Test DeleteSessionResourceTemplates - Test session template override behavior - Test notification behavior (enabled/disabled) - Test uninitialized session handling - Test error cases for unsupported sessions - Verify thread-safety through existing patterns Coverage increased from 70.17% to 72.8% (+2.63%) * refactor: address CodeRabbit review comments - Use atomic.Bool for initialized field in test mock to prevent data races - Add nil checks for URITemplate in handleReadResource for both session and global templates - Add validation in AddSessionResourceTemplates to prevent nil URITemplate, empty URI, or empty Name - Simplify Get/SetSessionResourceTemplates using maps.Clone - Fix test initialization states to match expected behavior All tests pass with race detector.
1 parent 8b7d60c commit 5088c93

File tree

7 files changed

+877
-46
lines changed

7 files changed

+877
-46
lines changed

server/errors.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ var (
1313
ErrToolNotFound = errors.New("tool not found")
1414

1515
// Session-related errors
16-
ErrSessionNotFound = errors.New("session not found")
17-
ErrSessionExists = errors.New("session already exists")
18-
ErrSessionNotInitialized = errors.New("session not properly initialized")
19-
ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
20-
ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources")
21-
ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")
16+
ErrSessionNotFound = errors.New("session not found")
17+
ErrSessionExists = errors.New("session already exists")
18+
ErrSessionNotInitialized = errors.New("session not properly initialized")
19+
ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools")
20+
ErrSessionDoesNotSupportResources = errors.New("session does not support per-session resources")
21+
ErrSessionDoesNotSupportResourceTemplates = errors.New("session does not support resource templates")
22+
ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level")
2223

2324
// Notification-related errors
2425
ErrNotificationNotInitialized = errors.New("notification channel not initialized")

server/server.go

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -880,12 +880,34 @@ func (s *MCPServer) handleListResourceTemplates(
880880
id any,
881881
request mcp.ListResourceTemplatesRequest,
882882
) (*mcp.ListResourceTemplatesResult, *requestError) {
883+
// Get global templates
883884
s.resourcesMu.RLock()
884-
templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
885-
for _, entry := range s.resourceTemplates {
886-
templates = append(templates, entry.template)
885+
templateMap := make(map[string]mcp.ResourceTemplate, len(s.resourceTemplates))
886+
for uri, entry := range s.resourceTemplates {
887+
templateMap[uri] = entry.template
887888
}
888889
s.resourcesMu.RUnlock()
890+
891+
// Check if there are session-specific resource templates
892+
session := ClientSessionFromContext(ctx)
893+
if session != nil {
894+
if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok {
895+
if sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates(); sessionTemplates != nil {
896+
// Merge session-specific templates with global templates
897+
// Session templates override global ones
898+
for uriTemplate, serverTemplate := range sessionTemplates {
899+
templateMap[uriTemplate] = serverTemplate.Template
900+
}
901+
}
902+
}
903+
}
904+
905+
// Convert map to slice for sorting and pagination
906+
templates := make([]mcp.ResourceTemplate, 0, len(templateMap))
907+
for _, template := range templateMap {
908+
templates = append(templates, template)
909+
}
910+
889911
sort.Slice(templates, func(i, j int) bool {
890912
return templates[i].Name < templates[j].Name
891913
})
@@ -971,18 +993,48 @@ func (s *MCPServer) handleReadResource(
971993
// If no direct handler found, try matching against templates
972994
var matchedHandler ResourceTemplateHandlerFunc
973995
var matched bool
974-
for _, entry := range s.resourceTemplates {
975-
template := entry.template
976-
if matchesTemplate(request.Params.URI, template.URITemplate) {
977-
matchedHandler = entry.handler
978-
matched = true
979-
matchedVars := template.URITemplate.Match(request.Params.URI)
980-
// Convert matched variables to a map
981-
request.Params.Arguments = make(map[string]any, len(matchedVars))
982-
for name, value := range matchedVars {
983-
request.Params.Arguments[name] = value.V
996+
997+
// First check session templates if available
998+
if session != nil {
999+
if sessionWithTemplates, ok := session.(SessionWithResourceTemplates); ok {
1000+
sessionTemplates := sessionWithTemplates.GetSessionResourceTemplates()
1001+
for _, serverTemplate := range sessionTemplates {
1002+
if serverTemplate.Template.URITemplate == nil {
1003+
continue
1004+
}
1005+
if matchesTemplate(request.Params.URI, serverTemplate.Template.URITemplate) {
1006+
matchedHandler = serverTemplate.Handler
1007+
matched = true
1008+
matchedVars := serverTemplate.Template.URITemplate.Match(request.Params.URI)
1009+
// Convert matched variables to a map
1010+
request.Params.Arguments = make(map[string]any, len(matchedVars))
1011+
for name, value := range matchedVars {
1012+
request.Params.Arguments[name] = value.V
1013+
}
1014+
break
1015+
}
1016+
}
1017+
}
1018+
}
1019+
1020+
// If not found in session templates, check global templates
1021+
if !matched {
1022+
for _, entry := range s.resourceTemplates {
1023+
template := entry.template
1024+
if template.URITemplate == nil {
1025+
continue
1026+
}
1027+
if matchesTemplate(request.Params.URI, template.URITemplate) {
1028+
matchedHandler = entry.handler
1029+
matched = true
1030+
matchedVars := template.URITemplate.Match(request.Params.URI)
1031+
// Convert matched variables to a map
1032+
request.Params.Arguments = make(map[string]any, len(matchedVars))
1033+
for name, value := range matchedVars {
1034+
request.Params.Arguments[name] = value.V
1035+
}
1036+
break
9841037
}
985-
break
9861038
}
9871039
}
9881040
s.resourcesMu.RUnlock()

server/session.go

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ type SessionWithResources interface {
5151
SetSessionResources(resources map[string]ServerResource)
5252
}
5353

54+
// SessionWithResourceTemplates is an extension of ClientSession that can store session-specific resource template data
55+
type SessionWithResourceTemplates interface {
56+
ClientSession
57+
// GetSessionResourceTemplates returns the resource templates specific to this session, if any
58+
// This method must be thread-safe for concurrent access
59+
GetSessionResourceTemplates() map[string]ServerResourceTemplate
60+
// SetSessionResourceTemplates sets resource templates specific to this session
61+
// This method must be thread-safe for concurrent access
62+
SetSessionResourceTemplates(templates map[string]ServerResourceTemplate)
63+
}
64+
5465
// SessionWithClientInfo is an extension of ClientSession that can store client info
5566
type SessionWithClientInfo interface {
5667
ClientSession
@@ -613,3 +624,137 @@ func (s *MCPServer) DeleteSessionResources(sessionID string, uris ...string) err
613624

614625
return nil
615626
}
627+
628+
// AddSessionResourceTemplate adds a resource template for a specific session
629+
func (s *MCPServer) AddSessionResourceTemplate(sessionID string, template mcp.ResourceTemplate, handler ResourceTemplateHandlerFunc) error {
630+
return s.AddSessionResourceTemplates(sessionID, ServerResourceTemplate{
631+
Template: template,
632+
Handler: handler,
633+
})
634+
}
635+
636+
// AddSessionResourceTemplates adds resource templates for a specific session
637+
func (s *MCPServer) AddSessionResourceTemplates(sessionID string, templates ...ServerResourceTemplate) error {
638+
sessionValue, ok := s.sessions.Load(sessionID)
639+
if !ok {
640+
return ErrSessionNotFound
641+
}
642+
643+
session, ok := sessionValue.(SessionWithResourceTemplates)
644+
if !ok {
645+
return ErrSessionDoesNotSupportResourceTemplates
646+
}
647+
648+
// For session resource templates, enable listChanged by default
649+
// This is the same behavior as session resources
650+
s.implicitlyRegisterCapabilities(
651+
func() bool { return s.capabilities.resources != nil },
652+
func() { s.capabilities.resources = &resourceCapabilities{listChanged: true} },
653+
)
654+
655+
// Get existing templates (this returns a thread-safe copy)
656+
sessionTemplates := session.GetSessionResourceTemplates()
657+
658+
// Create a new map to avoid modifying the returned copy
659+
newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates)+len(templates))
660+
661+
// Copy existing templates
662+
for k, v := range sessionTemplates {
663+
newTemplates[k] = v
664+
}
665+
666+
// Validate and add new templates
667+
for _, t := range templates {
668+
if t.Template.URITemplate == nil {
669+
return fmt.Errorf("resource template URITemplate cannot be nil")
670+
}
671+
raw := t.Template.URITemplate.Raw()
672+
if raw == "" {
673+
return fmt.Errorf("resource template URITemplate cannot be empty")
674+
}
675+
if t.Template.Name == "" {
676+
return fmt.Errorf("resource template name cannot be empty")
677+
}
678+
newTemplates[raw] = t
679+
}
680+
681+
// Set the new templates (this method must handle thread-safety)
682+
session.SetSessionResourceTemplates(newTemplates)
683+
684+
// Send notification if the session is initialized and listChanged is enabled
685+
if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
686+
if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil {
687+
// Log the error but don't fail the operation
688+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
689+
hooks := s.hooks
690+
go func(sID string, hooks *Hooks) {
691+
ctx := context.Background()
692+
hooks.onError(ctx, nil, "notification", map[string]any{
693+
"method": "notifications/resources/list_changed",
694+
"sessionID": sID,
695+
}, fmt.Errorf("failed to send notification after adding resource templates: %w", err))
696+
}(sessionID, hooks)
697+
}
698+
}
699+
}
700+
701+
return nil
702+
}
703+
704+
// DeleteSessionResourceTemplates removes resource templates from a specific session
705+
func (s *MCPServer) DeleteSessionResourceTemplates(sessionID string, uriTemplates ...string) error {
706+
sessionValue, ok := s.sessions.Load(sessionID)
707+
if !ok {
708+
return ErrSessionNotFound
709+
}
710+
711+
session, ok := sessionValue.(SessionWithResourceTemplates)
712+
if !ok {
713+
return ErrSessionDoesNotSupportResourceTemplates
714+
}
715+
716+
// Get existing templates (this returns a thread-safe copy)
717+
sessionTemplates := session.GetSessionResourceTemplates()
718+
719+
// Track if any were actually deleted
720+
deletedAny := false
721+
722+
// Create a new map without the deleted templates
723+
newTemplates := make(map[string]ServerResourceTemplate, len(sessionTemplates))
724+
for k, v := range sessionTemplates {
725+
newTemplates[k] = v
726+
}
727+
728+
// Delete specified templates
729+
for _, uriTemplate := range uriTemplates {
730+
if _, exists := newTemplates[uriTemplate]; exists {
731+
delete(newTemplates, uriTemplate)
732+
deletedAny = true
733+
}
734+
}
735+
736+
// Only update if something was actually deleted
737+
if deletedAny {
738+
// Set the new templates (this method must handle thread-safety)
739+
session.SetSessionResourceTemplates(newTemplates)
740+
741+
// Send notification if the session is initialized and listChanged is enabled
742+
if session.Initialized() && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
743+
if err := s.SendNotificationToSpecificClient(sessionID, "notifications/resources/list_changed", nil); err != nil {
744+
// Log the error but don't fail the operation
745+
if s.hooks != nil && len(s.hooks.OnError) > 0 {
746+
hooks := s.hooks
747+
go func(sID string, hooks *Hooks) {
748+
ctx := context.Background()
749+
hooks.onError(ctx, nil, "notification", map[string]any{
750+
"method": "notifications/resources/list_changed",
751+
"sessionID": sID,
752+
}, fmt.Errorf("failed to send notification after deleting resource templates: %w", err))
753+
}(sessionID, hooks)
754+
}
755+
}
756+
}
757+
}
758+
759+
return nil
760+
}

0 commit comments

Comments
 (0)