Skip to content

Commit

Permalink
feat: auto refresh session
Browse files Browse the repository at this point in the history
  • Loading branch information
speed2exe committed Oct 13, 2023
1 parent b994f84 commit 591b66b
Show file tree
Hide file tree
Showing 14 changed files with 197 additions and 67 deletions.
4 changes: 3 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ fancy-regex = "0.11.0"
validator = "0.16.0"
bytes = "1.4.0"
rcgen = { version = "0.10.0", features = ["pem", "x509-parser"] }
jsonwebtoken = "8.3.0"
mime = "0.3.17"
# aws-config = "0.56.1"
# aws-sdk-s3 = "0.31.1"
Expand Down
1 change: 1 addition & 0 deletions admin_frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ tower-http = { version = "0.4.4", features = ["cors"] }
tower = "0.4.13"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
jwt = "0.16.0"
5 changes: 0 additions & 5 deletions admin_frontend/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ pub struct LoginRequest {
pub password: String,
}

#[derive(Deserialize)]
pub struct AddUserRequest {
pub email: String,
}

#[derive(Serialize)]
pub struct LoginResponse {
pub access_token: String,
Expand Down
61 changes: 59 additions & 2 deletions admin_frontend/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::time::{SystemTime, UNIX_EPOCH};

use axum::{
async_trait,
extract::FromRequestParts,
http::request::Parts,
response::{IntoResponse, Redirect},
};
use axum_extra::extract::CookieJar;
use gotrue::grant::{Grant, RefreshTokenGrant};
use jwt::{Claims, Header};
use redis::{aio::ConnectionManager, AsyncCommands, FromRedisValue, ToRedisArgs};
use serde::{de::DeserializeOwned, Deserialize, Serialize};

Expand Down Expand Up @@ -38,7 +42,7 @@ impl SessionStorage {
}
}

pub async fn put_user_session(&self, user_session: UserSession) -> redis::RedisResult<()> {
pub async fn put_user_session(&self, user_session: &UserSession) -> redis::RedisResult<()> {
let key = session_id_key(&user_session.session_id);
self
.redis_client
Expand Down Expand Up @@ -93,21 +97,70 @@ impl FromRequestParts<AppState> for UserSession {
.ok_or(SessionRejection::NoSessionId)?
.value();

let session = state
let mut session = state
.session_store
.get_user_session(session_id)
.await
.ok_or(SessionRejection::SessionNotFound)?;

if has_expired(session.access_token.as_str()) {
// Get new pair of access token and refresh token
let refresh_token = session.refresh_token;
let new_token = state
.gotrue_client
.clone()
.token(&Grant::RefreshToken(RefreshTokenGrant { refresh_token }))
.await
.map_err(|err| SessionRejection::RefreshTokenError(err.to_string()))?;

session.access_token = new_token.access_token;
session.refresh_token = new_token.refresh_token;

// Update session in redis
let _ = state
.session_store
.put_user_session(&session)
.await
.map_err(|err| {
tracing::error!("failed to update session in redis: {}", err);
});
}

Ok(session)
}
}

fn has_expired(access_token: &str) -> bool {
match get_session_expiration(access_token) {
Some(expiration_seconds) => {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time went backwards")
.as_secs();
now > expiration_seconds
},
None => false,
}
}

fn get_session_expiration(access_token: &str) -> Option<u64> {
// no need to verify, let the appflowy cloud server do it
// in that way, frontend server does not need to know the secret
match jwt::Token::<Header, Claims, _>::parse_unverified(access_token) {
Ok(unverified) => unverified.claims().registered.expiration,
Err(e) => {
tracing::error!("failed to parse unverified token: {}", e);
None
},
}
}

#[derive(Clone, Debug)]
pub enum SessionRejection {
NoSessionId,
SessionNotFound,
CookieError(String),
RefreshTokenError(String),
}

impl IntoResponse for SessionRejection {
Expand All @@ -119,6 +172,10 @@ impl IntoResponse for SessionRejection {
Redirect::temporary("/web/login").into_response()
},
SessionRejection::SessionNotFound => Redirect::temporary("/web/login").into_response(),
SessionRejection::RefreshTokenError(err) => {
tracing::warn!("refresh token error: {}", err);
Redirect::temporary("/web/login").into_response()
},
}
}
}
Expand Down
32 changes: 25 additions & 7 deletions admin_frontend/src/web_api.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,50 @@
use crate::error::WebApiError;
use crate::models::AddUserRequest;
use crate::response::WebApiResponse;
use crate::session::{self, UserSession};
use crate::{models::LoginRequest, AppState};
use axum::extract::Path;
use axum::http::status;
use axum::response::Result;
use axum::Json;
use axum::{extract::State, routing::post, Router};
use axum_extra::extract::cookie::Cookie;
use axum_extra::extract::CookieJar;
use gotrue::params::AdminUserParams;
use gotrue::params::{AdminDeleteUserParams, AdminUserParams};
use gotrue_entity::User;

