-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtransport.go
285 lines (251 loc) · 8.94 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
package client
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sync"
)
const (
// HTTP header names
ADMAuthSession = "X-ADM-Auth-Session"
ServerProtocolVersion = "X-Server-Protocol-Version"
DefaultServerProtocolVersion = "7"
SessionEndpoint = "/session"
bodyForbidden = "FORBIDDEN"
)
// ErrMissingName is returned when an HTTP context is missing the DEP name.
var ErrMissingName = errors.New("transport: missing DEP name in HTTP request context")
// ctxKeyName is the context key for the DEP name.
type ctxKeyName struct{}
// WithName creates a new context from ctx with the DEP name associated.
func WithName(ctx context.Context, name string) context.Context {
return context.WithValue(ctx, ctxKeyName{}, name)
}
// GetName retrieves the DEP name from ctx.
func GetName(ctx context.Context) string {
v, _ := ctx.Value(ctxKeyName{}).(string)
return v
}
type AuthTokensRetriever interface {
// RetrieveAuthTokens retrieves the OAuth tokens from storage for name (DEP name).
// If the name or tokens do not exist storage.ErrNotFound should be returned.
RetrieveAuthTokens(ctx context.Context, name string) (*OAuth1Tokens, error)
}
type SessionStore interface {
SetSessionToken(context.Context, string, string) error
GetSessionToken(context.Context, string) (string, error)
}
// sessionMap is a simple SessionStore which manages DEP authentication in a
// Go map. Note this potentially means that these DEP sessions are are not
// shared and thus the Apple DEP servers may not support multiple sessions at
// the same time.
type sessionMap struct {
sessions map[string]string
sync.RWMutex
}
// newSessionMap initializes a new sessionMap.
func newSessionMap() *sessionMap {
return &sessionMap{sessions: make(map[string]string)}
}
func (s *sessionMap) SetSessionToken(_ context.Context, name, session string) error {
s.Lock()
defer s.Unlock()
if session == "" {
delete(s.sessions, name)
} else {
s.sessions[name] = session
}
return nil
}
func (s *sessionMap) GetSessionToken(_ context.Context, name string) (token string, err error) {
s.RLock()
defer s.RUnlock()
token = s.sessions[name]
return
}
// Transport is an http.RoundTripper that transparently handles Apple DEP API
// authentication and session token management. See the RoundTrip method for
// more details.
type Transport struct {
// Wrapped transport that we call for actual HTTP RoundTripping.
transport http.RoundTripper
// Used for making the raw requests to the /session endpoint for
// authentication and session token capture.
client Doer
tokens AuthTokensRetriever
sessions SessionStore
// a cached pre-parsed URL of the /session path only (not a full URL)
sessionURL *url.URL
}
// NewTransport creates a new Transport which wraps and calls to t for the
// actual HTTP calls. We call c for executing the authentication endpoint
// /session. The sessions are stored and retrieved using s while auth tokens
// are retrieved using tokens.
// If t is nil then http.DefaultTransport is used. If c is nil then
// http.DefaultClient is used. If s is nil then local-only session management
// is used. A panic will ensue if tokens is nil.
func NewTransport(t http.RoundTripper, c Doer, tokens AuthTokensRetriever, s SessionStore) *Transport {
if t == nil {
t = http.DefaultTransport
}
if c == nil {
c = http.DefaultClient
}
if tokens == nil {
panic("nil token retriever")
}
if s == nil {
s = newSessionMap()
}
url, err := url.Parse(SessionEndpoint)
if err != nil {
// there shouldn't be a valid reason why url.Parse fails on this
panic(err)
}
return &Transport{
transport: t,
client: c,
tokens: tokens,
sessions: s,
sessionURL: url,
}
}
// TeeReadCloser returns an io.ReadCloser that writes to w what it reads from rc.
// See also io.TeeReader as we simply wrap it under the hood here.
func TeeReadCloser(rc io.ReadCloser, w io.Writer) io.ReadCloser {
type readCloser struct {
io.Reader
io.Closer
}
return &readCloser{io.TeeReader(rc, w), rc}
}
// RoundTrip transparently handles DEP server authentication and session token
// management. Practically speaking this means we make up to three individual
// requests for a given single request: the initial request attempt, a
// possible authentication request followed by a re-try of the original, now
// authenticated, request. Note also that we try to be helpful and inject the
// `X-Server-Protocol-Version` into the request headers if it is missing.
// See https://developer.apple.com/documentation/devicemanagement/device_assignment/authenticating_with_a_device_enrollment_program_dep_server
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
name := GetName(req.Context())
if name == "" {
return nil, ErrMissingName
}
// Apple DEP servers support differing requests and responses based on the
// protocol version header. Try to be helpful and use the latest protocol
// version mentioned in the docs.
if _, ok := req.Header[ServerProtocolVersion]; !ok {
req.Header.Set(ServerProtocolVersion, DefaultServerProtocolVersion)
}
// if previous requests have already authenticated try to use that session token
session, err := t.sessions.GetSessionToken(req.Context(), name)
if err != nil {
return nil, fmt.Errorf("transport: retrieving session token: %w", err)
}
var resp *http.Response
var reqBodyBuf *bytes.Buffer
var roundTripped bool
var forbidden bool
if session != "" {
// if we have a session token for this DEP name then try to inject it
req.Header.Set(ADMAuthSession, session)
if req.Body != nil && req.GetBody == nil {
reqBodyBuf = bytes.NewBuffer(make([]byte, 0, req.ContentLength))
// stream the body to both the wrapped transport and our buffer in case we need to retry
req.Body = TeeReadCloser(req.Body, reqBodyBuf)
}
resp, err = t.transport.RoundTrip(req)
if err != nil {
return resp, err
}
roundTripped = true
}
if resp != nil && resp.StatusCode == http.StatusForbidden {
// the DEP simulator depsim showed this specific 403 Forbidden
// "FORBIDDEN" error when you restart the simulator. this indicates,
// I think, an expired/unknown session token but this isn't documented
// for the DEP service. specifically test and handle this error so we
// do not accidentally capture any other 403 errors (e.g. T&C).
// unfortunately this means reading (and replacing) the body, which is
// rather verbose.
respBodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return resp, fmt.Errorf("transport: reading response body: %w", err)
}
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBodyBytes))
if bytes.Contains(respBodyBytes, []byte(bodyForbidden)) {
forbidden = true
}
}
if session == "" || resp.StatusCode == http.StatusUnauthorized || forbidden {
// either we have no session token yet or the DEP server doesn't like
// our provided token. let's authenticate.
tokens, err := t.tokens.RetrieveAuthTokens(req.Context(), name)
if err != nil {
return nil, fmt.Errorf("transport: retrieving auth tokens: %w", err)
}
// assemble the /session URL from the original request "base" URL.
sessionURL := req.URL.ResolveReference(t.sessionURL)
sessionReq, err := http.NewRequestWithContext(
req.Context(),
"GET",
sessionURL.String(),
nil,
)
if userAgent := req.Header.Get("User-Agent"); userAgent != "" {
// copy the UA from the original request to the auth request
sessionReq.Header.Set("User-Agent", userAgent)
}
if err != nil {
return nil, fmt.Errorf("transport: creating session request: %w", err)
}
// use the same version header from the original request (which we
// likely set ourselves anyway)
sessionReq.Header.Set(
ServerProtocolVersion,
req.Header.Get(ServerProtocolVersion),
)
session, err = DoAuth(t.client, sessionReq, tokens)
if err != nil {
return nil, err
}
// save our session token for use by following requests
err = t.sessions.SetSessionToken(req.Context(), name, session)
if err != nil {
return nil, fmt.Errorf("transport: setting auth session token: %w", err)
}
// now that we've received and saved the session token let's use it
// to actually make the (same) request.
req.Header.Set(ADMAuthSession, session)
// reset our body reader if needed
if roundTripped && req.Body != nil {
if req.GetBody != nil {
// (ab)use the 304 redirect body cache if present
req.Body, err = req.GetBody()
if err != nil {
return nil, err
}
} else if reqBodyBuf != nil {
req.Body = io.NopCloser(reqBodyBuf)
}
}
resp, err = t.transport.RoundTrip(req)
if err != nil {
return resp, err
}
}
// check if the session token has changed. Apple says that the session
// token can be updated from the server. save it if so.
if respSession := resp.Header.Get(ADMAuthSession); respSession != "" && session != respSession {
err = t.sessions.SetSessionToken(req.Context(), name, respSession)
if err != nil {
return nil, fmt.Errorf("transport: setting response session token: %w", err)
}
}
return resp, nil
}