Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions crates/core/src/auth/token_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ impl async_cache::Fetcher<Arc<JwksValidator>> for KeyFetcher {
// TODO: Make this stored in the struct so we don't need to keep creating it.
let raw_issuer = key.to_string();
log::info!("Fetching key for issuer {}", raw_issuer.clone());
// TODO: Consider checking for trailing slashes or requiring a scheme.
let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer);
let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer.trim_end_matches('/'));
let key_or_error = Jwks::from_oidc_url(oidc_url).await;
// TODO: We should probably add debouncing to avoid spamming the logs.
// Alternatively we could add a backoff before retrying.
Expand Down Expand Up @@ -259,8 +258,7 @@ impl TokenValidator for OidcTokenValidator {
async fn validate_token(&self, token: &str) -> Result<SpacetimeIdentityClaims, TokenValidationError> {
// TODO: Make this stored in the struct so we don't need to keep creating it.
let raw_issuer = get_raw_issuer(token)?;
// TODO: Consider checking for trailing slashes or requiring a scheme.
let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer);
let oidc_url = format!("{}/.well-known/openid-configuration", raw_issuer.trim_end_matches('/'));
log::debug!("Fetching key for issuer {}", raw_issuer.clone());
let key_or_error = Jwks::from_oidc_url(oidc_url).await;
// TODO: We should probably add debouncing to avoid spamming the logs.
Expand Down Expand Up @@ -573,7 +571,12 @@ mod tests {
}
}

async fn run_oidc_test<T: TokenValidator>(validator: T) -> anyhow::Result<()> {
#[derive(Debug, Default, Copy, Clone)]
struct TestOptions {
pub issuer_trailing_slash: bool,
}

async fn run_oidc_test<T: TokenValidator>(validator: T, opts: &TestOptions) -> anyhow::Result<()> {
// We will put 2 keys in the keyset.
let mut kp1 = JwtKeys::generate()?;
let mut kp2 = JwtKeys::generate()?;
Expand All @@ -593,6 +596,11 @@ mod tests {
let handle = OIDCServerHandle::start_new(jwks).await?;

let issuer = handle.base_url.clone();
let issuer = if opts.issuer_trailing_slash {
format!("{}/", issuer)
} else {
issuer
};
let subject = "test_subject";

let orig_claims = IncomingClaims {
Expand Down Expand Up @@ -623,16 +631,27 @@ mod tests {
#[tokio::test]
async fn test_oidc_flow() -> anyhow::Result<()> {
for _ in 0..10 {
run_oidc_test(OidcTokenValidator).await?
run_oidc_test(OidcTokenValidator, &Default::default()).await?
}
Ok(())
}

#[tokio::test]
async fn test_issuer_slash() -> anyhow::Result<()> {
let opts = TestOptions {
issuer_trailing_slash: true,
};

run_oidc_test(OidcTokenValidator, &opts).await?;
run_oidc_test(CachingOidcTokenValidator::get_default(), &opts).await?;
Ok(())
}

#[tokio::test]
async fn test_caching_oidc_flow() -> anyhow::Result<()> {
for _ in 0..10 {
let v = CachingOidcTokenValidator::get_default();
run_oidc_test(v).await?;
run_oidc_test(v, &Default::default()).await?;
}
Ok(())
}
Expand All @@ -645,7 +664,7 @@ mod tests {
local_issuer: "local_issuer".to_string(),
oidc_validator: OidcTokenValidator,
};
run_oidc_test(v).await
run_oidc_test(v, &Default::default()).await
}

/// Convert a set of keys to a JWKS JSON string.
Expand Down
Loading