Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ env_logger.workspace = true
pretty_assertions.workspace = true
jsonwebtoken.workspace = true
axum.workspace = true
reqwest.workspace = true
68 changes: 62 additions & 6 deletions crates/core/src/auth/token_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,14 @@ impl TokenValidator for OidcTokenValidator {
let raw_issuer = get_raw_issuer(token)?;
// TODO: Consider checking for trailing slashes or requiring a scheme.
let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer);
let keys = Jwks::from_oidc_url(oidc_url).await?;
log::debug!("Fetching key for issuer {}", raw_issuer.clone());
let key_or_error = Jwks::from_oidc_url(oidc_url).await;
// TODO: We should probably add debouncing to avoid spamming the logs.
// Alternatively we could add a backoff before retrying.
if let Err(e) = &key_or_error {
log::warn!("Error fetching public key for issuer {}: {:?}", raw_issuer, e);
}
let keys = key_or_error?;
let validator = JwksValidator {
issuer: raw_issuer,
keyset: keys,
Expand Down Expand Up @@ -317,6 +324,8 @@ impl TokenValidator for JwksValidator {

#[cfg(test)]
mod tests {
use std::time::Duration;

use crate::auth::identity::{IncomingClaims, SpacetimeIdentityClaims};
use crate::auth::token_validation::{
BasicTokenValidator, CachingOidcTokenValidator, FullTokenValidator, OidcTokenValidator, TokenSigner,
Expand Down Expand Up @@ -533,6 +542,29 @@ mod tests {
.unwrap();
});

// Wait for server to be ready
let client = reqwest::Client::new();
let health_check_url = format!("{}/ok", base_url);

let mut attempts = 0;
const MAX_ATTEMPTS: u32 = 10;
const DELAY_MS: u64 = 50;

while attempts < MAX_ATTEMPTS {
match client.get(&health_check_url).send().await {
Ok(response) if response.status().is_success() => break,
_ => {
log::debug!("Server not ready. Waiting...");
tokio::time::sleep(Duration::from_millis(DELAY_MS)).await;
attempts += 1;
}
}
}

if attempts == MAX_ATTEMPTS {
return Err(anyhow::anyhow!("Server failed to start after maximum attempts"));
}

Ok(OIDCServerHandle {
base_url,
shutdown_tx,
Expand Down Expand Up @@ -590,13 +622,19 @@ mod tests {

#[tokio::test]
async fn test_oidc_flow() -> anyhow::Result<()> {
run_oidc_test(OidcTokenValidator).await
for _ in 0..10 {
run_oidc_test(OidcTokenValidator).await?
}
Ok(())
}

#[tokio::test]
async fn test_caching_oidc_flow() -> anyhow::Result<()> {
let v = CachingOidcTokenValidator::get_default();
run_oidc_test(v).await
for _ in 0..10 {
let v = CachingOidcTokenValidator::get_default();
run_oidc_test(v).await?;
}
Ok(())
}

#[tokio::test]
Expand Down Expand Up @@ -639,8 +677,26 @@ mod tests {
let mut y = openssl::bn::BigNum::new()?;
eck.public_key().affine_coordinates(&group, &mut x, &mut y, &mut ctx)?;

let x_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(x.to_vec());
let y_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(y.to_vec());
let x_bytes = x.to_vec();
let y_bytes = y.to_vec();

let x_padded = if x_bytes.len() < 32 {
let mut padded = vec![0u8; 32];
padded[32 - x_bytes.len()..].copy_from_slice(&x_bytes);
padded
} else {
x_bytes
};

let y_padded = if y_bytes.len() < 32 {
let mut padded = vec![0u8; 32];
padded[32 - y_bytes.len()..].copy_from_slice(&y_bytes);
padded
} else {
y_bytes
};
let x_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(x_padded);
let y_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(y_padded);

let mut jwks = serde_json::json!(
{
Expand Down
Loading