Skip to content

Commit 7cf8880

Browse files
committed
oauth2: add device flow support
1 parent d668ce9 commit 7cf8880

File tree

2 files changed

+140
-2
lines changed

2 files changed

+140
-2
lines changed

deviceauth.go

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"golang.org/x/net/context/ctxhttp"
8+
"io"
9+
"io/ioutil"
10+
"net/http"
11+
"net/url"
12+
"strings"
13+
)
14+
15+
const (
16+
errAuthorizationPending = "authorization_pending"
17+
errSlowDown = "slow_down"
18+
errAccessDenied = "access_denied"
19+
errExpiredToken = "expired_token"
20+
)
21+
22+
type DeviceAuth struct {
23+
DeviceCode string `json:"device_code"`
24+
UserCode string `json:"user_code"`
25+
VerificationURI string `json:"verification_uri,verification_url"`
26+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
27+
ExpiresIn int `json:"expires_in,string"`
28+
Interval int `json:"interval,string,omitempty"`
29+
raw map[string]interface{}
30+
}
31+
32+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuth, error) {
33+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
34+
if err != nil {
35+
return nil, err
36+
}
37+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
38+
39+
r, err := ctxhttp.Do(ctx, nil, req)
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
45+
if err != nil {
46+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
47+
}
48+
if code := r.StatusCode; code < 200 || code > 299 {
49+
return nil, &RetrieveError{
50+
Response: r,
51+
Body: body,
52+
}
53+
}
54+
55+
var da = &DeviceAuth{}
56+
err = json.Unmarshal(body, &da)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
_ = json.Unmarshal(body, &da.raw)
62+
63+
// Azure AD supplies verification_url instead of verification_uri
64+
if da.VerificationURI == "" {
65+
da.VerificationURI, _ = da.raw["verification_url"].(string)
66+
}
67+
68+
return da, nil
69+
}
70+
71+
func parseError(err error) string {
72+
e, ok := err.(*RetrieveError)
73+
if ok {
74+
eResp := make(map[string]string)
75+
_ = json.Unmarshal(e.Body, &eResp)
76+
return eResp["error"]
77+
}
78+
return ""
79+
}

oauth2.go

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"net/url"
1717
"strings"
1818
"sync"
19+
"time"
1920

2021
"golang.org/x/oauth2/internal"
2122
)
@@ -74,8 +75,9 @@ type TokenSource interface {
7475
// Endpoint contains the OAuth 2.0 provider's authorization and token
7576
// endpoint URLs.
7677
type Endpoint struct {
77-
AuthURL string
78-
TokenURL string
78+
AuthURL string
79+
DeviceAuthURL string
80+
TokenURL string
7981
}
8082

8183
var (
@@ -203,6 +205,63 @@ func (c *Config) Exchange(ctx context.Context, code string, opts ...AuthCodeOpti
203205
return retrieveToken(ctx, c, v)
204206
}
205207

208+
// AuthDevice returns a device auth struct which contains a device code
209+
// and authorization information provided for users to enter on another device.
210+
func (c *Config) AuthDevice(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuth, error) {
211+
v := url.Values{
212+
"client_id": {c.ClientID},
213+
}
214+
if len(c.Scopes) > 0 {
215+
v.Set("scope", strings.Join(c.Scopes, " "))
216+
}
217+
for _, opt := range opts {
218+
opt.setValue(v)
219+
}
220+
return retrieveDeviceAuth(ctx, c, v)
221+
}
222+
223+
// Poll does a polling to exchange an device code for a token.
224+
func (c *Config) Poll(ctx context.Context, da *DeviceAuth, opts ...AuthCodeOption) (*Token, error) {
225+
v := url.Values{
226+
"client_id": {c.ClientID},
227+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
228+
"device_code": {da.DeviceCode},
229+
"code": {da.DeviceCode},
230+
}
231+
if len(c.Scopes) > 0 {
232+
v.Set("scope", strings.Join(c.Scopes, " "))
233+
}
234+
for _, opt := range opts {
235+
opt.setValue(v)
236+
}
237+
238+
// If no interval was provided, the client MUST use a reasonable default polling interval.
239+
// See https://tools.ietf.org/html/draft-ietf-oauth-device-flow-07#section-3.5
240+
interval := da.Interval
241+
if interval == 0 {
242+
interval = 5
243+
}
244+
245+
for {
246+
time.Sleep(time.Duration(interval) * time.Second)
247+
248+
tok, err := retrieveToken(ctx, c, v)
249+
if err == nil {
250+
return tok, nil
251+
}
252+
253+
errTyp := parseError(err)
254+
switch errTyp {
255+
case errAccessDenied, errExpiredToken:
256+
return tok, errors.New("oauth2: " + errTyp)
257+
case errSlowDown:
258+
interval += 5
259+
fallthrough
260+
case errAuthorizationPending:
261+
}
262+
}
263+
}
264+
206265
// Client returns an HTTP client using the provided token.
207266
// The token will auto-refresh as necessary. The underlying
208267
// HTTP transport will be obtained using the provided context.

0 commit comments

Comments
 (0)