Skip to content

Commit 6ae1e33

Browse files
authored
refactor: make token a watch channel (#39)
* refactor: make token a watch channel * fix: docs * nit: drive-by add security note * chore: bump version * chore: misc cleanup and documentation * lint: clippy
1 parent e17be44 commit 6ae1e33

File tree

4 files changed

+202
-61
lines changed

4 files changed

+202
-61
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ name = "init4-bin-base"
44
description = "Internal utilities for binaries produced by the init4 team"
55
keywords = ["init4", "bin", "base"]
66

7-
version = "0.4.3"
7+
version = "0.5.0"
88
edition = "2021"
99
rust-version = "1.81"
1010
authors = ["init4", "James Prestwich"]

examples/oauth.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ async fn main() -> eyre::Result<()> {
99
let _jh = authenticator.spawn();
1010

1111
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
12-
dbg!(token.read());
12+
dbg!(token.secret().await.unwrap());
1313

1414
Ok(())
1515
}

src/perms/oauth.rs

Lines changed: 197 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@ use crate::{
44
deps::tracing::{error, info},
55
utils::from_env::FromEnv,
66
};
7+
use core::fmt;
78
use oauth2::{
89
basic::{BasicClient, BasicTokenType},
9-
AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet, EndpointSet,
10-
HttpClientError, RequestTokenError, StandardErrorResponse, StandardTokenResponse, TokenUrl,
10+
AccessToken, AuthUrl, ClientId, ClientSecret, EmptyExtraTokenFields, EndpointNotSet,
11+
EndpointSet, HttpClientError, RefreshToken, RequestTokenError, Scope, StandardErrorResponse,
12+
StandardTokenResponse, TokenResponse, TokenUrl,
13+
};
14+
use tokio::{
15+
sync::watch::{self, Ref},
16+
task::JoinHandle,
1117
};
12-
use std::sync::{Arc, Mutex};
13-
use tokio::task::JoinHandle;
1418

1519
type Token = StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>;
1620

@@ -57,38 +61,17 @@ impl OAuthConfig {
5761
}
5862
}
5963

60-
/// A shared token that can be read and written to by multiple threads.
61-
#[derive(Debug, Clone, Default)]
62-
pub struct SharedToken(Arc<Mutex<Option<Token>>>);
63-
64-
impl SharedToken {
65-
/// Read the token from the shared token.
66-
pub fn read(&self) -> Option<Token> {
67-
self.0.lock().unwrap().clone()
68-
}
69-
70-
/// Write a new token to the shared token.
71-
pub fn write(&self, token: Token) {
72-
let mut lock = self.0.lock().unwrap();
73-
*lock = Some(token);
74-
}
75-
76-
/// Check if the token is authenticated.
77-
pub fn is_authenticated(&self) -> bool {
78-
self.0.lock().unwrap().is_some()
79-
}
80-
}
81-
8264
/// A self-refreshing, periodically fetching authenticator for the block
83-
/// builder. This task periodically fetches a new token, and stores it in a
84-
/// [`SharedToken`].
65+
/// builder. This task periodically fetches a new token, and sends it to all
66+
/// active [`SharedToken`]s via a [`tokio::sync::watch`] channel..
8567
#[derive(Debug)]
8668
pub struct Authenticator {
8769
/// Configuration
88-
pub config: OAuthConfig,
70+
config: OAuthConfig,
8971
client: MyOAuthClient,
90-
token: SharedToken,
9172
reqwest: reqwest::Client,
73+
74+
token: watch::Sender<Option<Token>>,
9275
}
9376

