Skip to content

Commit afffdad

Browse files
committed
start, add an in-memory map of inflight requests to de-dupe them
1 parent 464a6b6 commit afffdad

File tree

5 files changed

+233
-36
lines changed

5 files changed

+233
-36
lines changed

src/frontegg-auth/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ reqwest = { version = "0.11.13", features = ["json"] }
1616
reqwest-middleware = "0.2.2"
1717
reqwest-retry = "0.2.2"
1818
serde = { version = "1.0.152", features = ["derive"] }
19+
serde_json = "1.0.89"
1920
thiserror = "1.0.37"
2021
tokio = { version = "1.24.2", features = ["macros"] }
2122
tracing = "0.1.37"
@@ -24,7 +25,7 @@ workspace-hack = { version = "0.0.0", path = "../workspace-hack" }
2425

2526
[dev-dependencies]
2627
axum = "0.6.20"
27-
serde_json = "1.0.89"
28+
mz-ore = { path = "../ore", features = ["network", "test"] }
2829
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }
2930

3031
[package.metadata.cargo-udeps.ignore]

src/frontegg-auth/src/app_password.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl FromStr for AppPassword {
9090
}
9191

9292
/// An error while parsing an [`AppPassword`].
93-
#[derive(Debug)]
93+
#[derive(Clone, Debug)]
9494
pub struct AppPasswordParseError;
9595

9696
impl Error for AppPasswordParseError {}

src/frontegg-auth/src/client.rs

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,33 @@
77
// the Business Source License, use of this software will be governed
88
// by the Apache License, Version 2.0.
99

10+
use std::sync::{Arc, Mutex};
1011
use std::time::Duration;
1112

13+
use anyhow::Context;
14+
use mz_ore::collections::HashMap;
15+
use tokio::sync::oneshot;
16+
1217
use reqwest_retry::policies::ExponentialBackoff;
1318
use reqwest_retry::RetryTransientMiddleware;
1419

20+
use crate::client::tokens::RefreshTokenResponse;
21+
use crate::{ApiTokenArgs, ApiTokenResponse, Error, RefreshToken};
22+
1523
pub mod tokens;
1624

25+
/// Client for Frontegg auth requests.
26+
///
27+
/// Internally the client will attempt to de-dupe requests, e.g. if a single user tries to connect
28+
/// many clients at once, we'll de-dupe the authentication requests.
1729
#[derive(Clone, Debug)]
1830
pub struct Client {
1931
pub client: reqwest_middleware::ClientWithMiddleware,
32+
inflight_requests: Arc<Mutex<HashMap<Request, ResponseHandle>>>,
2033
}
2134

