Skip to content

Commit

Permalink
Implement CSRF token verification for OAuth example (#2534)
Browse files Browse the repository at this point in the history
Co-authored-by: Logan Nielsen <loganbn@amazon.com>
Co-authored-by: Logan Nielsen <logan.nielsen@one.app>
  • Loading branch information
3 people authored Nov 17, 2024
1 parent dc5c202 commit 7e59625
Showing 1 changed file with 74 additions and 11 deletions.
85 changes: 74 additions & 11 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ```

use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
extract::{FromRef, FromRequestParts, Query, State},
Expand All @@ -28,6 +28,7 @@ use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

static COOKIE_NAME: &str = "SESSION";
static CSRF_TOKEN: &str = "csrf_token";

#[tokio::main]
async fn main() {
Expand Down Expand Up @@ -141,19 +142,37 @@ async fn index(user: Option<User>) -> impl IntoResponse {
}
}

async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
// TODO: this example currently doesn't validate the CSRF token during login attempts. That
// makes it vulnerable to cross-site request forgery. If you copy code from this example make
// sure to add a check for the CSRF token.
//
// Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511
let (auth_url, _csrf_token) = client
async fn discord_auth(
State(client): State<BasicClient>,
State(store): State<MemoryStore>,
) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
.url();

// Redirect to Discord's oauth service
Redirect::to(auth_url.as_ref())
// Create session to store csrf_token
let mut session = Session::new();
session
.insert(CSRF_TOKEN, &csrf_token)
.context("failed in inserting CSRF token into session")?;

// Store the session in MemoryStore and retrieve the session cookie
let cookie = store
.store_session(session)
.await
.context("failed to store CSRF token session")?
.context("unexpected error retrieving CSRF cookie value")?;

// Attach the session cookie to the response header
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);

Ok((headers, Redirect::to(auth_url.as_ref())))
}

// Valid user session required. If there is none, redirect to the auth page
Expand Down Expand Up @@ -194,11 +213,55 @@ struct AuthRequest {
state: String,
}

async fn csrf_token_validation_workflow(
auth_request: &AuthRequest,
cookies: &headers::Cookie,
store: &MemoryStore,
) -> Result<(), AppError> {
// Extract the cookie from the request
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?
.to_string();

// Load the session
let session = match store
.load_session(cookie)
.await
.context("failed to load session")?
{
Some(session) => session,
None => return Err(anyhow!("Session not found").into()),
};

// Extract the CSRF token from the session
let stored_csrf_token = session
.get::<CsrfToken>(CSRF_TOKEN)
.context("CSRF token not found in session")?
.to_owned();

// Cleanup the CSRF token session
store
.destroy_session(session)
.await
.context("Failed to destroy old session")?;

// Validate CSRF token is the same as the one in the auth request
if *stored_csrf_token.secret() != auth_request.state {
return Err(anyhow!("CSRF token mismatch").into());
}

Ok(())
}

async fn login_authorized(
Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> {
csrf_token_validation_workflow(&query, &cookies, &store).await?;

// Get an auth token
let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
Expand Down Expand Up @@ -233,7 +296,7 @@ async fn login_authorized(
.context("unexpected error retrieving cookie value")?;

// Build the cookie
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/");
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");

// Set cookie
let mut headers = HeaderMap::new();
Expand Down

0 comments on commit 7e59625

Please sign in to comment.