Skip to content

Commit 42193ac

Browse files
committed
refactor token refresh tests
1 parent a6e9155 commit 42193ac

File tree

1 file changed

+105
-132
lines changed

1 file changed

+105
-132
lines changed

internal/oauth/http_transport_test.go

Lines changed: 105 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -9,127 +9,129 @@ import (
99
"time"
1010
)
1111

12-
type mockRoundTripper struct {
13-
handler func(*http.Request) (*http.Response, error)
14-
}
15-
16-
func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
17-
return m.handler(req)
18-
}
19-
20-
func TestTransport_SetsAuthorizationHeader(t *testing.T) {
21-
var capturedAuth string
22-
23-
transport := &Transport{
24-
Base: &mockRoundTripper{
25-
handler: func(req *http.Request) (*http.Response, error) {
26-
capturedAuth = req.Header.Get("Authorization")
27-
return &http.Response{StatusCode: 200}, nil
28-
},
29-
},
30-
Token: &Token{
31-
AccessToken: "test-token",
32-
ExpiresAt: time.Now().Add(time.Hour),
33-
},
34-
}
35-
36-
req := httptest.NewRequest("GET", "http://example.com", nil)
37-
_, err := transport.RoundTrip(req)
38-
if err != nil {
39-
t.Fatalf("RoundTrip() error = %v", err)
40-
}
12+
type roundTripperFunc func(*http.Request) (*http.Response, error)
4113

42-
if capturedAuth != "Bearer test-token" {
43-
t.Errorf("Authorization = %q, want %q", capturedAuth, "Bearer test-token")
44-
}
14+
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
15+
return f(req)
4516
}
4617