35+
type ResponseHandle = Vec<oneshot::Sender<Result<Response, Error>>>;
36+
2237
impl Default for Client {
2338
fn default() -> Self {
2439
// Re-use the envd defaults until there's a reason to use something else. This is a separate
@@ -44,6 +59,167 @@ impl Client {
4459
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
4560
.build();
4661

47-
Self { client }
62+
let inflight_requests = Arc::new(Mutex::new(HashMap::new()));
63+
64+
Self {
65+
client,
66+
inflight_requests,
67+
}
68+
}
69+
70+
/// Makes a request to the provided URL, possibly de-duping by attaching a listener to an
71+
/// already in-flight request.
72+
async fn make_request<Req, Resp>(&self, url: String, req: Req) -> Result<Resp, Error>
73+
where
74+
Req: AuthRequest,
75+
Resp: AuthResponse,
76+
{
77+
let req = req.into_request();
78+
79+
// Note: we get the reciever in a block to scope the access to the mutex.
80+
let rx = {
81+
let mut inflight_requests = self
82+
.inflight_requests
83+
.lock()
84+
.expect("Frontegg Auth Client panicked");
85+
let (tx, rx) = tokio::sync::oneshot::channel();
86+
87+
match inflight_requests.get_mut(&req) {
88+
// Already have an inflight request, add to our list of waiters.
89+
Some(senders) => {
90+
tracing::debug!("reusing request, {req:?}");
91+
senders.push(tx);
92+
rx
93+
}
94+
// New request! Need to queue one up.
95+
None => {
96+
tracing::debug!("spawning new request, {req:?}");
97+
98+
inflight_requests.insert(req.clone(), vec![tx]);
99+
100+
let client = self.client.clone();
101+
let inflight = Arc::clone(&self.inflight_requests);
102+
let req_ = req.clone();
103+
104+
mz_ore::task::spawn(move || "frontegg-auth-request", async move {
105+
// Make the actual request.
106+
let result = async {
107+
let resp = client
108+
.post(&url)
109+
.json(&req_.into_json())
110+
.send()
111+
.await?
112+
.error_for_status()?
113+
.json::<Resp>()
114+
.await?;
115+
Ok::<_, Error>(resp)
116+
}
117+
.await;
118+
119+
// Get all of our waiters.
120+
let mut inflight = inflight.lock().expect("Frontegg Auth Client panicked");
121+
let Some(waiters) = inflight.remove(&req) else {
122+
tracing::error!("Inflight entry already removed? {req:?}");
123+
return;
124+
};
125+
126+
// Tell all of our waiters about the result.
127+
let response = result.map(|r| r.into_response());
128+
for tx in waiters {
129+
let _ = tx.send(response.clone());
130+
}
131+
});
132+
133+
rx
134+
}
135+
}
136+
};
137+
138+
let resp = rx.await.context("waiting for inflight response")?;
139+
resp.map(|r| Resp::from_response(r))
140+
}
141+
}
142+
143+
/// Boilerplate for de-duping requests.
144+
///
145+
/// We maintain an in-memory map of inflight requests, and that map needs to have keys of a single
146+
/// type, so we wrap all of our request types an an enum to create that single type.
147+
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
148+
enum Request {
149+
ExchangeSecretForToken(ApiTokenArgs),
150+
RefreshToken(RefreshToken),
151+
}
152+
153+
impl Request {
154+
fn into_json(self) -> serde_json::Value {
155+
match self {
156+
Request::ExchangeSecretForToken(arg) => serde_json::to_value(arg),
157+
Request::RefreshToken(arg) => serde_json::to_value(arg),
158+
}
159+
.expect("converting to JSON cannot fail")
160+
}
161+
}
162+
163+
/// Boilerplate for de-duping requests.
164+
///
165+
/// Deduplicates the wrapping of request types into a [`Request`].
166+
trait AuthRequest: serde::Serialize + Clone {
167+
fn into_request(self) -> Request;
168+
}
169+
170+
impl AuthRequest for ApiTokenArgs {
171+
fn into_request(self) -> Request {
172+
Request::ExchangeSecretForToken(self)
173+
}
174+
}
175+
176+
impl AuthRequest for RefreshToken {
177+
fn into_request(self) -> Request {
178+
Request::RefreshToken(self)
179+
}
180+
}
181+
182+
/// Boilerplate for de-duping requests.
183+
///
184+
/// We maintain an in-memory map of inflight requests, the values of the map are a Vec of waiters
185+
/// that listen for a response. These listeners all need to have the same type, so we wrap all of
186+
/// our response types in an enum.
187+
#[derive(Clone, Debug)]
188+
enum Response {
189+
ExchangeSecretForToken(ApiTokenResponse),
190+
RefreshToken(RefreshTokenResponse),
191+
}
192+
193+
/// Boilerplate for de-duping requests.
194+
///
195+
/// Deduplicates the wrapping and unwrapping between response types and [`Response`].
196+
trait AuthResponse: serde::de::DeserializeOwned {
197+
fn into_response(self) -> Response;
198+
fn from_response(resp: Response) -> Self;
199+
}
200+
201+
impl AuthResponse for ApiTokenResponse {
202+
fn into_response(self) -> Response {
203+
Response::ExchangeSecretForToken(self)
204+
}
205+
206+
fn from_response(resp: Response) -> Self {
207+
let Response::ExchangeSecretForToken(result) = resp else {
208+
unreachable!("programming error!, didn't roundtrip {resp:?}")
209+
};
210+
result
211+
}
212+
}
213+
214+
impl AuthResponse for RefreshTokenResponse {
215+
fn into_response(self) -> Response {
216+
Response::RefreshToken(self)
217+
}
218+
219+
fn from_response(resp: Response) -> Self {
220+
let Response::RefreshToken(result) = resp else {
221+
unreachable!("programming error!, didn't roundtrip")
222+
};
223+
result
48224
}
49225
}

src/frontegg-auth/src/client/tokens.rs

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,11 @@ impl Client {
1919
secret: Uuid,
2020
admin_api_token_url: &str,
2121
) -> Result<ApiTokenResponse, Error> {
22-
let res = self
23-
.client
24-
.post(admin_api_token_url)
25-
.json(&ApiTokenArgs { client_id, secret })
26-
.send()
27-
.await?
28-
.error_for_status()?
29-
.json::<ApiTokenResponse>()
22+
let args = ApiTokenArgs { client_id, secret };
23+
let result = self
24+
.make_request(admin_api_token_url.to_string(), args)
3025
.await?;
31-
Ok(res)
26+
Ok(result)
3227
}
3328

3429
/// Exchanges a client id and secret for a jwt token.
@@ -37,20 +32,22 @@ impl Client {
3732
refresh_url: &str,
3833
refresh_token: &str,
3934
) -> Result<ApiTokenResponse, Error> {
40-
let res = self
41-
.client
42-
.post(refresh_url)
43-
.json(&RefreshToken { refresh_token })
44-
.send()
45-
.await?
46-
.error_for_status()?
47-
.json::<ApiTokenResponse>()
48-
.await?;
49-
Ok(res)
35+
let args = RefreshToken {
36+
refresh_token: refresh_token.to_string(),
37+
};
38+
let result = self.make_request(refresh_url.to_string(), args).await?;
39+
Ok(result)
5040
}
5141
}
5242

