Skip to content

Commit 61e0172

Browse files
committed
add basic http transport for oauth
1 parent 8d58ac1 commit 61e0172

File tree

3 files changed

+63
-11
lines changed

3 files changed

+63
-11
lines changed

cmd/src/login.go

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package main
22

33
import (
44
"context"
5-
"encoding/json"
65
"flag"
76
"fmt"
87
"io"
@@ -14,8 +13,6 @@ import (
1413
"github.com/sourcegraph/src-cli/internal/cmderrors"
1514
"github.com/sourcegraph/src-cli/internal/keyring"
1615
"github.com/sourcegraph/src-cli/internal/oauthdevice"
17-
18-
"github.com/sourcegraph/sourcegraph/lib/errors"
1916
)
2017

2118
func init() {
@@ -137,14 +134,14 @@ func loginCmd(ctx context.Context, p loginParams) error {
137134
cfg.Endpoint = endpointArg
138135

139136
if p.useDeviceFlow {
140-
resp, err := runDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient)
137+
token, err := runDeviceFlow(ctx, endpointArg, out, p.deviceFlowClient)
141138
if err != nil {
142139
printProblem(fmt.Sprintf("Device flow authentication failed: %s", err))
143140
fmt.Fprintln(out, createAccessTokenMessage)
144141
return cmderrors.ExitCode1
145142
}
146143

147-
if err := oauthdevice.StoreToken(secretStore, &resp.Token); err != nil {
144+
if err := oauthdevice.StoreToken(secretStore, token); err != nil {
148145
printProblem(fmt.Sprintf("Failed to store token in keyring store: %s", err))
149146
return cmderrors.ExitCode1
150147
}
@@ -198,10 +195,7 @@ func loginCmd(ctx context.Context, p loginParams) error {
198195
return nil
199196
}
200197

201-
func storeToken(store *keyring.Store, token *oauthdevice.Token) error {
202-
}
203-
204-
func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauthdevice.Client) (*oauthdevice.TokenResponse, error) {
198+
func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauthdevice.Client) (*oauthdevice.Token, error) {
205199
authResp, err := client.Start(ctx, endpoint, nil)
206200
if err != nil {
207201
return nil, err
@@ -222,10 +216,10 @@ func runDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client o
222216
interval = 5 * time.Second
223217
}
224218

225-
tokenResp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn)
219+
resp, err := client.Poll(ctx, endpoint, authResp.DeviceCode, interval, authResp.ExpiresIn)
226220
if err != nil {
227221
return nil, err
228222
}
229223

230-
return tokenResp, nil
224+
return resp.Token(endpoint), nil
231225
}

internal/oauthdevice/device_flow.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,15 @@ func (t *TokenResponse) Token(endpoint string) *Token {
380380
}
381381
}
382382

383+
func (t *Token) HasExpired() bool {
384+
return time.Now().After(t.ExpiresAt)
385+
}
386+
387+
func (t *Token) ExpiringIn(d time.Duration) bool {
388+
future := time.Now().Add(d)
389+
return future.After(t.ExpiresAt)
390+
}
391+
383392
func StoreToken(store *keyring.Store, token *Token) error {
384393
data, err := json.Marshal(token)
385394
if err != nil {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package oauthdevice
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"time"
7+
)
8+
9+
var _ http.Transport
10+
11+
var _ http.RoundTripper = (*Transport)(nil)
12+
13+
type Transport struct {
14+
Base http.RoundTripper
15+
token *Token
16+
}
17+
18+
// RoundTrip implements http.RoundTripper.
19+
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
20+
ctx := req.Context()
21+
token, err := maybeRefresh(ctx, t.token)
22+
if err != nil {
23+
return nil, err
24+
}
25+
t.token = token
26+
27+
req2 := req.Clone(req.Context())
28+
req2.Header.Set("Authorization", "Bearer "+t.token.AccessToken)
29+
30+
if t.Base != nil {
31+
return t.Base.RoundTrip(req2)
32+
}
33+
return http.DefaultTransport.RoundTrip(req2)
34+
}
35+
36+
func maybeRefresh(ctx context.Context, token *Token) (*Token, error) {
37+
// token has NOT expired or NOT about to expire in 30s
38+
if !(token.HasExpired() || token.ExpiringIn(time.Duration(30)*time.Second)) {
39+
return token, nil
40+
}
41+
client := NewClient(DefaultClientID)
42+
43+
resp, err := client.Refresh(ctx, token)
44+
if err != nil {
45+
return nil, err
46+
}
47+
48+
return resp.Token(token.Endpoint), nil
49+
}

0 commit comments

Comments
 (0)