47-
func TestMaybeRefresh_RefreshesExpiredToken(t *testing.T) {
48-
server := newTestServer(t, testServerOptions{
18+
func newRefreshServer(t *testing.T, accessToken string) *httptest.Server {
19+
t.Helper()
20+
return newTestServer(t, testServerOptions{
4921
handlers: map[string]http.HandlerFunc{
50-
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
22+
testTokenPath: func(w http.ResponseWriter, _ *http.Request) {
5123
w.Header().Set("Content-Type", "application/json")
52-
w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`))
24+
_, _ = w.Write([]byte(`{"access_token":"` + accessToken + `","refresh_token":"new-refresh","expires_in":3600}`))
5325
},
5426
},
5527
})
56-
defer server.Close()
57-
58-
token := &Token{
59-
Endpoint: server.URL,
60-
AccessToken: "expired-token",
61-
RefreshToken: "refresh-token",
62-
ExpiresAt: time.Now().Add(-time.Hour), // expired
63-
}
64-
65-
result, err := maybeRefresh(context.Background(), token)
66-
if err != nil {
67-
t.Fatalf("maybeRefresh() error = %v", err)
68-
}
69-
70-
if result.AccessToken != "new-token" {
71-
t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token")
72-
}
7328
}
7429

75-
func TestMaybeRefresh_RefreshesTokenExpiringSoon(t *testing.T) {
76-
server := newTestServer(t, testServerOptions{
77-
handlers: map[string]http.HandlerFunc{
78-
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
79-
w.Header().Set("Content-Type", "application/json")
80-
w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`))
81-
},
82-
},
83-
})
30+
func TestMaybeRefresh(t *testing.T) {
31+
server := newRefreshServer(t, "new-token")
8432
defer server.Close()
8533

86-
token := &Token{
87-
Endpoint: server.URL,
88-
AccessToken: "expiring-soon-token",
89-
RefreshToken: "refresh-token",
90-
ExpiresAt: time.Now().Add(10 * time.Second), // expires in 10s (< 30s threshold)
91-
}
92-
93-
result, err := maybeRefresh(context.Background(), token)
94-
if err != nil {
95-
t.Fatalf("maybeRefresh() error = %v", err)
34+
tests := []struct {
35+
name string
36+
token *Token
37+
wantAccess string
38+
wantSame bool
39+
}{
40+
{
41+
name: "unchanged when still valid",
42+
token: &Token{
43+
AccessToken: "valid-token",
44+
ExpiresAt: time.Now().Add(time.Hour),
45+
},
46+
wantAccess: "valid-token",
47+
wantSame: true,
48+
},
49+
{
50+
name: "refreshes expired token",
51+
token: &Token{
52+
Endpoint: server.URL,
53+
AccessToken: "expired-token",
54+
RefreshToken: "refresh-token",
55+
ExpiresAt: time.Now().Add(-time.Hour),
56+
},
57+
wantAccess: "new-token",
58+
},
59+
{
60+
name: "refreshes token expiring soon",
61+
token: &Token{
62+
Endpoint: server.URL,
63+
AccessToken: "expiring-soon-token",
64+
RefreshToken: "refresh-token",
65+
ExpiresAt: time.Now().Add(10 * time.Second),
66+
},
67+
wantAccess: "new-token",
68+
},
9669
}
9770

98-
if result.AccessToken != "new-token" {
99-
t.Errorf("AccessToken = %q, want %q", result.AccessToken, "new-token")
71+
for _, tt := range tests {
72+
t.Run(tt.name, func(t *testing.T) {
73+
got, err := maybeRefresh(context.Background(), tt.token)
74+
if err != nil {
75+
t.Fatalf("maybeRefresh() error = %v", err)
76+
}
77+
if got.AccessToken != tt.wantAccess {
78+
t.Errorf("AccessToken = %q, want %q", got.AccessToken, tt.wantAccess)
79+
}
80+
if tt.wantSame && got != tt.token {
81+
t.Errorf("token pointer changed for unexpired token")
82+
}
83+
})
10084
}
10185
}
10286

103-
func TestTransport_RefreshPersistence(t *testing.T) {
87+
func TestTransportRoundTrip(t *testing.T) {
10488
tests := []struct {
105-
name string
106-
needsRefresh bool
107-
persistErr error
108-
wantAuthHeaderVal string
109-
wantStoreCalls int
89+
name string
90+
token *Token
91+
persistErr error
92+
wantAuthHeader string
93+
wantStoreCalls int
11094
}{
11195
{
112-
name: "persists refreshed token",
113-
needsRefresh: true,
114-
wantAuthHeaderVal: "Bearer new-token",
115-
wantStoreCalls: 1,
96+
name: "uses existing token without persisting",
97+
token: &Token{
98+
AccessToken: "valid-token",
99+
ExpiresAt: time.Now().Add(time.Hour),
100+
},
101+
wantAuthHeader: "Bearer valid-token",
102+
wantStoreCalls: 0,
116103
},
117104
{
118-
name: "does not persist unchanged token",
119-
wantAuthHeaderVal: "Bearer valid-token",
120-
wantStoreCalls: 0,
105+
name: "persists refreshed token",
106+
token: &Token{
107+
AccessToken: "expired-token",
108+
RefreshToken: "refresh-token",
109+
ExpiresAt: time.Now().Add(-time.Hour),
110+
},
111+
wantAuthHeader: "Bearer new-token",
112+
wantStoreCalls: 1,
121113
},
122114
{
123-
name: "persist failure does not fail request",
124-
needsRefresh: true,
125-
persistErr: errors.New("persist failed"),
126-
wantAuthHeaderVal: "Bearer new-token",
127-
wantStoreCalls: 1,
115+
name: "ignores persist failures",
116+
token: &Token{
117+
AccessToken: "expired-token",
118+
RefreshToken: "refresh-token",
119+
ExpiresAt: time.Now().Add(-time.Hour),
120+
},
121+
persistErr: errors.New("persist failed"),
122+
wantAuthHeader: "Bearer new-token",
123+
wantStoreCalls: 1,
128124
},
129125
}
130126

131127
for _, tt := range tests {
132128
t.Run(tt.name, func(t *testing.T) {
129+
if tt.wantStoreCalls > 0 {
130+
server := newRefreshServer(t, "new-token")
131+
defer server.Close()
132+
tt.token.Endpoint = server.URL
133+
}
134+
133135
originalStoreFn := storeRefreshedTokenFn
134136
defer func() { storeRefreshedTokenFn = originalStoreFn }()
135137

@@ -141,57 +143,28 @@ func TestTransport_RefreshPersistence(t *testing.T) {
141143
return tt.persistErr
142144
}
143145

144-
token := &Token{
145-
AccessToken: "valid-token",
146-
ExpiresAt: time.Now().Add(time.Hour),
147-
}
148-
if tt.needsRefresh {
149-
server := newTestServer(t, testServerOptions{
150-
handlers: map[string]http.HandlerFunc{
151-
testTokenPath: func(w http.ResponseWriter, r *http.Request) {
152-
w.Header().Set("Content-Type", "application/json")
153-
w.Write([]byte(`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}`))
154-
},
155-
},
156-
})
157-
defer server.Close()
158-
token.Endpoint = server.URL
159-
token.AccessToken = "expired-token"
160-
token.RefreshToken = "refresh-token"
161-
token.ExpiresAt = time.Now().Add(-time.Hour)
162-
}
163-
164146
var capturedAuth string
165-
transport := &Transport{
166-
Base: &mockRoundTripper{
167-
handler: func(req *http.Request) (*http.Response, error) {
168-
capturedAuth = req.Header.Get("Authorization")
169-
return &http.Response{StatusCode: 200}, nil
170-
},
171-
},
172-
Token: token,
147+
tr := &Transport{
148+
Base: roundTripperFunc(func(req *http.Request) (*http.Response, error) {
149+
capturedAuth = req.Header.Get("Authorization")
150+
return &http.Response{StatusCode: http.StatusOK, Body: http.NoBody}, nil
151+
}),
152+
Token: tt.token,
173153
}
174154

175-
req := httptest.NewRequest("GET", "http://example.com", nil)
176-
_, err := transport.RoundTrip(req)
155+
_, err := tr.RoundTrip(httptest.NewRequest(http.MethodGet, "http://example.com", nil))
177156
if err != nil {
178157
t.Fatalf("RoundTrip() error = %v", err)
179158
}
180159

181-
if capturedAuth != tt.wantAuthHeaderVal {
182-
t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeaderVal)
160+
if capturedAuth != tt.wantAuthHeader {
161+
t.Errorf("Authorization = %q, want %q", capturedAuth, tt.wantAuthHeader)
183162
}
184163
if storeCalls != tt.wantStoreCalls {
185164
t.Errorf("store calls = %d, want %d", storeCalls, tt.wantStoreCalls)
186165
}
187-
188-
if tt.needsRefresh {
189-
if storedToken == nil {
190-
t.Fatal("stored token is nil")
191-
}
192-
if storedToken.AccessToken != "new-token" {
193-
t.Errorf("stored AccessToken = %q, want %q", storedToken.AccessToken, "new-token")
194-
}
166+
if tt.wantStoreCalls > 0 && (storedToken == nil || storedToken.AccessToken != "new-token") {
167+
t.Errorf("stored token = %#v, want access token %q", storedToken, "new-token")
195168
}
196169
})
197170
}

0 commit comments

Comments
 (0)