Skip to content

Commit c6bb426

Browse files
committed
oauth2: add device flow support
Signed-off-by: Marcos Lilljedahl <marcosnils@gmail.com>
1 parent 128564f commit c6bb426

File tree

2 files changed

+140
-3
lines changed

2 files changed

+140
-3
lines changed

deviceauth.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"golang.org/x/net/context/ctxhttp"
8+
"io"
9+
"io/ioutil"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
)
14+
15+
const (
16+
errAuthorizationPending = "authorization_pending"
17+
errSlowDown = "slow_down"
18+
errAccessDenied = "access_denied"
19+
errExpiredToken = "expired_token"
20+
)
21+
22+
type DeviceAuth struct {
23+
DeviceCode string `json:"device_code"`
24+
UserCode string `json:"user_code"`
25+
VerificationURI string `json:"verification_uri,verification_url"`
26+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
27+
ExpiresIn int `json:"expires_in,string"`
28+
Interval int `json:"interval,string,omitempty"`
29+
raw map[string]interface{}
30+
}
31+
32+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuth, error) {
33+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
34+
if err != nil {
35+
return nil, err
36+
}
37+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
38+
39+
r, err := ctxhttp.Do(ctx, nil, req)
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
45+
if err != nil {
46+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
47+
}
48+
if code := r.StatusCode; code < 200 || code > 299 {
49+
return nil, &RetrieveError{
50+
Response: r,
51+
Body: body,
52+
}
53+
}
54+
55+
var da = &DeviceAuth{}
56+
err = json.Unmarshal(body, &da)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
_ = json.Unmarshal(body, &da.raw)
62+
63+
// Azure AD supplies verification_url instead of verification_uri
64+
if da.VerificationURI == "" {
65+
da.VerificationURI, _ = da.raw["verification_url"].(string)
66+
}
67+
68+
return da, nil
69+
}
70+
71+
func parseError(err error) string {
72+
e, ok := err.(*RetrieveError)
73+
if ok {
74+
eResp := make(map[string]string)
75+
_ = json.Unmarshal(e.Body, &eResp)
76+
return eResp["error"]
77+
}
78+
return ""
79+
}

oauth2.go

Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/url"
1717
"strings"
1818
"sync"
19+
"time"
1920

2021
"golang.org/x/oauth2/internal"
2122
)
@@ -70,8 +71,9 @@ type TokenSource interface {
7071
// Endpoint represents an OAuth 2.0 provider's authorization and token
7172
// endpoint URLs.
7273
type Endpoint struct {
73-
AuthURL string
74-
TokenURL string
74+
AuthURL string
75+
DeviceAuthURL string
76+
TokenURL string
7577

7678
// AuthStyle optionally specifies how the endpoint wants the
7779
// client ID & client secret sent. The zero value means to
@@ -224,6 +226,63 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti
224226
return retrieveToken(ctx, c, v)
225227
}
226228

229+
// AuthDevice returns a device auth struct which contains a device code
230+
// and authorization information provided for users to enter on another device.
231+
func (c *Config) AuthDevice(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuth, error) {
232+
v := url.Values{
233+
"client_id": {c.ClientID},
234+
}
235+
if len(c.Scopes) > 0 {
236+
v.Set("scope", strings.Join(c.Scopes, " "))
237+
}
238+
for _, opt := range opts {
239+
opt.setValue(v)
240+
}
241+
return retrieveDeviceAuth(ctx, c, v)
242+
}
243+
244+
// Poll does a polling to exchange an device code for a token.
245+
func (c *Config) Poll(ctx context.Context, da *DeviceAuth, opts ...AuthCodeOption) (*Token, error) {
246+
v := url.Values{
247+
"client_id": {c.ClientID},
248+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
249+
"device_code": {da.DeviceCode},
250+
"code": {da.DeviceCode},
251+
}
252+
if len(c.Scopes) > 0 {
253+
v.Set("scope", strings.Join(c.Scopes, " "))
254+
}
255+
for _, opt := range opts {
256+
opt.setValue(v)
257+
}
258+
259+
// If no interval was provided, the client MUST use a reasonable default polling interval.
260+
// See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5
261+
interval := da.Interval
262+
if interval == 0 {
263+
interval = 5
264+
}
265+
266+
for {
267+
time.Sleep(time.Duration(interval) * time.Second)
268+
269+
tok, err := retrieveToken(ctx, c, v)
270+
if err == nil {
271+
return tok, nil
272+
}
273+
274+
errTyp := parseError(err)
275+
switch errTyp {
276+
case errAccessDenied, errExpiredToken:
277+
return tok, errors.New("oauth2: " + errTyp)
278+
case errSlowDown:
279+
interval += 5
280+
fallthrough
281+
case errAuthorizationPending:
282+
}
283+
}
284+
}
285+
227286
// Client returns an HTTP client using the provided token.
228287
// The token will auto-refresh as necessary. The underlying
229288
// HTTP transport will be obtained using the provided context.
@@ -271,7 +330,6 @@ func (tf *tokenRefresher) Token() (*Token, error) {
271330
"grant_type": {"refresh_token"},
272331
"refresh_token": {tf.refreshToken},
273332
})
274-
275333
if err != nil {
276334
return nil, err
277335
}

0 commit comments

Comments
 (0)