Skip to content

Commit 9a1aea7

Browse files
committed
feat!: add context.Context to TokenStore methods
In distributed environments a TokenStore will be implemented as a database, and often a context is required when running get/store operations to ensure cancellation propagation, instrumentation, etc works when running token store ops. This is a breaking change but I think it makes sense (and it's pretty easy to fix any breakages).
1 parent 35ebaa5 commit 9a1aea7

File tree

5 files changed

+20
-15
lines changed

5 files changed

+20
-15
lines changed

client/oauth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
)
1414

1515
func TestNewOAuthStreamableHttpClient(t *testing.T) {
16+
ctx := context.Background()
1617
// Create a test server
1718
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1819
// Check for Authorization header
@@ -51,7 +52,7 @@ func TestNewOAuthStreamableHttpClient(t *testing.T) {
5152
ExpiresIn: 3600,
5253
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
5354
}
54-
if err := tokenStore.SaveToken(validToken); err != nil {
55+
if err := tokenStore.SaveToken(ctx, validToken); err != nil {
5556
t.Fatalf("Failed to save token: %v", err)
5657
}
5758

client/transport/oauth.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ type OAuthConfig struct {
3636
// TokenStore is an interface for storing and retrieving OAuth tokens
3737
type TokenStore interface {
3838
// GetToken returns the current token
39-
GetToken() (*Token, error)
39+
GetToken(ctx context.Context) (*Token, error)
4040
// SaveToken saves a token
41-
SaveToken(token *Token) error
41+
SaveToken(ctx context.Context, token *Token) error
4242
}
4343

4444
// Token represents an OAuth token
@@ -77,7 +77,7 @@ func NewMemoryTokenStore() *MemoryTokenStore {
7777
}
7878

7979
// GetToken returns the current token
80-
func (s *MemoryTokenStore) GetToken() (*Token, error) {
80+
func (s *MemoryTokenStore) GetToken(ctx context.Context) (*Token, error) {
8181
s.mu.RLock()
8282
defer s.mu.RUnlock()
8383
if s.token == nil {
@@ -87,7 +87,7 @@ func (s *MemoryTokenStore) GetToken() (*Token, error) {
8787
}
8888

8989
// SaveToken saves a token
90-
func (s *MemoryTokenStore) SaveToken(token *Token) error {
90+
func (s *MemoryTokenStore) SaveToken(ctx context.Context, token *Token) error {
9191
s.mu.Lock()
9292
defer s.mu.Unlock()
9393
s.token = token
@@ -150,7 +150,7 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro
150150

151151
// getValidToken returns a valid token, refreshing if necessary
152152
func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
153-
token, err := h.config.TokenStore.GetToken()
153+
token, err := h.config.TokenStore.GetToken(ctx)
154154
if err == nil && !token.IsExpired() && token.AccessToken != "" {
155155
return token, nil
156156
}
@@ -218,13 +218,13 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*
218218
}
219219

220220
// If no new refresh token is provided, keep the old one
221-
oldToken, _ := h.config.TokenStore.GetToken()
221+
oldToken, _ := h.config.TokenStore.GetToken(ctx)
222222
if tokenResp.RefreshToken == "" && oldToken != nil {
223223
tokenResp.RefreshToken = oldToken.RefreshToken
224224
}
225225

226226
// Save the token
227-
if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
227+
if err := h.config.TokenStore.SaveToken(ctx, &tokenResp); err != nil {
228228
return nil, fmt.Errorf("failed to save token: %w", err)
229229
}
230230

@@ -637,7 +637,7 @@ func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, s
637637
}
638638

639639
// Save the token
640-
if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
640+
if err := h.config.TokenStore.SaveToken(ctx, &tokenResp); err != nil {
641641
return fmt.Errorf("failed to save token: %w", err)
642642
}
643643

client/transport/oauth_test.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ func TestToken_IsExpired(t *testing.T) {
5454
func TestMemoryTokenStore(t *testing.T) {
5555
// Create a token store
5656
store := NewMemoryTokenStore()
57+
ctx := context.Background()
5758

5859
// Test getting token from empty store
59-
_, err := store.GetToken()
60+
_, err := store.GetToken(ctx)
6061
if err == nil {
6162
t.Errorf("Expected error when getting token from empty store")
6263
}
@@ -71,13 +72,13 @@ func TestMemoryTokenStore(t *testing.T) {
7172
}
7273

7374
// Save the token
74-
err = store.SaveToken(token)
75+
err = store.SaveToken(ctx, token)
7576
if err != nil {
7677
t.Fatalf("Failed to save token: %v", err)
7778
}
7879

7980
// Get the token
80-
retrievedToken, err := store.GetToken()
81+
retrievedToken, err := store.GetToken(ctx)
8182
if err != nil {
8283
t.Fatalf("Failed to get token: %v", err)
8384
}
@@ -158,6 +159,7 @@ func TestValidateRedirectURI(t *testing.T) {
158159

159160
func TestOAuthHandler_GetAuthorizationHeader_EmptyAccessToken(t *testing.T) {
160161
// Create a token store with a token that has an empty access token
162+
ctx := context.Background()
161163
tokenStore := NewMemoryTokenStore()
162164
invalidToken := &Token{
163165
AccessToken: "", // Empty access token
@@ -166,7 +168,7 @@ func TestOAuthHandler_GetAuthorizationHeader_EmptyAccessToken(t *testing.T) {
166168
ExpiresIn: 3600,
167169
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
168170
}
169-
if err := tokenStore.SaveToken(invalidToken); err != nil {
171+
if err := tokenStore.SaveToken(ctx, invalidToken); err != nil {
170172
t.Fatalf("Failed to save token: %v", err)
171173
}
172174

client/transport/sse_oauth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
)
1212

1313
func TestSSE_WithOAuth(t *testing.T) {
14+
ctx := context.Background()
1415
// Track request count to simulate 401 on first request, then success
1516
requestCount := 0
1617
authHeaderReceived := ""
@@ -80,7 +81,7 @@ func TestSSE_WithOAuth(t *testing.T) {
8081
ExpiresIn: 3600,
8182
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
8283
}
83-
if err := tokenStore.SaveToken(validToken); err != nil {
84+
if err := tokenStore.SaveToken(ctx, validToken); err != nil {
8485
t.Fatalf("Failed to save token: %v", err)
8586
}
8687

client/transport/streamable_http_oauth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
)
1414

1515
func TestStreamableHTTP_WithOAuth(t *testing.T) {
16+
ctx := context.Background()
1617
// Track request count to simulate 401 on first request, then success
1718
requestCount := 0
1819
authHeaderReceived := ""
@@ -59,7 +60,7 @@ func TestStreamableHTTP_WithOAuth(t *testing.T) {
5960
ExpiresIn: 3600,
6061
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
6162
}
62-
if err := tokenStore.SaveToken(validToken); err != nil {
63+
if err := tokenStore.SaveToken(ctx, validToken); err != nil {
6364
t.Fatalf("Failed to save token: %v", err)
6465
}
6566

0 commit comments

Comments
 (0)