@@ -4,13 +4,17 @@ use crate::{
4
4
deps:: tracing:: { error, info} ,
5
5
utils:: from_env:: FromEnv ,
6
6
} ;
7
+ use core:: fmt;
7
8
use oauth2:: {
8
9
basic:: { BasicClient , BasicTokenType } ,
9
- AuthUrl , ClientId , ClientSecret , EmptyExtraTokenFields , EndpointNotSet , EndpointSet ,
10
- HttpClientError , RequestTokenError , StandardErrorResponse , StandardTokenResponse , TokenUrl ,
10
+ AccessToken , AuthUrl , ClientId , ClientSecret , EmptyExtraTokenFields , EndpointNotSet ,
11
+ EndpointSet , HttpClientError , RefreshToken , RequestTokenError , Scope , StandardErrorResponse ,
12
+ StandardTokenResponse , TokenResponse , TokenUrl ,
13
+ } ;
14
+ use tokio:: {
15
+ sync:: watch:: { self , Ref } ,
16
+ task:: JoinHandle ,
11
17
} ;
12
- use std:: sync:: { Arc , Mutex } ;
13
- use tokio:: task:: JoinHandle ;
14
18
15
19
type Token = StandardTokenResponse < EmptyExtraTokenFields , BasicTokenType > ;
16
20
@@ -57,38 +61,17 @@ impl OAuthConfig {
57
61
}
58
62
}
59
63
60
- /// A shared token that can be read and written to by multiple threads.
61
- #[ derive( Debug , Clone , Default ) ]
62
- pub struct SharedToken ( Arc < Mutex < Option < Token > > > ) ;
63
-
64
- impl SharedToken {
65
- /// Read the token from the shared token.
66
- pub fn read ( & self ) -> Option < Token > {
67
- self . 0 . lock ( ) . unwrap ( ) . clone ( )
68
- }
69
-
70
- /// Write a new token to the shared token.
71
- pub fn write ( & self , token : Token ) {
72
- let mut lock = self . 0 . lock ( ) . unwrap ( ) ;
73
- * lock = Some ( token) ;
74
- }
75
-
76
- /// Check if the token is authenticated.
77
- pub fn is_authenticated ( & self ) -> bool {
78
- self . 0 . lock ( ) . unwrap ( ) . is_some ( )
79
- }
80
- }
81
-
82
64
/// A self-refreshing, periodically fetching authenticator for the block
83
- /// builder. This task periodically fetches a new token, and stores it in a
84
- /// [`SharedToken`].
65
+ /// builder. This task periodically fetches a new token, and sends it to all
66
+ /// active [`SharedToken`]s via a [`tokio::sync::watch`] channel. .
85
67
#[ derive( Debug ) ]
86
68
pub struct Authenticator {
87
69
/// Configuration
88
- pub config : OAuthConfig ,
70
+ config : OAuthConfig ,
89
71
client : MyOAuthClient ,
90
- token : SharedToken ,
91
72
reqwest : reqwest:: Client ,
73
+
74
+ token : watch:: Sender < Option < Token > > ,
92
75
}
93
76
94
77
impl Authenticator {
@@ -99,6 +82,8 @@ impl Authenticator {
99
82
. set_auth_uri ( AuthUrl :: from_url ( config. oauth_authenticate_url . clone ( ) ) )
100
83
. set_token_uri ( TokenUrl :: from_url ( config. oauth_token_url . clone ( ) ) ) ;
101
84
85
+ // NB: this is MANDATORY
86
+ // https://docs.rs/oauth2/latest/oauth2/#security-warning
102
87
let rq_client = reqwest:: Client :: builder ( )
103
88
. redirect ( reqwest:: redirect:: Policy :: none ( ) )
104
89
. build ( )
@@ -107,8 +92,8 @@ impl Authenticator {
107
92
Self {
108
93
config : config. clone ( ) ,
109
94
client,
110
- token : Default :: default ( ) ,
111
95
reqwest : rq_client,
96
+ token : watch:: channel ( None ) . 0 ,
112
97
}
113
98
}
114
99
@@ -129,20 +114,20 @@ impl Authenticator {
129
114
130
115
/// Returns true if there is Some token set
131
116
pub fn is_authenticated ( & self ) -> bool {
132
- self . token . is_authenticated ( )
117
+ self . token . borrow ( ) . is_some ( )
133
118
}
134
119
135
120
/// Sets the Authenticator's token to the provided value
136
121
fn set_token ( & self , token : StandardTokenResponse < EmptyExtraTokenFields , BasicTokenType > ) {
137
- self . token . write ( token) ;
122
+ self . token . send_replace ( Some ( token) ) ;
138
123
}
139
124
140
125
/// Returns the currently set token
141
126
pub fn token ( & self ) -> SharedToken {
142
- self . token . clone ( )
127
+ self . token . subscribe ( ) . into ( )
143
128
}
144
129
145
- /// Fetches an oauth token
130
+ /// Fetches an oauth token.
146
131
pub async fn fetch_oauth_token (
147
132
& self ,
148
133
) -> Result <
@@ -161,25 +146,184 @@ impl Authenticator {
161
146
Ok ( token_result)
162
147
}
163
148
164
- /// Spawns a task that periodically fetches a new token every 300 seconds.
165
- pub fn spawn ( self ) -> JoinHandle < ( ) > {
149
+ /// Get a reference to the OAuth configuration.
150
+ pub const fn config ( & self ) -> & OAuthConfig {
151
+ & self . config
152
+ }
153
+
154
+ /// Create a future that contains the periodic refresh loop.
155
+ async fn task_future ( self ) {
166
156
let interval = self . config . oauth_token_refresh_interval ;
167
157
168
- let handle: JoinHandle < ( ) > = tokio:: spawn ( async move {
169
- loop {
170
- info ! ( "Refreshing oauth token" ) ;
171
- match self . authenticate ( ) . await {
172
- Ok ( _) => {
173
- info ! ( "Successfully refreshed oauth token" ) ;
174
- }
175
- Err ( e) => {
176
- error ! ( %e, "Failed to refresh oauth token" ) ;
177
- }
178
- } ;
179
- let _sleep = tokio:: time:: sleep ( tokio:: time:: Duration :: from_secs ( interval) ) . await ;
180
- }
181
- } ) ;
182
-
183
- handle
158
+ loop {
159
+ info ! ( "Refreshing oauth token" ) ;
160
+ match self . authenticate ( ) . await {
161
+ Ok ( _) => {
162
+ info ! ( "Successfully refreshed oauth token" ) ;
163
+ }
164
+ Err ( e) => {
165
+ error ! ( %e, "Failed to refresh oauth token" ) ;
166
+ }
167
+ } ;
168
+ let _sleep = tokio:: time:: sleep ( tokio:: time:: Duration :: from_secs ( interval) ) . await ;
169
+ }
170
+ }
171
+
172
+ /// Spawns a task that periodically fetches a new token. The refresh
173
+ /// interval may be configured via the
174
+ /// [`OAuthConfig::oauth_token_refresh_interval`] property.
175
+ pub fn spawn ( self ) -> JoinHandle < ( ) > {
176
+ tokio:: spawn ( self . task_future ( ) )
177
+ }
178
+ }
179
+
180
+ /// A shared token, wrapped in a [`tokio::sync::watch`] Receiver. The token is
181
+ /// periodically refreshed by an [`Authenticator`] task, and can be awaited
182
+ /// for when it becomes available.
183
+ ///
184
+ /// This allows multiple tasks to wait for the token to be available, and
185
+ /// provides a way to check if the token is authenticated without blocking.
186
+ /// Please consult the [`Receiver`] documentation for caveats regarding
187
+ /// usage.
188
+ ///
189
+ /// [`Receiver`]: tokio::sync::watch::Receiver
190
+ #[ derive( Debug , Clone ) ]
191
+ pub struct SharedToken ( watch:: Receiver < Option < Token > > ) ;
192
+
193
+ impl From < watch:: Receiver < Option < Token > > > for SharedToken {
194
+ fn from ( inner : watch:: Receiver < Option < Token > > ) -> Self {
195
+ Self ( inner)
196
+ }
197
+ }
198
+
199
+ impl SharedToken {
200
+ /// Wait for the token to be available, and get a reference to the secret.
201
+ ///
202
+ /// This is implemented using [`Receiver::wait_for`], and has the same
203
+ /// blocking, panics, errors, and cancel safety. However, it uses a clone
204
+ /// of the [`watch::Receiver`] and will not update the local view of the
205
+ /// channel.
206
+ ///
207
+ /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
208
+ pub async fn secret ( & self ) -> Result < String , watch:: error:: RecvError > {
209
+ Ok ( self
210
+ . clone ( )
211
+ . token ( )
212
+ . await ?
213
+ . access_token ( )
214
+ . secret ( )
215
+ . to_owned ( ) )
216
+ }
217
+
218
+ /// Wait for the token to be available, then get a reference to it.
219
+ ///
220
+ /// Holding this reference will block the background task from updating
221
+ /// the token until it is dropped, so it is recommended to drop this
222
+ /// reference as soon as possible.
223
+ ///
224
+ /// This is implemented using [`Receiver::wait_for`], and has the same
225
+ /// blocking, panics, errors, and cancel safety. Unlike [`Self::secret`]
226
+ /// it is NOT implemented using a clone, and will update the local view of
227
+ /// the channel.
228
+ ///
229
+ /// Generally, prefer using [`Self::secret`] for simple use cases, and
230
+ /// this when deeper inspection of the token is required.
231
+ ///
232
+ /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
233
+ pub async fn token ( & mut self ) -> Result < TokenRef < ' _ > , watch:: error:: RecvError > {
234
+ self . 0 . wait_for ( Option :: is_some) . await . map ( Into :: into)
235
+ }
236
+
237
+ /// Create a future that will resolve when the token is ready.
238
+ ///
239
+ /// This is implemented using [`Receiver::wait_for`], and has the same
240
+ /// blocking, panics, errors, and cancel safety.
241
+ ///
242
+ /// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
243
+ pub async fn wait ( & self ) -> Result < ( ) , watch:: error:: RecvError > {
244
+ self . clone ( ) . 0 . wait_for ( Option :: is_some) . await . map ( drop)
245
+ }
246
+
247
+ /// Borrow the current token, if available. If called before the token is
248
+ /// set by the authentication task, this will return `None`.
249
+ ///
250
+ /// Holding this reference will block the background task from updating
251
+ /// the token until it is dropped, so it is recommended to drop this
252
+ /// reference as soon as possible.
253
+ ///
254
+ /// This is implemented using [`Receiver::borrow`].
255
+ ///
256
+ /// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow
257
+ pub fn borrow ( & mut self ) -> Ref < ' _ , Option < Token > > {
258
+ self . 0 . borrow ( )
259
+ }
260
+
261
+ /// Check if the background task has produced an authentication token.
262
+ ///
263
+ /// Holding this reference will block the background task from updating
264
+ /// the token until it is dropped, so it is recommended to drop this
265
+ /// reference as soon as possible.
266
+ ///
267
+ /// This is implemented using [`Receiver::borrow`].
268
+ ///
269
+ /// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow
270
+ pub fn is_authenticated ( & self ) -> bool {
271
+ self . 0 . borrow ( ) . is_some ( )
272
+ }
273
+ }
274
+
275
+ /// A reference to token data, contained in a [`SharedToken`].
276
+ ///
277
+ /// This is implemented using [`watch::Ref`], and as a result holds a lock on
278
+ /// the token data. Holding this reference will block the background task
279
+ /// from updating the token until it is dropped, so it is recommended to drop
280
+ /// this reference as soon as possible.
281
+ pub struct TokenRef < ' a > {
282
+ inner : Ref < ' a , Option < Token > > ,
283
+ }
284
+
285
+ impl < ' a > From < Ref < ' a , Option < Token > > > for TokenRef < ' a > {
286
+ fn from ( inner : Ref < ' a , Option < Token > > ) -> Self {
287
+ Self { inner }
288
+ }
289
+ }
290
+
291
+ impl fmt:: Debug for TokenRef < ' _ > {
292
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
293
+ f. debug_struct ( "TokenRef" ) . finish_non_exhaustive ( )
294
+ }
295
+ }
296
+
297
+ impl < ' a > TokenRef < ' a > {
298
+ /// Get a reference to the inner token.
299
+ pub fn inner ( & ' a self ) -> & ' a Token {
300
+ self . inner . as_ref ( ) . unwrap ( )
301
+ }
302
+
303
+ /// Get a reference to the [`AccessToken`] contained in the token.
304
+ pub fn access_token ( & self ) -> & AccessToken {
305
+ self . inner ( ) . access_token ( )
306
+ }
307
+
308
+ /// Get a reference to the [`TokenType`] instance contained in the token.
309
+ ///
310
+ /// [`TokenType`]: oauth2::TokenType
311
+ pub fn token_type ( & self ) -> & <Token as TokenResponse >:: TokenType {
312
+ self . inner ( ) . token_type ( )
313
+ }
314
+
315
+ /// Get a reference to the current token's expiration time, if it has one.
316
+ pub fn expires_in ( & self ) -> Option < std:: time:: Duration > {
317
+ self . inner ( ) . expires_in ( )
318
+ }
319
+
320
+ /// Get a reference to the refresh token, if it exists.
321
+ pub fn refresh_token ( & self ) -> Option < & RefreshToken > {
322
+ self . inner ( ) . refresh_token ( )
323
+ }
324
+
325
+ /// Get a reference to the scopes associated with the token, if any.
326
+ pub fn scopes ( & self ) -> Option < & Vec < Scope > > {
327
+ self . inner ( ) . scopes ( )
184
328
}
185
329
}
0 commit comments