9477
impl Authenticator {
@@ -99,6 +82,8 @@ impl Authenticator {
9982
.set_auth_uri(AuthUrl::from_url(config.oauth_authenticate_url.clone()))
10083
.set_token_uri(TokenUrl::from_url(config.oauth_token_url.clone()));
10184

85+
// NB: this is MANDATORY
86+
// https://docs.rs/oauth2/latest/oauth2/#security-warning
10287
let rq_client = reqwest::Client::builder()
10388
.redirect(reqwest::redirect::Policy::none())
10489
.build()
@@ -107,8 +92,8 @@ impl Authenticator {
10792
Self {
10893
config: config.clone(),
10994
client,
110-
token: Default::default(),
11195
reqwest: rq_client,
96+
token: watch::channel(None).0,
11297
}
11398
}
11499

@@ -129,20 +114,20 @@ impl Authenticator {
129114

130115
/// Returns true if there is Some token set
131116
pub fn is_authenticated(&self) -> bool {
132-
self.token.is_authenticated()
117+
self.token.borrow().is_some()
133118
}
134119

135120
/// Sets the Authenticator's token to the provided value
136121
fn set_token(&self, token: StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>) {
137-
self.token.write(token);
122+
self.token.send_replace(Some(token));
138123
}
139124

140125
/// Returns the currently set token
141126
pub fn token(&self) -> SharedToken {
142-
self.token.clone()
127+
self.token.subscribe().into()
143128
}
144129

145-
/// Fetches an oauth token
130+
/// Fetches an oauth token.
146131
pub async fn fetch_oauth_token(
147132
&self,
148133
) -> Result<
@@ -161,25 +146,184 @@ impl Authenticator {
161146
Ok(token_result)
162147
}
163148

164-
/// Spawns a task that periodically fetches a new token every 300 seconds.
165-
pub fn spawn(self) -> JoinHandle<()> {
149+
/// Get a reference to the OAuth configuration.
150+
pub const fn config(&self) -> &OAuthConfig {
151+
&self.config
152+
}
153+
154+
/// Create a future that contains the periodic refresh loop.
155+
async fn task_future(self) {
166156
let interval = self.config.oauth_token_refresh_interval;
167157

168-
let handle: JoinHandle<()> = tokio::spawn(async move {
169-
loop {
170-
info!("Refreshing oauth token");
171-
match self.authenticate().await {
172-
Ok(_) => {
173-
info!("Successfully refreshed oauth token");
174-
}
175-
Err(e) => {
176-
error!(%e, "Failed to refresh oauth token");
177-
}
178-
};
179-
let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
180-
}
181-
});
182-
183-
handle
158+
loop {
159+
info!("Refreshing oauth token");
160+
match self.authenticate().await {
161+
Ok(_) => {
162+
info!("Successfully refreshed oauth token");
163+
}
164+
Err(e) => {
165+
error!(%e, "Failed to refresh oauth token");
166+
}
167+
};
168+
let _sleep = tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
169+
}
170+
}
171+
172+
/// Spawns a task that periodically fetches a new token. The refresh
173+
/// interval may be configured via the
174+
/// [`OAuthConfig::oauth_token_refresh_interval`] property.
175+
pub fn spawn(self) -> JoinHandle<()> {
176+
tokio::spawn(self.task_future())
177+
}
178+
}
179+
180+
/// A shared token, wrapped in a [`tokio::sync::watch`] Receiver. The token is
181+
/// periodically refreshed by an [`Authenticator`] task, and can be awaited
182+
/// for when it becomes available.
183+
///
184+
/// This allows multiple tasks to wait for the token to be available, and
185+
/// provides a way to check if the token is authenticated without blocking.
186+
/// Please consult the [`Receiver`] documentation for caveats regarding
187+
/// usage.
188+
///
189+
/// [`Receiver`]: tokio::sync::watch::Receiver
190+
#[derive(Debug, Clone)]
191+
pub struct SharedToken(watch::Receiver<Option<Token>>);
192+
193+
impl From<watch::Receiver<Option<Token>>> for SharedToken {
194+
fn from(inner: watch::Receiver<Option<Token>>) -> Self {
195+
Self(inner)
196+
}
197+
}
198+
199+
impl SharedToken {
200+
/// Wait for the token to be available, and get a reference to the secret.
201+
///
202+
/// This is implemented using [`Receiver::wait_for`], and has the same
203+
/// blocking, panics, errors, and cancel safety. However, it uses a clone
204+
/// of the [`watch::Receiver`] and will not update the local view of the
205+
/// channel.
206+
///
207+
/// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
208+
pub async fn secret(&self) -> Result<String, watch::error::RecvError> {
209+
Ok(self
210+
.clone()
211+
.token()
212+
.await?
213+
.access_token()
214+
.secret()
215+
.to_owned())
216+
}
217+
218+
/// Wait for the token to be available, then get a reference to it.
219+
///
220+
/// Holding this reference will block the background task from updating
221+
/// the token until it is dropped, so it is recommended to drop this
222+
/// reference as soon as possible.
223+
///
224+
/// This is implemented using [`Receiver::wait_for`], and has the same
225+
/// blocking, panics, errors, and cancel safety. Unlike [`Self::secret`]
226+
/// it is NOT implemented using a clone, and will update the local view of
227+
/// the channel.
228+
///
229+
/// Generally, prefer using [`Self::secret`] for simple use cases, and
230+
/// this when deeper inspection of the token is required.
231+
///
232+
/// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
233+
pub async fn token(&mut self) -> Result<TokenRef<'_>, watch::error::RecvError> {
234+
self.0.wait_for(Option::is_some).await.map(Into::into)
235+
}
236+
237+
/// Create a future that will resolve when the token is ready.
238+
///
239+
/// This is implemented using [`Receiver::wait_for`], and has the same
240+
/// blocking, panics, errors, and cancel safety.
241+
///
242+
/// [`Receiver::wait_for`]: tokio::sync::watch::Receiver::wait_for
243+
pub async fn wait(&self) -> Result<(), watch::error::RecvError> {
244+
self.clone().0.wait_for(Option::is_some).await.map(drop)
245+
}
246+
247+
/// Borrow the current token, if available. If called before the token is
248+
/// set by the authentication task, this will return `None`.
249+
///
250+
/// Holding this reference will block the background task from updating
251+
/// the token until it is dropped, so it is recommended to drop this
252+
/// reference as soon as possible.
253+
///
254+
/// This is implemented using [`Receiver::borrow`].
255+
///
256+
/// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow
257+
pub fn borrow(&mut self) -> Ref<'_, Option<Token>> {
258+
self.0.borrow()
259+
}
260+
261+
/// Check if the background task has produced an authentication token.
262+
///
263+
/// Holding this reference will block the background task from updating
264+
/// the token until it is dropped, so it is recommended to drop this
265+
/// reference as soon as possible.
266+
///
267+
/// This is implemented using [`Receiver::borrow`].
268+
///
269+
/// [`Receiver::borrow`]: tokio::sync::watch::Receiver::borrow
270+
pub fn is_authenticated(&self) -> bool {
271+
self.0.borrow().is_some()
272+
}
273+
}
274+
275+
/// A reference to token data, contained in a [`SharedToken`].
276+
///
277+
/// This is implemented using [`watch::Ref`], and as a result holds a lock on
278+
/// the token data. Holding this reference will block the background task
279+
/// from updating the token until it is dropped, so it is recommended to drop
280+
/// this reference as soon as possible.
281+
pub struct TokenRef<'a> {
282+
inner: Ref<'a, Option<Token>>,
283+
}
284+
285+
impl<'a> From<Ref<'a, Option<Token>>> for TokenRef<'a> {
286+
fn from(inner: Ref<'a, Option<Token>>) -> Self {
287+
Self { inner }
288+
}
289+
}
290+
291+
impl fmt::Debug for TokenRef<'_> {
292+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293+
f.debug_struct("TokenRef").finish_non_exhaustive()
294+
}
295+
}
296+
297+
impl<'a> TokenRef<'a> {
298+
/// Get a reference to the inner token.
299+
pub fn inner(&'a self) -> &'a Token {
300+
self.inner.as_ref().unwrap()
301+
}
302+
303+
/// Get a reference to the [`AccessToken`] contained in the token.
304+
pub fn access_token(&self) -> &AccessToken {
305+
self.inner().access_token()
306+
}
307+
308+
/// Get a reference to the [`TokenType`] instance contained in the token.
309+
///
310+
/// [`TokenType`]: oauth2::TokenType
311+
pub fn token_type(&self) -> &<Token as TokenResponse>::TokenType {
312+
self.inner().token_type()
313+
}
314+
315+
/// Get a reference to the current token's expiration time, if it has one.
316+
pub fn expires_in(&self) -> Option<std::time::Duration> {
317+
self.inner().expires_in()
318+
}
319+
320+
/// Get a reference to the refresh token, if it exists.
321+
pub fn refresh_token(&self) -> Option<&RefreshToken> {
322+
self.inner().refresh_token()
323+
}
324+
325+
/// Get a reference to the scopes associated with the token, if any.
326+
pub fn scopes(&self) -> Option<&Vec<Scope>> {
327+
self.inner().scopes()
184328
}
185329
}

src/perms/tx_cache.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::perms::oauth::SharedToken;
2-
use eyre::{bail, Result};
3-
use oauth2::TokenResponse;
2+
use eyre::Result;
43
use serde::de::DeserializeOwned;
54
use signet_tx_cache::{
65
client::TxCache,
@@ -53,14 +52,12 @@ impl BuilderTxCache {
5352

5453
async fn get_inner_with_token<T: DeserializeOwned>(&self, join: &str) -> Result<T> {
5554
let url = self.tx_cache.url().join(join)?;
56-
let Some(token) = self.token.read() else {
57-
bail!("No token available for authentication");
58-
};
55+
let secret = self.token.secret().await?;
5956

6057
self.tx_cache
6158
.client()
6259
.get(url)
63-
.bearer_auth(token.access_token().secret())
60+
.bearer_auth(secret)
6461
.send()
6562
.await
6663
.inspect_err(|e| warn!(%e, "Failed to get object from transaction cache"))?

0 commit comments

Comments
 (0)