diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 30d7d41c88..4241db2a7c 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -8,7 +8,7 @@ //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! ``` -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ extract::{FromRef, FromRequestParts, Query, State}, @@ -28,6 +28,7 @@ use std::env; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; +static CSRF_TOKEN: &str = "csrf_token"; #[tokio::main] async fn main() { @@ -141,19 +142,37 @@ async fn index(user: Option) -> impl IntoResponse { } } -async fn discord_auth(State(client): State) -> impl IntoResponse { - // TODO: this example currently doesn't validate the CSRF token during login attempts. That - // makes it vulnerable to cross-site request forgery. If you copy code from this example make - // sure to add a check for the CSRF token. - // - // Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511 - let (auth_url, _csrf_token) = client +async fn discord_auth( + State(client): State, + State(store): State, +) -> Result { + let (auth_url, csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); - // Redirect to Discord's oauth service - Redirect::to(auth_url.as_ref()) + // Create session to store csrf_token + let mut session = Session::new(); + session + .insert(CSRF_TOKEN, &csrf_token) + .context("failed in inserting CSRF token into session")?; + + // Store the session in MemoryStore and retrieve the session cookie + let cookie = store + .store_session(session) + .await + .context("failed to store CSRF token session")? + .context("unexpected error retrieving CSRF cookie value")?; + + // Attach the session cookie to the response header + let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); + let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); + + Ok((headers, Redirect::to(auth_url.as_ref()))) } // Valid user session required. If there is none, redirect to the auth page @@ -194,11 +213,55 @@ struct AuthRequest { state: String, } +async fn csrf_token_validation_workflow( + auth_request: &AuthRequest, + cookies: &headers::Cookie, + store: &MemoryStore, +) -> Result<(), AppError> { + // Extract the cookie from the request + let cookie = cookies + .get(COOKIE_NAME) + .context("unexpected error getting cookie name")? + .to_string(); + + // Load the session + let session = match store + .load_session(cookie) + .await + .context("failed to load session")? + { + Some(session) => session, + None => return Err(anyhow!("Session not found").into()), + }; + + // Extract the CSRF token from the session + let stored_csrf_token = session + .get::(CSRF_TOKEN) + .context("CSRF token not found in session")? + .to_owned(); + + // Cleanup the CSRF token session + store + .destroy_session(session) + .await + .context("Failed to destroy old session")?; + + // Validate CSRF token is the same as the one in the auth request + if *stored_csrf_token.secret() != auth_request.state { + return Err(anyhow!("CSRF token mismatch").into()); + } + + Ok(()) +} + async fn login_authorized( Query(query): Query, State(store): State, State(oauth_client): State, + TypedHeader(cookies): TypedHeader, ) -> Result { + csrf_token_validation_workflow(&query, &cookies, &store).await?; + // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) @@ -233,7 +296,7 @@ async fn login_authorized( .context("unexpected error retrieving cookie value")?; // Build the cookie - let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/"); + let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); // Set cookie let mut headers = HeaderMap::new();