Skip to content

Commit 8d58ac1

Browse files
committed
create token struct from TokenResponse
- Token converts expiresIn to a timestamp - Store the token with the endpoint suffix
1 parent 4c95c34 commit 8d58ac1

File tree

2 files changed

+55
-36
lines changed

2 files changed

+55
-36
lines changed

internal/oauthdevice/device_flow.go

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ const (
2626
wellKnownPath = "/.well-known/openid-configuration"
2727

2828
// Key used to store the token in the store
29-
KeyOAuth = "oauth"
29+
KeyOAuth = "Sourcegraph CLI key storage"
3030

3131
GrantTypeDeviceCode string = "urn:ietf:params:oauth:grant-type:device_code"
3232

@@ -55,16 +55,20 @@ type DeviceAuthResponse struct {
5555
ExpiresIn int `json:"expires_in"`
5656
Interval int `json:"interval"`
5757
}
58-
type Token struct {
58+
59+
type TokenResponse struct {
5960
AccessToken string `json:"access_token"`
6061
RefreshToken string `json:"refresh_token,omitempty"`
6162
ExpiresIn int `json:"expires_in,omitempty"`
63+
TokenType string `json:"token_type"`
64+
Scope string `json:"scope,omitempty"`
6265
}
6366

64-
type TokenResponse struct {
65-
Token
66-
TokenType string `json:"token_type"`
67-
Scope string `json:"scope,omitempty"`
67+
type Token struct {
68+
Endpoint string `json:"endpoint"`
69+
AccessToken string `json:"access_token"`
70+
RefreshToken string `json:"refresh_token,omitempty"`
71+
ExpiresAt time.Time `json:"expires_at"`
6872
}
6973

7074
type ErrorResponse struct {
@@ -76,7 +80,7 @@ type Client interface {
7680
Discover(ctx context.Context, endpoint string) (*OIDCConfiguration, error)
7781
Start(ctx context.Context, endpoint string, scopes []string) (*DeviceAuthResponse, error)
7882
Poll(ctx context.Context, endpoint, deviceCode string, interval time.Duration, expiresIn int) (*TokenResponse, error)
79-
Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error)
83+
Refresh(ctx context.Context, token *Token) (*TokenResponse, error)
8084
}
8185

8286
type httpClient struct {
@@ -318,22 +322,20 @@ func (c *httpClient) pollOnce(ctx context.Context, tokenEndpoint, deviceCode str
318322
}
319323

320324
// Refresh exchanges a refresh token for a new access token.
321-
func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string) (*TokenResponse, error) {
322-
endpoint = strings.TrimRight(endpoint, "/")
323-
324-
config, err := c.Discover(ctx, endpoint)
325+
func (c *httpClient) Refresh(ctx context.Context, token *Token) (*TokenResponse, error) {
326+
config, err := c.Discover(ctx, token.Endpoint)
325327
if err != nil {
326-
return nil, errors.Wrap(err, "OIDC discovery failed")
328+
errors.Wrap(err, "failed to discover OIDC configuration")
327329
}
328330

329331
if config.TokenEndpoint == "" {
330-
return nil, errors.New("token endpoint not found in OIDC configuration")
332+
errors.New("OIDC configuration has no token endpoint")
331333
}
332334

333335
data := url.Values{}
334336
data.Set("client_id", c.clientID)
335337
data.Set("grant_type", "refresh_token")
336-
data.Set("refresh_token", refreshToken)
338+
data.Set("refresh_token", token.RefreshToken)
337339

338340
req, err := http.NewRequestWithContext(ctx, "POST", config.TokenEndpoint, strings.NewReader(data.Encode()))
339341
if err != nil {
@@ -366,7 +368,16 @@ func (c *httpClient) Refresh(ctx context.Context, endpoint, refreshToken string)
366368
return nil, errors.Wrap(err, "parsing refresh token response")
367369
}
368370

369-
return &tokenResp, nil
371+
return &tokenResp, err
372+
}
373+
374+
func (t *TokenResponse) Token(endpoint string) *Token {
375+
return &Token{
376+
Endpoint: strings.TrimRight(endpoint, "/"),
377+
RefreshToken: t.RefreshToken,
378+
AccessToken: t.AccessToken,
379+
ExpiresAt: time.Now().Add(time.Second * time.Duration(t.ExpiresIn)),
380+
}
370381
}
371382

372383
func StoreToken(store *keyring.Store, token *Token) error {
@@ -375,13 +386,18 @@ func StoreToken(store *keyring.Store, token *Token) error {
375386
return errors.Wrap(err, "failed to marshal token")
376387
}
377388

378-
// TODO(burmudar): do we need a suffix that is the endpoint? ex. oauth-sourcegraph.com
379-
return store.Set(KeyOAuth, data)
389+
if token.Endpoint == "" {
390+
return errors.New("token endpoint cannot be empty when storing the token")
391+
}
392+
393+
key := fmt.Sprintf("%s <%s>", KeyOAuth, token.Endpoint)
394+
return store.Set(key, data)
380395
}
381396

382-
func LoadToken(store *keyring.Store) (*Token, error) {
397+
func LoadToken(store *keyring.Store, endpoint string) (*Token, error) {
398+
key := fmt.Sprintf("%s <%s>", KeyOAuth, endpoint)
383399
var t Token
384-
data, err := store.Get(KeyOAuth)
400+
data, err := store.Get(key)
385401
if err != nil {
386402
return nil, errors.Wrap(err, "failed to get token from store")
387403
}

internal/oauthdevice/device_flow_test.go

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,10 @@ func TestStart_NoDeviceEndpoint(t *testing.T) {
266266

267267
func TestPoll_Success(t *testing.T) {
268268
wantToken := TokenResponse{
269-
Token: Token{
270-
AccessToken: "test-access-token",
271-
ExpiresIn: 3600,
272-
},
273-
Scope: "read write",
274-
TokenType: "Bearer",
269+
AccessToken: "test-access-token",
270+
ExpiresIn: 3600,
271+
Scope: "read write",
272+
TokenType: "Bearer",
275273
}
276274

277275
server := newTestServer(t, testServerOptions{
@@ -315,6 +313,7 @@ func TestPoll_Success(t *testing.T) {
315313
if resp.TokenType != wantToken.TokenType {
316314
t.Errorf("TokenType = %q, want %q", resp.TokenType, wantToken.TokenType)
317315
}
316+
318317
}
319318

320319
func TestPoll_AuthorizationPending(t *testing.T) {
@@ -337,8 +336,8 @@ func TestPoll_AuthorizationPending(t *testing.T) {
337336
}
338337

339338
json.NewEncoder(w).Encode(TokenResponse{
340-
Token: Token{AccessToken: "test-access-token"},
341-
TokenType: "Bearer",
339+
AccessToken: "test-access-token",
340+
TokenType: "Bearer",
342341
})
343342
},
344343
},
@@ -379,8 +378,8 @@ func TestPoll_SlowDown(t *testing.T) {
379378
}
380379

381380
json.NewEncoder(w).Encode(TokenResponse{
382-
Token: Token{AccessToken: "test-access-token"},
383-
TokenType: "Bearer",
381+
AccessToken: "test-access-token",
382+
TokenType: "Bearer",
384383
})
385384
},
386385
},
@@ -527,20 +526,24 @@ func TestRefresh_Success(t *testing.T) {
527526

528527
w.Header().Set("Content-Type", "application/json")
529528
json.NewEncoder(w).Encode(TokenResponse{
530-
Token: Token{
531-
AccessToken: "new-access-token",
532-
RefreshToken: "new-refresh-token",
533-
ExpiresIn: 3600,
534-
},
535-
TokenType: "Bearer",
529+
AccessToken: "new-access-token",
530+
RefreshToken: "new-refresh-token",
531+
ExpiresIn: 3600,
532+
TokenType: "Bearer",
536533
})
537534
},
538535
},
539536
})
540537
defer server.Close()
541538

542539
client := NewClient(DefaultClientID)
543-
resp, err := client.Refresh(context.Background(), server.URL, "test-refresh-token")
540+
token := &Token{
541+
Endpoint: server.URL,
542+
AccessToken: "new-access-token",
543+
RefreshToken: "test-refresh-token",
544+
ExpiresAt: time.Now().Add(time.Second * time.Duration(3600)),
545+
}
546+
resp, err := client.Refresh(context.Background(), token)
544547
if err != nil {
545548
t.Fatalf("Refresh() error = %v", err)
546549
}

0 commit comments

Comments
 (0)