@@ -9,127 +9,129 @@ import (
99 "time"
1010)
1111
12- type mockRoundTripper struct {
13- handler func (* http.Request ) (* http.Response , error )
14- }
15-
16- func (m * mockRoundTripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
17- return m .handler (req )
18- }
19-
20- func TestTransport_SetsAuthorizationHeader (t * testing.T ) {
21- var capturedAuth string
22-
23- transport := & Transport {
24- Base : & mockRoundTripper {
25- handler : func (req * http.Request ) (* http.Response , error ) {
26- capturedAuth = req .Header .Get ("Authorization" )
27- return & http.Response {StatusCode : 200 }, nil
28- },
29- },
30- Token : & Token {
31- AccessToken : "test-token" ,
32- ExpiresAt : time .Now ().Add (time .Hour ),
33- },
34- }
35-
36- req := httptest .NewRequest ("GET" , "http://example.com" , nil )
37- _ , err := transport .RoundTrip (req )
38- if err != nil {
39- t .Fatalf ("RoundTrip() error = %v" , err )
40- }
12+ type roundTripperFunc func (* http.Request ) (* http.Response , error )
4113
42- if capturedAuth != "Bearer test-token" {
43- t .Errorf ("Authorization = %q, want %q" , capturedAuth , "Bearer test-token" )
44- }
14+ func (f roundTripperFunc ) RoundTrip (req * http.Request ) (* http.Response , error ) {
15+ return f (req )
4516}
4617
47- func TestMaybeRefresh_RefreshesExpiredToken (t * testing.T ) {
48- server := newTestServer (t , testServerOptions {
18+ func newRefreshServer (t * testing.T , accessToken string ) * httptest.Server {
19+ t .Helper ()
20+ return newTestServer (t , testServerOptions {
4921 handlers : map [string ]http.HandlerFunc {
50- testTokenPath : func (w http.ResponseWriter , r * http.Request ) {
22+ testTokenPath : func (w http.ResponseWriter , _ * http.Request ) {
5123 w .Header ().Set ("Content-Type" , "application/json" )
52- w .Write ([]byte (`{"access_token":"new-token ","refresh_token":"new-refresh","expires_in":3600}` ))
24+ _ , _ = w .Write ([]byte (`{"access_token":"` + accessToken + ` ","refresh_token":"new-refresh","expires_in":3600}` ))
5325 },
5426 },
5527 })
56- defer server .Close ()
57-
58- token := & Token {
59- Endpoint : server .URL ,
60- AccessToken : "expired-token" ,
61- RefreshToken : "refresh-token" ,
62- ExpiresAt : time .Now ().Add (- time .Hour ), // expired
63- }
64-
65- result , err := maybeRefresh (context .Background (), token )
66- if err != nil {
67- t .Fatalf ("maybeRefresh() error = %v" , err )
68- }
69-
70- if result .AccessToken != "new-token" {
71- t .Errorf ("AccessToken = %q, want %q" , result .AccessToken , "new-token" )
72- }
7328}
7429
75- func TestMaybeRefresh_RefreshesTokenExpiringSoon (t * testing.T ) {
76- server := newTestServer (t , testServerOptions {
77- handlers : map [string ]http.HandlerFunc {
78- testTokenPath : func (w http.ResponseWriter , r * http.Request ) {
79- w .Header ().Set ("Content-Type" , "application/json" )
80- w .Write ([]byte (`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}` ))
81- },
82- },
83- })
30+ func TestMaybeRefresh (t * testing.T ) {
31+ server := newRefreshServer (t , "new-token" )
8432 defer server .Close ()
8533
86- token := & Token {
87- Endpoint : server .URL ,
88- AccessToken : "expiring-soon-token" ,
89- RefreshToken : "refresh-token" ,
90- ExpiresAt : time .Now ().Add (10 * time .Second ), // expires in 10s (< 30s threshold)
91- }
92-
93- result , err := maybeRefresh (context .Background (), token )
94- if err != nil {
95- t .Fatalf ("maybeRefresh() error = %v" , err )
34+ tests := []struct {
35+ name string
36+ token * Token
37+ wantAccess string
38+ wantSame bool
39+ }{
40+ {
41+ name : "unchanged when still valid" ,
42+ token : & Token {
43+ AccessToken : "valid-token" ,
44+ ExpiresAt : time .Now ().Add (time .Hour ),
45+ },
46+ wantAccess : "valid-token" ,
47+ wantSame : true ,
48+ },
49+ {
50+ name : "refreshes expired token" ,
51+ token : & Token {
52+ Endpoint : server .URL ,
53+ AccessToken : "expired-token" ,
54+ RefreshToken : "refresh-token" ,
55+ ExpiresAt : time .Now ().Add (- time .Hour ),
56+ },
57+ wantAccess : "new-token" ,
58+ },
59+ {
60+ name : "refreshes token expiring soon" ,
61+ token : & Token {
62+ Endpoint : server .URL ,
63+ AccessToken : "expiring-soon-token" ,
64+ RefreshToken : "refresh-token" ,
65+ ExpiresAt : time .Now ().Add (10 * time .Second ),
66+ },
67+ wantAccess : "new-token" ,
68+ },
9669 }
9770
98- if result .AccessToken != "new-token" {
99- t .Errorf ("AccessToken = %q, want %q" , result .AccessToken , "new-token" )
71+ for _ , tt := range tests {
72+ t .Run (tt .name , func (t * testing.T ) {
73+ got , err := maybeRefresh (context .Background (), tt .token )
74+ if err != nil {
75+ t .Fatalf ("maybeRefresh() error = %v" , err )
76+ }
77+ if got .AccessToken != tt .wantAccess {
78+ t .Errorf ("AccessToken = %q, want %q" , got .AccessToken , tt .wantAccess )
79+ }
80+ if tt .wantSame && got != tt .token {
81+ t .Errorf ("token pointer changed for unexpired token" )
82+ }
83+ })
10084 }
10185}
10286
103- func TestTransport_RefreshPersistence (t * testing.T ) {
87+ func TestTransportRoundTrip (t * testing.T ) {
10488 tests := []struct {
105- name string
106- needsRefresh bool
107- persistErr error
108- wantAuthHeaderVal string
109- wantStoreCalls int
89+ name string
90+ token * Token
91+ persistErr error
92+ wantAuthHeader string
93+ wantStoreCalls int
11094 }{
11195 {
112- name : "persists refreshed token" ,
113- needsRefresh : true ,
114- wantAuthHeaderVal : "Bearer new-token" ,
115- wantStoreCalls : 1 ,
96+ name : "uses existing token without persisting" ,
97+ token : & Token {
98+ AccessToken : "valid-token" ,
99+ ExpiresAt : time .Now ().Add (time .Hour ),
100+ },
101+ wantAuthHeader : "Bearer valid-token" ,
102+ wantStoreCalls : 0 ,
116103 },
117104 {
118- name : "does not persist unchanged token" ,
119- wantAuthHeaderVal : "Bearer valid-token" ,
120- wantStoreCalls : 0 ,
105+ name : "persists refreshed token" ,
106+ token : & Token {
107+ AccessToken : "expired-token" ,
108+ RefreshToken : "refresh-token" ,
109+ ExpiresAt : time .Now ().Add (- time .Hour ),
110+ },
111+ wantAuthHeader : "Bearer new-token" ,
112+ wantStoreCalls : 1 ,
121113 },
122114 {
123- name : "persist failure does not fail request" ,
124- needsRefresh : true ,
125- persistErr : errors .New ("persist failed" ),
126- wantAuthHeaderVal : "Bearer new-token" ,
127- wantStoreCalls : 1 ,
115+ name : "ignores persist failures" ,
116+ token : & Token {
117+ AccessToken : "expired-token" ,
118+ RefreshToken : "refresh-token" ,
119+ ExpiresAt : time .Now ().Add (- time .Hour ),
120+ },
121+ persistErr : errors .New ("persist failed" ),
122+ wantAuthHeader : "Bearer new-token" ,
123+ wantStoreCalls : 1 ,
128124 },
129125 }
130126
131127 for _ , tt := range tests {
132128 t .Run (tt .name , func (t * testing.T ) {
129+ if tt .wantStoreCalls > 0 {
130+ server := newRefreshServer (t , "new-token" )
131+ defer server .Close ()
132+ tt .token .Endpoint = server .URL
133+ }
134+
133135 originalStoreFn := storeRefreshedTokenFn
134136 defer func () { storeRefreshedTokenFn = originalStoreFn }()
135137
@@ -141,57 +143,28 @@ func TestTransport_RefreshPersistence(t *testing.T) {
141143 return tt .persistErr
142144 }
143145
144- token := & Token {
145- AccessToken : "valid-token" ,
146- ExpiresAt : time .Now ().Add (time .Hour ),
147- }
148- if tt .needsRefresh {
149- server := newTestServer (t , testServerOptions {
150- handlers : map [string ]http.HandlerFunc {
151- testTokenPath : func (w http.ResponseWriter , r * http.Request ) {
152- w .Header ().Set ("Content-Type" , "application/json" )
153- w .Write ([]byte (`{"access_token":"new-token","refresh_token":"new-refresh","expires_in":3600}` ))
154- },
155- },
156- })
157- defer server .Close ()
158- token .Endpoint = server .URL
159- token .AccessToken = "expired-token"
160- token .RefreshToken = "refresh-token"
161- token .ExpiresAt = time .Now ().Add (- time .Hour )
162- }
163-
164146 var capturedAuth string
165- transport := & Transport {
166- Base : & mockRoundTripper {
167- handler : func (req * http.Request ) (* http.Response , error ) {
168- capturedAuth = req .Header .Get ("Authorization" )
169- return & http.Response {StatusCode : 200 }, nil
170- },
171- },
172- Token : token ,
147+ tr := & Transport {
148+ Base : roundTripperFunc (func (req * http.Request ) (* http.Response , error ) {
149+ capturedAuth = req .Header .Get ("Authorization" )
150+ return & http.Response {StatusCode : http .StatusOK , Body : http .NoBody }, nil
151+ }),
152+ Token : tt .token ,
173153 }
174154
175- req := httptest .NewRequest ("GET" , "http://example.com" , nil )
176- _ , err := transport .RoundTrip (req )
155+ _ , err := tr .RoundTrip (httptest .NewRequest (http .MethodGet , "http://example.com" , nil ))
177156 if err != nil {
178157 t .Fatalf ("RoundTrip() error = %v" , err )
179158 }
180159
181- if capturedAuth != tt .wantAuthHeaderVal {
182- t .Errorf ("Authorization = %q, want %q" , capturedAuth , tt .wantAuthHeaderVal )
160+ if capturedAuth != tt .wantAuthHeader {
161+ t .Errorf ("Authorization = %q, want %q" , capturedAuth , tt .wantAuthHeader )
183162 }
184163 if storeCalls != tt .wantStoreCalls {
185164 t .Errorf ("store calls = %d, want %d" , storeCalls , tt .wantStoreCalls )
186165 }
187-
188- if tt .needsRefresh {
189- if storedToken == nil {
190- t .Fatal ("stored token is nil" )
191- }
192- if storedToken .AccessToken != "new-token" {
193- t .Errorf ("stored AccessToken = %q, want %q" , storedToken .AccessToken , "new-token" )
194- }
166+ if tt .wantStoreCalls > 0 && (storedToken == nil || storedToken .AccessToken != "new-token" ) {
167+ t .Errorf ("stored token = %#v, want access token %q" , storedToken , "new-token" )
195168 }
196169 })
197170 }
0 commit comments