53-
#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
43+
#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
44+
#[serde(rename_all = "camelCase")]
45+
pub struct ApiTokenArgs {
46+
pub client_id: Uuid,
47+
pub secret: Uuid,
48+
}
49+
50+
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
5451
#[serde(rename_all = "camelCase")]
5552
pub struct ApiTokenResponse {
5653
pub expires: String,
@@ -59,17 +56,19 @@ pub struct ApiTokenResponse {
5956
pub refresh_token: String,
6057
}
6158

62-
#[derive(Debug, serde::Serialize, serde::Deserialize)]
59+
#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
6360
#[serde(rename_all = "camelCase")]
64-
pub struct ApiTokenArgs {
65-
pub client_id: Uuid,
66-
pub secret: Uuid,
61+
pub struct RefreshToken {
62+
pub refresh_token: String,
6763
}
6864

69-
#[derive(Debug, serde::Serialize, serde::Deserialize)]
65+
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
7066
#[serde(rename_all = "camelCase")]
71-
pub struct RefreshToken<'a> {
72-
pub refresh_token: &'a str,
67+
pub struct RefreshTokenResponse {
68+
pub expires: String,
69+
pub expires_in: i64,
70+
pub access_token: String,
71+
pub refresh_token: String,
7372
}
7473

7574
#[cfg(test)]

src/frontegg-auth/src/error.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,56 @@
77
// the Business Source License, use of this software will be governed
88
// by the Apache License, Version 2.0.
99

10+
use std::sync::Arc;
1011
use thiserror::Error;
1112

1213
use crate::AppPasswordParseError;
1314

14-
#[derive(Error, Debug)]
15+
#[derive(Clone, Error, Debug)]
1516
pub enum Error {
1617
#[error(transparent)]
1718
InvalidPasswordFormat(#[from] AppPasswordParseError),
1819
#[error("invalid token format: {0}")]
1920
InvalidTokenFormat(#[from] jsonwebtoken::errors::Error),
2021
#[error("authentication token exchange failed: {0}")]
21-
ReqwestError(#[from] reqwest::Error),
22+
ReqwestError(Arc<reqwest::Error>),
2223
#[error("middleware programming error: {0}")]
23-
MiddlewareError(anyhow::Error),
24+
MiddlewareError(Arc<anyhow::Error>),
2425
#[error("authentication token expired")]
2526
TokenExpired,
2627
#[error("unauthorized organization")]
2728
UnauthorizedTenant,
2829
#[error("email in access token did not match the expected email")]
2930
WrongEmail,
3031
#[error("request timeout")]
31-
Timeout(#[from] tokio::time::error::Elapsed),
32+
Timeout(Arc<tokio::time::error::Elapsed>),
33+
#[error("internal error")]
34+
Internal(Arc<anyhow::Error>),
35+
}
36+
37+
impl From<anyhow::Error> for Error {
38+
fn from(value: anyhow::Error) -> Self {
39+
Error::Internal(Arc::new(value))
40+
}
41+
}
42+
43+
impl From<tokio::time::error::Elapsed> for Error {
44+
fn from(value: tokio::time::error::Elapsed) -> Self {
45+
Error::Timeout(Arc::new(value))
46+
}
47+
}
48+
49+
impl From<reqwest::Error> for Error {
50+
fn from(value: reqwest::Error) -> Self {
51+
Error::ReqwestError(Arc::new(value))
52+
}
3253
}
3354

3455
impl From<reqwest_middleware::Error> for Error {
3556
fn from(value: reqwest_middleware::Error) -> Self {
3657
match value {
37-
reqwest_middleware::Error::Middleware(e) => Error::MiddlewareError(e),
38-
reqwest_middleware::Error::Reqwest(e) => Error::ReqwestError(e),
58+
reqwest_middleware::Error::Middleware(e) => Error::MiddlewareError(Arc::new(e)),
59+
reqwest_middleware::Error::Reqwest(e) => Error::ReqwestError(Arc::new(e)),
3960
}
4061
}
4162
}

0 commit comments

Comments
 (0)