Skip to content

Commit

Permalink
auth: de-dupe inflight requests (#21801)
Browse files Browse the repository at this point in the history
This PR updates the `Client` we use to make Frontegg Auth requests to
dedupe inflight requests. Specifically, when a request is made we check
if a request to that endpoint with those arguments is already inflight,
if so, we do not issue a second request and instead we attach a
listener/waiter to the already inflight request.

### Motivation

Helps improve:
https://github.com/MaterializeInc/materialize/issues/21782

[Frontegg
documents](https://docs.frontegg.com/docs/frontegg-rate-limit-policies#limits-for-frontegg-workspaces)
the API we use for getting auth tokens as accepting 100 requests
per-second. As we attempt to scale to supporting thousands of concurrent
connection requests (per-user), we hit Frontegg's request limit.

With this change +
#21783 we have the
following latencies when opening concurrent connections:

num requests | p50      | p90     | p99
-------------|----------|---------|------- 
32           |   34ms   | 371ms   | 495ms
64           |   25ms   | 285ms   | 367ms
128          |   31ms   | 189ms   | 331ms
256          |   66ms   | 565ms   | 660ms
512          |  4044ms  | 4828ms  | 4977ms
1024         |  9114ms  | 9880ms  | 10038ms
2048         |  20031ms | 20784ms | 20931ms
4096         |  21550ms | 22269ms | 22424ms
8192         |  23174ms | 24440ms | 24571ms

Something is still happening when we reach 512 connections, but this is
about a 10x improvement over the current state of the world.

### Checklist

- [ ] This PR has adequate test coverage / QA involvement has been duly
considered.
- [ ] This PR has an associated up-to-date [design
doc](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/README.md),
is a design doc
([template](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/design/00000000_template.md)),
or is sufficiently small to not require a design.
  <!-- Reference the design in the description. -->
- [ ] If this PR evolves [an existing `$T ⇔ Proto$T`
mapping](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/command-and-response-binary-encoding.md)
(possibly in a backwards-incompatible way), then it is tagged with a
`T-proto` label.
- [ ] If this PR will require changes to cloud orchestration or tests,
there is a companion cloud PR to account for those changes that is
tagged with the release-blocker label
([example](MaterializeInc/cloud#5021)).
<!-- Ask in #team-cloud on Slack if you need help preparing the cloud
PR. -->
- [x] This PR includes the following [user-facing behavior
changes](https://github.com/MaterializeInc/materialize/blob/main/doc/developer/guide-changes.md#what-changes-require-a-release-note):
  - N/a
  • Loading branch information
ParkMyCar authored Sep 20, 2023
1 parent 2a65a47 commit 02bf079
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 37 deletions.
2 changes: 1 addition & 1 deletion src/environmentd/tests/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ fn start_mzcloud(
.refresh_tokens
.lock()
.unwrap()
.get(args.refresh_token),
.get(&args.refresh_token),
context.enable_refresh.load(Ordering::Relaxed),
) {
(Some(email), true) => email.to_string(),
Expand Down
3 changes: 2 additions & 1 deletion src/frontegg-auth/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ reqwest = { version = "0.11.13", features = ["json"] }
reqwest-middleware = "0.2.2"
reqwest-retry = "0.2.2"
serde = { version = "1.0.152", features = ["derive"] }
serde_json = "1.0.89"
thiserror = "1.0.37"
tokio = { version = "1.24.2", features = ["macros"] }
tracing = "0.1.37"
Expand All @@ -24,7 +25,7 @@ workspace-hack = { version = "0.0.0", path = "../workspace-hack" }

[dev-dependencies]
axum = "0.6.20"
serde_json = "1.0.89"
mz-ore = { path = "../ore", features = ["network", "test"] }
tokio = { version = "1.24.2", features = ["macros", "rt-multi-thread"] }

[package.metadata.cargo-udeps.ignore]
Expand Down
2 changes: 1 addition & 1 deletion src/frontegg-auth/src/app_password.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ impl FromStr for AppPassword {
}

/// An error while parsing an [`AppPassword`].
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct AppPasswordParseError;

impl Error for AppPasswordParseError {}
Expand Down
177 changes: 176 additions & 1 deletion src/frontegg-auth/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,32 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::sync::{Arc, Mutex};
use std::time::Duration;

use anyhow::Context;
use mz_ore::collections::HashMap;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::RetryTransientMiddleware;
use tokio::sync::oneshot;

use crate::client::tokens::RefreshTokenResponse;
use crate::{ApiTokenArgs, ApiTokenResponse, Error, RefreshToken};

pub mod tokens;

/// Client for Frontegg auth requests.
///
/// Internally the client will attempt to de-dupe requests, e.g. if a single user tries to connect
/// many clients at once, we'll de-dupe the authentication requests.
#[derive(Clone, Debug)]
pub struct Client {
pub client: reqwest_middleware::ClientWithMiddleware,
inflight_requests: Arc<Mutex<HashMap<Request, ResponseHandle>>>,
}

type ResponseHandle = Vec<oneshot::Sender<Result<Response, Error>>>;

impl Default for Client {
fn default() -> Self {
// Re-use the envd defaults until there's a reason to use something else. This is a separate
Expand All @@ -44,6 +58,167 @@ impl Client {
.with(RetryTransientMiddleware::new_with_policy(retry_policy))
.build();

Self { client }
let inflight_requests = Arc::new(Mutex::new(HashMap::new()));

Self {
client,
inflight_requests,
}
}

/// Makes a request to the provided URL, possibly de-duping by attaching a listener to an
/// already in-flight request.
async fn make_request<Req, Resp>(&self, url: String, req: Req) -> Result<Resp, Error>
where
Req: AuthRequest,
Resp: AuthResponse,
{
let req = req.into_request();

// Note: we get the reciever in a block to scope the access to the mutex.
let rx = {
let mut inflight_requests = self
.inflight_requests
.lock()
.expect("Frontegg Auth Client panicked");
let (tx, rx) = tokio::sync::oneshot::channel();

match inflight_requests.get_mut(&req) {
// Already have an inflight request, add to our list of waiters.
Some(senders) => {
tracing::debug!("reusing request, {req:?}");
senders.push(tx);
rx
}
// New request! Need to queue one up.
None => {
tracing::debug!("spawning new request, {req:?}");

inflight_requests.insert(req.clone(), vec![tx]);

let client = self.client.clone();
let inflight = Arc::clone(&self.inflight_requests);
let req_ = req.clone();

mz_ore::task::spawn(move || "frontegg-auth-request", async move {
// Make the actual request.
let result = async {
let resp = client
.post(&url)
.json(&req_.into_json())
.send()
.await?
.error_for_status()?
.json::<Resp>()
.await?;
Ok::<_, Error>(resp)
}
.await;

// Get all of our waiters.
let mut inflight = inflight.lock().expect("Frontegg Auth Client panicked");
let Some(waiters) = inflight.remove(&req) else {
tracing::error!("Inflight entry already removed? {req:?}");
return;
};

// Tell all of our waiters about the result.
let response = result.map(|r| r.into_response());
for tx in waiters {
let _ = tx.send(response.clone());
}
});

rx
}
}
};

let resp = rx.await.context("waiting for inflight response")?;
resp.map(|r| Resp::from_response(r))
}
}

/// Boilerplate for de-duping requests.
///
/// We maintain an in-memory map of inflight requests, and that map needs to have keys of a single
/// type, so we wrap all of our request types an an enum to create that single type.
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
enum Request {
ExchangeSecretForToken(ApiTokenArgs),
RefreshToken(RefreshToken),
}

impl Request {
fn into_json(self) -> serde_json::Value {
match self {
Request::ExchangeSecretForToken(arg) => serde_json::to_value(arg),
Request::RefreshToken(arg) => serde_json::to_value(arg),
}
.expect("converting to JSON cannot fail")
}
}

/// Boilerplate for de-duping requests.
///
/// Deduplicates the wrapping of request types into a [`Request`].
trait AuthRequest: serde::Serialize + Clone {
fn into_request(self) -> Request;
}

impl AuthRequest for ApiTokenArgs {
fn into_request(self) -> Request {
Request::ExchangeSecretForToken(self)
}
}

impl AuthRequest for RefreshToken {
fn into_request(self) -> Request {
Request::RefreshToken(self)
}
}

/// Boilerplate for de-duping requests.
///
/// We maintain an in-memory map of inflight requests, the values of the map are a Vec of waiters
/// that listen for a response. These listeners all need to have the same type, so we wrap all of
/// our response types in an enum.
#[derive(Clone, Debug)]
enum Response {
ExchangeSecretForToken(ApiTokenResponse),
RefreshToken(RefreshTokenResponse),
}

/// Boilerplate for de-duping requests.
///
/// Deduplicates the wrapping and unwrapping between response types and [`Response`].
trait AuthResponse: serde::de::DeserializeOwned {
fn into_response(self) -> Response;
fn from_response(resp: Response) -> Self;
}

impl AuthResponse for ApiTokenResponse {
fn into_response(self) -> Response {
Response::ExchangeSecretForToken(self)
}

fn from_response(resp: Response) -> Self {
let Response::ExchangeSecretForToken(result) = resp else {
unreachable!("programming error!, didn't roundtrip {resp:?}")
};
result
}
}

impl AuthResponse for RefreshTokenResponse {
fn into_response(self) -> Response {
Response::RefreshToken(self)
}

fn from_response(resp: Response) -> Self {
let Response::RefreshToken(result) = resp else {
unreachable!("programming error!, didn't roundtrip")
};
result
}
}
53 changes: 26 additions & 27 deletions src/frontegg-auth/src/client/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,11 @@ impl Client {
secret: Uuid,
admin_api_token_url: &str,
) -> Result<ApiTokenResponse, Error> {
let res = self
.client
.post(admin_api_token_url)
.json(&ApiTokenArgs { client_id, secret })
.send()
.await?
.error_for_status()?
.json::<ApiTokenResponse>()
let args = ApiTokenArgs { client_id, secret };
let result = self
.make_request(admin_api_token_url.to_string(), args)
.await?;
Ok(res)
Ok(result)
}

/// Exchanges a client id and secret for a jwt token.
Expand All @@ -37,20 +32,22 @@ impl Client {
refresh_url: &str,
refresh_token: &str,
) -> Result<ApiTokenResponse, Error> {
let res = self
.client
.post(refresh_url)
.json(&RefreshToken { refresh_token })
.send()
.await?
.error_for_status()?
.json::<ApiTokenResponse>()
.await?;
Ok(res)
let args = RefreshToken {
refresh_token: refresh_token.to_string(),
};
let result = self.make_request(refresh_url.to_string(), args).await?;
Ok(result)
}
}

#[derive(Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiTokenArgs {
pub client_id: Uuid,
pub secret: Uuid,
}

#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiTokenResponse {
pub expires: String,
Expand All @@ -59,17 +56,19 @@ pub struct ApiTokenResponse {
pub refresh_token: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiTokenArgs {
pub client_id: Uuid,
pub secret: Uuid,
pub struct RefreshToken {
pub refresh_token: String,
}

#[derive(Debug, serde::Serialize, serde::Deserialize)]
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RefreshToken<'a> {
pub refresh_token: &'a str,
pub struct RefreshTokenResponse {
pub expires: String,
pub expires_in: i64,
pub access_token: String,
pub refresh_token: String,
}

#[cfg(test)]
Expand Down
33 changes: 27 additions & 6 deletions src/frontegg-auth/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,56 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use std::sync::Arc;
use thiserror::Error;

use crate::AppPasswordParseError;

#[derive(Error, Debug)]
#[derive(Clone, Error, Debug)]
pub enum Error {
#[error(transparent)]
InvalidPasswordFormat(#[from] AppPasswordParseError),
#[error("invalid token format: {0}")]
InvalidTokenFormat(#[from] jsonwebtoken::errors::Error),
#[error("authentication token exchange failed: {0}")]
ReqwestError(#[from] reqwest::Error),
ReqwestError(Arc<reqwest::Error>),
#[error("middleware programming error: {0}")]
MiddlewareError(anyhow::Error),
MiddlewareError(Arc<anyhow::Error>),
#[error("authentication token expired")]
TokenExpired,
#[error("unauthorized organization")]
UnauthorizedTenant,
#[error("email in access token did not match the expected email")]
WrongEmail,
#[error("request timeout")]
Timeout(#[from] tokio::time::error::Elapsed),
Timeout(Arc<tokio::time::error::Elapsed>),
#[error("internal error")]
Internal(Arc<anyhow::Error>),
}

impl From<anyhow::Error> for Error {
fn from(value: anyhow::Error) -> Self {
Error::Internal(Arc::new(value))
}
}

impl From<tokio::time::error::Elapsed> for Error {
fn from(value: tokio::time::error::Elapsed) -> Self {
Error::Timeout(Arc::new(value))
}
}

impl From<reqwest::Error> for Error {
fn from(value: reqwest::Error) -> Self {
Error::ReqwestError(Arc::new(value))
}
}

impl From<reqwest_middleware::Error> for Error {
fn from(value: reqwest_middleware::Error) -> Self {
match value {
reqwest_middleware::Error::Middleware(e) => Error::MiddlewareError(e),
reqwest_middleware::Error::Reqwest(e) => Error::ReqwestError(e),
reqwest_middleware::Error::Middleware(e) => Error::MiddlewareError(Arc::new(e)),
reqwest_middleware::Error::Reqwest(e) => Error::ReqwestError(Arc::new(e)),
}
}
}

0 comments on commit 02bf079

Please sign in to comment.