Skip to content

Commit d70b8d2

Browse files
authored
feat!: add context.Context to TokenStore methods (#557)
* feat!: add context.Context to TokenStore methods In distributed environments a TokenStore will be implemented as a database, and often a context is required when running get/store operations to ensure cancellation propagation, instrumentation, etc works when running token store ops. This is a breaking change but I think it makes sense (and it's pretty easy to fix any breakages). * Add checks for context cancellation; add tests * Document TokenStore requirements and add ErrNoToken sentinel error
1 parent c87c957 commit d70b8d2

File tree

5 files changed

+467
-30
lines changed

5 files changed

+467
-30
lines changed

client/oauth_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
)
1414

1515
func TestNewOAuthStreamableHttpClient(t *testing.T) {
16+
ctx := context.Background()
1617
// Create a test server
1718
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1819
// Check for Authorization header
@@ -51,7 +52,7 @@ func TestNewOAuthStreamableHttpClient(t *testing.T) {
5152
ExpiresIn: 3600,
5253
ExpiresAt: time.Now().Add(1 * time.Hour), // Valid for 1 hour
5354
}
54-
if err := tokenStore.SaveToken(validToken); err != nil {
55+
if err := tokenStore.SaveToken(ctx, validToken); err != nil {
5556
t.Fatalf("Failed to save token: %v", err)
5657
}
5758

client/transport/oauth.go

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ import (
1414
"time"
1515
)
1616

17+
// ErrNoToken is returned when no token is available in the token store
18+
var ErrNoToken = errors.New("no token available")
19+
1720
// OAuthConfig holds the OAuth configuration for the client
1821
type OAuthConfig struct {
1922
// ClientID is the OAuth client ID
@@ -33,12 +36,26 @@ type OAuthConfig struct {
3336
PKCEEnabled bool
3437
}
3538

36-
// TokenStore is an interface for storing and retrieving OAuth tokens
39+
// TokenStore is an interface for storing and retrieving OAuth tokens.
40+
//
41+
// Implementations must:
42+
// - Honor context cancellation and deadlines, returning context.Canceled
43+
// or context.DeadlineExceeded as appropriate
44+
// - Return ErrNoToken (or a sentinel error that wraps it) when no token
45+
// is available, rather than conflating this with other operational errors
46+
// - Properly propagate all other errors (database failures, I/O errors, etc.)
47+
// - Check ctx.Done() before performing operations and return ctx.Err() if cancelled
3748
type TokenStore interface {
38-
// GetToken returns the current token
39-
GetToken() (*Token, error)
40-
// SaveToken saves a token
41-
SaveToken(token *Token) error
49+
// GetToken returns the current token.
50+
// Returns ErrNoToken if no token is available.
51+
// Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
52+
// Returns other errors for operational failures (I/O, database, etc.).
53+
GetToken(ctx context.Context) (*Token, error)
54+
55+
// SaveToken saves a token.
56+
// Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
57+
// Returns other errors for operational failures (I/O, database, etc.).
58+
SaveToken(ctx context.Context, token *Token) error
4259
}
4360

4461
// Token represents an OAuth token
@@ -76,18 +93,27 @@ func NewMemoryTokenStore() *MemoryTokenStore {
7693
return &MemoryTokenStore{}
7794
}
7895

79-
// GetToken returns the current token
80-
func (s *MemoryTokenStore) GetToken() (*Token, error) {
96+
// GetToken returns the current token.
97+
// Returns ErrNoToken if no token is available.
98+
// Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
99+
func (s *MemoryTokenStore) GetToken(ctx context.Context) (*Token, error) {
100+
if err := ctx.Err(); err != nil {
101+
return nil, err
102+
}
81103
s.mu.RLock()
82104
defer s.mu.RUnlock()
83105
if s.token == nil {
84-
return nil, errors.New("no token available")
106+
return nil, ErrNoToken
85107
}
86108
return s.token, nil
87109
}
88110

89-
// SaveToken saves a token
90-
func (s *MemoryTokenStore) SaveToken(token *Token) error {
111+
// SaveToken saves a token.
112+
// Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
113+
func (s *MemoryTokenStore) SaveToken(ctx context.Context, token *Token) error {
114+
if err := ctx.Err(); err != nil {
115+
return err
116+
}
91117
s.mu.Lock()
92118
defer s.mu.Unlock()
93119
s.token = token
@@ -150,7 +176,10 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro
150176

151177
// getValidToken returns a valid token, refreshing if necessary
152178
func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
153-
token, err := h.config.TokenStore.GetToken()
179+
token, err := h.config.TokenStore.GetToken(ctx)
180+
if err != nil && !errors.Is(err, ErrNoToken) {
181+
return nil, err
182+
}
154183
if err == nil && !token.IsExpired() && token.AccessToken != "" {
155184
return token, nil
156185
}
@@ -218,13 +247,12 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*
218247
}
219248

220249
// If no new refresh token is provided, keep the old one
221-
oldToken, _ := h.config.TokenStore.GetToken()
222-
if tokenResp.RefreshToken == "" && oldToken != nil {
223-
tokenResp.RefreshToken = oldToken.RefreshToken
250+
if tokenResp.RefreshToken == "" {
251+
tokenResp.RefreshToken = refreshToken
224252
}
225253

226254
// Save the token
227-
if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
255+
if err := h.config.TokenStore.SaveToken(ctx, &tokenResp); err != nil {
228256
return nil, fmt.Errorf("failed to save token: %w", err)
229257
}
230258

@@ -637,7 +665,7 @@ func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, s
637665
}
638666

639667
// Save the token
640-
if err := h.config.TokenStore.SaveToken(&tokenResp); err != nil {
668+
if err := h.config.TokenStore.SaveToken(ctx, &tokenResp); err != nil {
641669
return fmt.Errorf("failed to save token: %w", err)
642670
}
643671

0 commit comments

Comments
 (0)