@@ -14,6 +14,9 @@ import (
14
14
"time"
15
15
)
16
16
17
+ // ErrNoToken is returned when no token is available in the token store
18
+ var ErrNoToken = errors .New ("no token available" )
19
+
17
20
// OAuthConfig holds the OAuth configuration for the client
18
21
type OAuthConfig struct {
19
22
// ClientID is the OAuth client ID
@@ -33,12 +36,26 @@ type OAuthConfig struct {
33
36
PKCEEnabled bool
34
37
}
35
38
36
- // TokenStore is an interface for storing and retrieving OAuth tokens
39
+ // TokenStore is an interface for storing and retrieving OAuth tokens.
40
+ //
41
+ // Implementations must:
42
+ // - Honor context cancellation and deadlines, returning context.Canceled
43
+ // or context.DeadlineExceeded as appropriate
44
+ // - Return ErrNoToken (or a sentinel error that wraps it) when no token
45
+ // is available, rather than conflating this with other operational errors
46
+ // - Properly propagate all other errors (database failures, I/O errors, etc.)
47
+ // - Check ctx.Done() before performing operations and return ctx.Err() if cancelled
37
48
type TokenStore interface {
38
- // GetToken returns the current token
39
- GetToken () (* Token , error )
40
- // SaveToken saves a token
41
- SaveToken (token * Token ) error
49
+ // GetToken returns the current token.
50
+ // Returns ErrNoToken if no token is available.
51
+ // Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
52
+ // Returns other errors for operational failures (I/O, database, etc.).
53
+ GetToken (ctx context.Context ) (* Token , error )
54
+
55
+ // SaveToken saves a token.
56
+ // Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
57
+ // Returns other errors for operational failures (I/O, database, etc.).
58
+ SaveToken (ctx context.Context , token * Token ) error
42
59
}
43
60
44
61
// Token represents an OAuth token
@@ -76,18 +93,27 @@ func NewMemoryTokenStore() *MemoryTokenStore {
76
93
return & MemoryTokenStore {}
77
94
}
78
95
79
- // GetToken returns the current token
80
- func (s * MemoryTokenStore ) GetToken () (* Token , error ) {
96
+ // GetToken returns the current token.
97
+ // Returns ErrNoToken if no token is available.
98
+ // Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
99
+ func (s * MemoryTokenStore ) GetToken (ctx context.Context ) (* Token , error ) {
100
+ if err := ctx .Err (); err != nil {
101
+ return nil , err
102
+ }
81
103
s .mu .RLock ()
82
104
defer s .mu .RUnlock ()
83
105
if s .token == nil {
84
- return nil , errors . New ( "no token available" )
106
+ return nil , ErrNoToken
85
107
}
86
108
return s .token , nil
87
109
}
88
110
89
- // SaveToken saves a token
90
- func (s * MemoryTokenStore ) SaveToken (token * Token ) error {
111
+ // SaveToken saves a token.
112
+ // Returns context.Canceled or context.DeadlineExceeded if ctx is cancelled.
113
+ func (s * MemoryTokenStore ) SaveToken (ctx context.Context , token * Token ) error {
114
+ if err := ctx .Err (); err != nil {
115
+ return err
116
+ }
91
117
s .mu .Lock ()
92
118
defer s .mu .Unlock ()
93
119
s .token = token
@@ -150,7 +176,10 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro
150
176
151
177
// getValidToken returns a valid token, refreshing if necessary
152
178
func (h * OAuthHandler ) getValidToken (ctx context.Context ) (* Token , error ) {
153
- token , err := h .config .TokenStore .GetToken ()
179
+ token , err := h .config .TokenStore .GetToken (ctx )
180
+ if err != nil && ! errors .Is (err , ErrNoToken ) {
181
+ return nil , err
182
+ }
154
183
if err == nil && ! token .IsExpired () && token .AccessToken != "" {
155
184
return token , nil
156
185
}
@@ -218,13 +247,12 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*
218
247
}
219
248
220
249
// If no new refresh token is provided, keep the old one
221
- oldToken , _ := h .config .TokenStore .GetToken ()
222
- if tokenResp .RefreshToken == "" && oldToken != nil {
223
- tokenResp .RefreshToken = oldToken .RefreshToken
250
+ if tokenResp .RefreshToken == "" {
251
+ tokenResp .RefreshToken = refreshToken
224
252
}
225
253
226
254
// Save the token
227
- if err := h .config .TokenStore .SaveToken (& tokenResp ); err != nil {
255
+ if err := h .config .TokenStore .SaveToken (ctx , & tokenResp ); err != nil {
228
256
return nil , fmt .Errorf ("failed to save token: %w" , err )
229
257
}
230
258
@@ -637,7 +665,7 @@ func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, s
637
665
}
638
666
639
667
// Save the token
640
- if err := h .config .TokenStore .SaveToken (& tokenResp ); err != nil {
668
+ if err := h .config .TokenStore .SaveToken (ctx , & tokenResp ); err != nil {
641
669
return fmt .Errorf ("failed to save token: %w" , err )
642
670
}
643
671
0 commit comments