Skip to content

Commit f06b092

Browse files
committed
fix data race in test
1 parent 5fc3c83 commit f06b092

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

client/sse_test.go

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"github.com/stretchr/testify/assert"
77
"net/http"
8+
"sync"
89
"testing"
910
"time"
1011

@@ -460,16 +461,19 @@ func TestSSEMCPClient(t *testing.T) {
460461
t.Fatalf("Failed to create client: %v", err)
461462
}
462463

464+
mu := sync.Mutex{}
463465
notificationNum := 0
464466
var messageNotification *mcp.JSONRPCNotification
465467
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
466468
client.OnNotification(func(notification mcp.JSONRPCNotification) {
469+
mu.Lock()
467470
if notification.Method == string(mcp.MethodNotificationMessage) {
468471
messageNotification = &notification
469472
} else if notification.Method == string(mcp.MethodNotificationProgress) {
470473
progressNotifications = append(progressNotifications, &notification)
471474
}
472475
notificationNum += 1
476+
mu.Unlock()
473477
})
474478
defer client.Close()
475479

@@ -516,6 +520,7 @@ func TestSSEMCPClient(t *testing.T) {
516520

517521
time.Sleep(time.Millisecond * 200)
518522

523+
mu.Lock()
519524
assert.Equal(t, notificationNum, 3)
520525
assert.NotNil(t, messageNotification)
521526
assert.Equal(t, messageNotification.Method, string(mcp.MethodNotificationMessage))
@@ -537,6 +542,7 @@ func TestSSEMCPClient(t *testing.T) {
537542
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"])
538543
assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"])
539544
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"])
545+
mu.Lock()
540546
})
541547

542548
t.Run("Ensure the server does not send notifications", func(t *testing.T) {
@@ -545,9 +551,12 @@ func TestSSEMCPClient(t *testing.T) {
545551
t.Fatalf("Failed to create client: %v", err)
546552
}
547553

554+
mu := sync.Mutex{}
548555
notifications := make([]*mcp.JSONRPCNotification, 0)
549556
client.OnNotification(func(notification mcp.JSONRPCNotification) {
557+
mu.Lock()
550558
notifications = append(notifications, &notification)
559+
mu.Unlock()
551560
})
552561
defer client.Close()
553562

@@ -585,7 +594,9 @@ func TestSSEMCPClient(t *testing.T) {
585594
_, _ = client.CallTool(ctx, request)
586595
time.Sleep(time.Millisecond * 200)
587596

597+
mu.Lock()
588598
assert.Len(t, notifications, 0)
599+
mu.Unlock()
589600
})
590601

591602
t.Run("GetPrompt for testing log and progress notification", func(t *testing.T) {
@@ -594,17 +605,19 @@ func TestSSEMCPClient(t *testing.T) {
594605
t.Fatalf("Failed to create client: %v", err)
595606
}
596607

608+
mu := sync.Mutex{}
597609
var messageNotification *mcp.JSONRPCNotification
598610
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
599611
notificationNum := 0
600612
client.OnNotification(func(notification mcp.JSONRPCNotification) {
601-
println(notification.Method)
613+
mu.Lock()
602614
if notification.Method == string(mcp.MethodNotificationMessage) {
603615
messageNotification = &notification
604616
} else if notification.Method == string(mcp.MethodNotificationProgress) {
605617
progressNotifications = append(progressNotifications, &notification)
606618
}
607619
notificationNum += 1
620+
mu.Unlock()
608621
})
609622
defer client.Close()
610623

@@ -645,6 +658,7 @@ func TestSSEMCPClient(t *testing.T) {
645658
if err != nil {
646659
t.Fatalf("GetPrompt failed: %v", err)
647660
}
661+
mu.Lock()
648662
assert.NotNil(t, result)
649663
assert.Len(t, result.Messages, 1)
650664
assert.Equal(t, result.Messages[0].Role, mcp.RoleAssistant)
@@ -675,6 +689,7 @@ func TestSSEMCPClient(t *testing.T) {
675689
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"])
676690
assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"])
677691
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"])
692+
mu.Unlock()
678693
})
679694

680695
t.Run("GetResource for testing log and progress notification", func(t *testing.T) {
@@ -683,17 +698,19 @@ func TestSSEMCPClient(t *testing.T) {
683698
t.Fatalf("Failed to create client: %v", err)
684699
}
685700

701+
mu := sync.Mutex{}
686702
var messageNotification *mcp.JSONRPCNotification
687703
progressNotifications := make([]*mcp.JSONRPCNotification, 0)
688704
notificationNum := 0
689705
client.OnNotification(func(notification mcp.JSONRPCNotification) {
690-
println(notification.Method)
706+
mu.Lock()
691707
if notification.Method == string(mcp.MethodNotificationMessage) {
692708
messageNotification = &notification
693709
} else if notification.Method == string(mcp.MethodNotificationProgress) {
694710
progressNotifications = append(progressNotifications, &notification)
695711
}
696712
notificationNum += 1
713+
mu.Unlock()
697714
})
698715
defer client.Close()
699716

@@ -735,6 +752,7 @@ func TestSSEMCPClient(t *testing.T) {
735752
t.Fatalf("ReadResource failed: %v", err)
736753
}
737754

755+
mu.Lock()
738756
assert.NotNil(t, result)
739757
assert.Len(t, result.Contents, 1)
740758
assert.Equal(t, result.Contents[0].(mcp.TextResourceContents).URI, "resource://testresource")
@@ -763,5 +781,6 @@ func TestSSEMCPClient(t *testing.T) {
763781
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["progress"])
764782
assert.Equal(t, "progress_token", progressNotifications[1].Params.AdditionalFields["progressToken"])
765783
assert.EqualValues(t, 100, progressNotifications[1].Params.AdditionalFields["total"])
784+
mu.Unlock()
766785
})
767786
}

server/sse.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
519519
var message string
520520
if eventData, err := json.Marshal(response); err != nil {
521521
// If there is an error marshalling the response, send a generic error response
522-
log.Printf("failed to marshal response: %v", err)
522+
marshal, _ := json.Marshal(response)
523+
log.Printf("failed to marshal response: %v, response %s", err, string(marshal))
523524
message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
524525
} else {
525526
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)

0 commit comments

Comments
 (0)