Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CSRF token verification for OAuth example #2534

Merged
merged 3 commits into from
Nov 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 74 additions & 11 deletions examples/oauth/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
//! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth
//! ```

use anyhow::{Context, Result};
use anyhow::{anyhow, Context, Result};
use async_session::{MemoryStore, Session, SessionStore};
use axum::{
extract::{FromRef, FromRequestParts, Query, State},
Expand All @@ -28,6 +28,7 @@ use std::env;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

static COOKIE_NAME: &str = "SESSION";
static CSRF_TOKEN: &str = "csrf_token";

#[tokio::main]
async fn main() {
Expand Down Expand Up @@ -141,19 +142,37 @@ async fn index(user: Option<User>) -> impl IntoResponse {
}
}

async fn discord_auth(State(client): State<BasicClient>) -> impl IntoResponse {
// TODO: this example currently doesn't validate the CSRF token during login attempts. That
// makes it vulnerable to cross-site request forgery. If you copy code from this example make
// sure to add a check for the CSRF token.
//
// Issue for adding check to this example https://github.com/tokio-rs/axum/issues/2511
let (auth_url, _csrf_token) = client
async fn discord_auth(
State(client): State<BasicClient>,
State(store): State<MemoryStore>,
) -> Result<impl IntoResponse, AppError> {
let (auth_url, csrf_token) = client
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("identify".to_string()))
.url();

// Redirect to Discord's oauth service
Redirect::to(auth_url.as_ref())
// Create session to store csrf_token
let mut session = Session::new();
session
.insert(CSRF_TOKEN, &csrf_token)
.context("failed in inserting CSRF token into session")?;

// Store the session in MemoryStore and retrieve the session cookie
let cookie = store
mladedav marked this conversation as resolved.
Show resolved Hide resolved
.store_session(session)
.await
.context("failed to store CSRF token session")?
.context("unexpected error retrieving CSRF cookie value")?;

// Attach the session cookie to the response header
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");
let mut headers = HeaderMap::new();
headers.insert(
SET_COOKIE,
cookie.parse().context("failed to parse cookie")?,
);

Ok((headers, Redirect::to(auth_url.as_ref())))
}

// Valid user session required. If there is none, redirect to the auth page
Expand Down Expand Up @@ -194,11 +213,55 @@ struct AuthRequest {
state: String,
}

async fn csrf_token_validation_workflow(
auth_request: &AuthRequest,
cookies: &headers::Cookie,
store: &MemoryStore,
) -> Result<(), AppError> {
// Extract the cookie from the request
let cookie = cookies
.get(COOKIE_NAME)
.context("unexpected error getting cookie name")?
.to_string();

// Load the session
let session = match store
.load_session(cookie)
.await
.context("failed to load session")?
{
Some(session) => session,
None => return Err(anyhow!("Session not found").into()),
};

// Extract the CSRF token from the session
let stored_csrf_token = session
.get::<CsrfToken>(CSRF_TOKEN)
.context("CSRF token not found in session")?
.to_owned();

// Cleanup the CSRF token session
store
.destroy_session(session)
.await
.context("Failed to destroy old session")?;

// Validate CSRF token is the same as the one in the auth request
if *stored_csrf_token.secret() != auth_request.state {
return Err(anyhow!("CSRF token mismatch").into());
}

Ok(())
}

async fn login_authorized(
Query(query): Query<AuthRequest>,
State(store): State<MemoryStore>,
State(oauth_client): State<BasicClient>,
TypedHeader(cookies): TypedHeader<headers::Cookie>,
) -> Result<impl IntoResponse, AppError> {
csrf_token_validation_workflow(&query, &cookies, &store).await?;

// Get an auth token
let token = oauth_client
.exchange_code(AuthorizationCode::new(query.code.clone()))
Expand Down Expand Up @@ -233,7 +296,7 @@ async fn login_authorized(
.context("unexpected error retrieving cookie value")?;

// Build the cookie
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/");
let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/");

// Set cookie
let mut headers = HeaderMap::new();
Expand Down