pub fn router() -> Router<AppState> {
Router::new()
// TODO
.route("/login", post(login_handler))
.route("/logout", post(logout_handler))
.route("/add_user", post(add_user_handler))
.route("/user/:param", post(post_user_handler).delete(delete_user_handler))
}

pub async fn add_user_handler(
pub async fn delete_user_handler(
State(state): State<AppState>,
session: UserSession,
Json(param): Json<AddUserRequest>,
Path(user_uuid): Path<String>,
) -> Result<WebApiResponse<()>, WebApiError<'static>> {
state
.gotrue_client
.admin_delete_user(
&session.access_token,
&user_uuid,
&AdminDeleteUserParams {
should_soft_delete: true,
},
)
.await?;
Ok(().into())
}

pub async fn post_user_handler(
State(state): State<AppState>,
session: UserSession,
Path(email): Path<String>,
) -> Result<WebApiResponse<User>, WebApiError<'static>> {
let add_user_params = AdminUserParams {
email: param.email.to_owned(),
email,
..Default::default()
};
let user = state
Expand Down Expand Up @@ -59,7 +77,7 @@ pub async fn login_handler(
token.access_token.to_string(),
token.refresh_token.to_owned(),
);
state.session_store.put_user_session(new_session).await?;
state.session_store.put_user_session(&new_session).await?;

let mut cookie = Cookie::new("session_id", new_session_id.to_string());
cookie.set_path("/");
Expand Down
10 changes: 2 additions & 8 deletions admin_frontend/templates/users.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,8 @@ <h1>User List</h1>
// Get the email from the user
const email = prompt('Please enter the new user email:');
if (email) {
fetch("/web-api/add_user", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
email: email,
}),
fetch(`/web-api/user/${email}`, {
method: "POST"
})
.then((response) => {
if (!response.ok) {
Expand Down
2 changes: 2 additions & 0 deletions libs/gotrue-entity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.105"
anyhow = "1.0.75"
reqwest = "0.11.20"
lazy_static = "1.4.0"
jsonwebtoken = "8.3.0"
40 changes: 40 additions & 0 deletions libs/gotrue-entity/src/gotrue_jwt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct GoTrueJWTClaims {
// JWT standard claims
pub aud: Option<String>,
pub exp: Option<i64>,
pub jti: Option<String>,
pub iat: Option<i64>,
pub iss: Option<String>,
pub nbf: Option<i64>,
pub sub: Option<String>,

pub email: String,
pub phone: String,
pub app_metadata: serde_json::Value,
pub user_metadata: serde_json::Value,
pub role: String,
pub aal: Option<String>,
pub amr: Option<Vec<Amr>>,
pub session_id: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Amr {
pub method: String,
pub timestamp: u64,
pub provider: Option<String>,
}

lazy_static::lazy_static! {
pub static ref VALIDATION: Validation = Validation::new(Algorithm::HS256);
}

impl GoTrueJWTClaims {
pub fn verify(token: &str, secret: &[u8]) -> Result<Self, jsonwebtoken::errors::Error> {
Ok(decode(token, &DecodingKey::from_secret(secret), &VALIDATION)?.claims)
}
}
2 changes: 2 additions & 0 deletions libs/gotrue-entity/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod gotrue_jwt;

use serde::{Deserialize, Serialize};
use std::fmt::Formatter;
use std::{collections::BTreeMap, fmt::Display};
Expand Down
29 changes: 28 additions & 1 deletion libs/gotrue/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::params::{AdminUserParams, GenerateLinkParams, GenerateLinkResponse};
use crate::params::{
AdminDeleteUserParams, AdminUserParams, GenerateLinkParams, GenerateLinkResponse,
};
use anyhow::Context;

use super::grant::Grant;
Expand Down Expand Up @@ -145,6 +147,22 @@ impl Client {
to_gotrue_result(resp).await
}

pub async fn admin_delete_user(
&self,
access_token: &str,
user_uuid: &str,
delete_user_params: &AdminDeleteUserParams,
) -> Result<(), GoTrueError> {
let resp = self
.client
.delete(format!("{}/admin/users/{}", self.base_url, user_uuid))
.header("Authorization", format!("Bearer {}", access_token))
.json(&delete_user_params)
.send()
.await?;
check_gotrue_result(resp).await
}

pub async fn admin_add_user(
&self,
access_token: &str,
Expand Down Expand Up @@ -188,3 +206,12 @@ where
Err(err)
}
}

async fn check_gotrue_result(resp: reqwest::Response) -> Result<(), GoTrueError> {
if resp.status().is_success() {
Ok(())
} else {
let err: GoTrueError = from_body(resp).await?;
Err(err)
}
}
5 changes: 5 additions & 0 deletions libs/gotrue/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ use std::collections::btree_map::BTreeMap;
use gotrue_entity::{Factor, Identity};
use serde::{Deserialize, Serialize};

#[derive(Debug, Deserialize, Serialize)]
pub struct AdminDeleteUserParams {
pub should_soft_delete: bool,
}

#[derive(Debug, Default, Deserialize, Serialize)]
pub struct AdminUserParams {
pub aud: String,
Expand Down
Loading

0 comments on commit 591b66b

Please sign in to comment.