Skip to content

Commit 35199e0

Browse files
committed
Add checks for context cancellation; add tests
1 parent 9a1aea7 commit 35199e0

File tree

3 files changed

+424
-6
lines changed

3 files changed

+424
-6
lines changed

client/transport/oauth.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ func NewMemoryTokenStore() *MemoryTokenStore {
7878

7979
// GetToken returns the current token
8080
func (s *MemoryTokenStore) GetToken(ctx context.Context) (*Token, error) {
81+
if err := ctx.Err(); err != nil {
82+
return nil, err
83+
}
8184
s.mu.RLock()
8285
defer s.mu.RUnlock()
8386
if s.token == nil {
@@ -88,6 +91,9 @@ func (s *MemoryTokenStore) GetToken(ctx context.Context) (*Token, error) {
8891

8992
// SaveToken saves a token
9093
func (s *MemoryTokenStore) SaveToken(ctx context.Context, token *Token) error {
94+
if err := ctx.Err(); err != nil {
95+
return err
96+
}
9197
s.mu.Lock()
9298
defer s.mu.Unlock()
9399
s.token = token
@@ -151,6 +157,11 @@ func (h *OAuthHandler) GetAuthorizationHeader(ctx context.Context) (string, erro
151157
// getValidToken returns a valid token, refreshing if necessary
152158
func (h *OAuthHandler) getValidToken(ctx context.Context) (*Token, error) {
153159
token, err := h.config.TokenStore.GetToken(ctx)
160+
if err != nil {
161+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
162+
return nil, err
163+
}
164+
}
154165
if err == nil && !token.IsExpired() && token.AccessToken != "" {
155166
return token, nil
156167
}
@@ -218,7 +229,10 @@ func (h *OAuthHandler) refreshToken(ctx context.Context, refreshToken string) (*
218229
}
219230

220231
// If no new refresh token is provided, keep the old one
221-
oldToken, _ := h.config.TokenStore.GetToken(ctx)
232+
oldToken, oldErr := h.config.TokenStore.GetToken(ctx)
233+
if oldErr != nil && (errors.Is(oldErr, context.Canceled) || errors.Is(oldErr, context.DeadlineExceeded)) {
234+
return nil, oldErr
235+
}
222236
if tokenResp.RefreshToken == "" && oldToken != nil {
223237
tokenResp.RefreshToken = oldToken.RefreshToken
224238
}

0 commit comments

Comments
 (0)