Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 68 additions & 24 deletions tee-worker/omni-executor/config-loader/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ const DEFAULT_MAILER_TYPE: &str = "sendgrid";
const DEFAULT_MAILER_API_KEY: &str = "";
const DEFAULT_MAILER_FROM_EMAIL: &str = "no-reply@example.com";
const DEFAULT_MAILER_FROM_NAME: &str = "Heima Verify";
const DEFAULT_GOOGLE_CLIENT_ID: &str = "";
const DEFAULT_GOOGLE_CLIENT_SECRET: &str = "";
const DEFAULT_ETHEREUM_URL: &str = "https://eth-mainnet.g.alchemy.com/v2/";
const DEFAULT_SOLANA_URL: &str = "https://solana-mainnet.g.alchemy.com/v2/";
const DEFAULT_BSC_URL: &str = "https://bnb-mainnet.g.alchemy.com/v2/";
Expand Down Expand Up @@ -86,11 +84,16 @@ impl Default for MailerConfig {
}
}

#[derive(Debug, Clone)]
pub struct GoogleOAuth2Config {
pub client_id: String,
pub client_secret: String,
}

#[derive(Debug, Clone)]
pub struct ConfigLoader {
pub mailer_configs: HashMap<String, MailerConfig>,
pub google_client_id: String,
pub google_client_secret: String,
pub google_oauth2_configs: HashMap<String, GoogleOAuth2Config>,
pub ethereum_url: String,
pub solana_url: String,
pub bsc_url: String,
Expand Down Expand Up @@ -137,24 +140,6 @@ impl ConfigLoader {
info!("Executing: {}", std::env::args().collect::<Vec<_>>().join(" "));

let vars: HashMap<&str, EnvVar> = HashMap::from([
(
"google_client_id",
EnvVar {
env_key: "OE_GOOGLE_CLIENT_ID",
default: DEFAULT_GOOGLE_CLIENT_ID,
sensitive: false,
optional: false,
},
),
(
"google_client_secret",
EnvVar {
env_key: "OE_GOOGLE_CLIENT_SECRET",
default: DEFAULT_GOOGLE_CLIENT_SECRET,
sensitive: true,
optional: false,
},
),
(
"ethereum_url",
EnvVar {
Expand Down Expand Up @@ -350,11 +335,11 @@ impl ConfigLoader {
let get_opt = |key: &str| get_env_value(&vars[key]);

let mailer_configs = Self::load_mailer_configs();
let google_oauth2_configs = Self::load_google_oauth2_configs();

ConfigLoader {
mailer_configs,
google_client_id: get("google_client_id"),
google_client_secret: get("google_client_secret"),
google_oauth2_configs,
ethereum_url: append_key(&get("ethereum_url")),
solana_url: append_key(&get("solana_url")),
bsc_url: append_key(&get("bsc_url")),
Expand Down Expand Up @@ -466,4 +451,63 @@ impl ConfigLoader {
clients.sort();
clients
}

/// Load Google OAuth2 configurations for multiple clients from environment variables
/// Format: OE_GOOGLE_CLIENT_ID_{CLIENT}, OE_GOOGLE_CLIENT_SECRET_{CLIENT}
/// CLIENT can be WILDMETA, HEIMA, etc.
fn load_google_oauth2_configs() -> HashMap<String, GoogleOAuth2Config> {
let mut configs = HashMap::new();

let env_vars: HashMap<String, String> = std::env::vars().collect();

let mut clients = std::collections::HashSet::new();
for key in env_vars.keys() {
if key.starts_with("OE_GOOGLE_CLIENT_ID_") {
if let Some(client) = key.strip_prefix("OE_GOOGLE_CLIENT_ID_") {
info!("Found Google OAuth2 configuration for client: {}", client);
clients.insert(client.to_lowercase());
}
}
}

info!("Total discovered Google OAuth2 clients: {:?}", clients);

if clients.is_empty() {
warn!("No Google OAuth2 configurations found in environment variables.");
return configs;
}

for client in clients {
let client_upper = client.to_uppercase();

let client_id =
std::env::var(format!("OE_GOOGLE_CLIENT_ID_{}", client_upper)).unwrap_or_default();
let client_secret = std::env::var(format!("OE_GOOGLE_CLIENT_SECRET_{}", client_upper))
.unwrap_or_default();

if client_id.is_empty() || client_secret.is_empty() {
warn!(
"Incomplete Google OAuth2 config for client '{}': client_id_empty={}, client_secret_empty={}",
client,
client_id.is_empty(),
client_secret.is_empty()
);
continue;
}

let config = GoogleOAuth2Config { client_id, client_secret };

info!("Loaded Google OAuth2 config for client '{}'", client);

configs.insert(client.clone(), config);
}

configs
}

/// Get Google OAuth2 configuration for a specific client
pub fn get_google_oauth2_config(&self, client_id: &str) -> Option<GoogleOAuth2Config> {
let client_key = client_id.to_lowercase();
self.google_oauth2_configs.get(&client_key).cloned()
}
}
2 changes: 1 addition & 1 deletion tee-worker/omni-executor/config-loader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
mod config;

pub use config::{ConfigLoader, MailerConfig, MailerType};
pub use config::{ConfigLoader, GoogleOAuth2Config, MailerConfig, MailerType};
5 changes: 3 additions & 2 deletions tee-worker/omni-executor/executor-primitives/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub enum OmniAuth {
Web3(String, Identity, HeimaMultiSignature), // (client_id, Signer, Signature)
Email(String, Email, VerificationCode), // (client_id, Email, VerificationCode)
AuthToken(JwtToken),
OAuth2(Identity, OAuth2Data), // (Sender, OAuth2Data)
OAuth2(String, Identity, OAuth2Data), // (client_id, Sender, OAuth2Data)
Passkey(PasskeyData),
}

Expand Down Expand Up @@ -176,6 +176,7 @@ pub struct OAuth2Data {
pub code: String,
pub state: String,
pub redirect_uri: String,
pub uid: String, // A unique identifier for the user/session requesting the OAuth2
}

#[derive(Encode, Decode, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -257,7 +258,7 @@ pub fn to_omni_auth(
UserAuth::OAuth2(data) => {
let identity =
Identity::try_from(user_id.clone()).map_err(|_| "Invalid user ID format")?;
OmniAuth::OAuth2(identity, data.clone())
OmniAuth::OAuth2(client_id.to_string(), identity, data.clone())
},
UserAuth::Passkey(data) => OmniAuth::Passkey(data.clone()),
};
Expand Down
79 changes: 79 additions & 0 deletions tee-worker/omni-executor/rpc-server/src/google_oauth2_factory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use config_loader::{ConfigLoader, GoogleOAuth2Config};
use std::collections::HashMap;
use std::sync::Arc;

pub struct GoogleOAuth2Factory {
config_loader: Arc<ConfigLoader>,
config_cache: std::sync::RwLock<HashMap<String, GoogleOAuth2Config>>,
}

impl GoogleOAuth2Factory {
pub fn new(config_loader: Arc<ConfigLoader>) -> Self {
Self { config_loader, config_cache: std::sync::RwLock::new(HashMap::new()) }
}

pub fn get_google_config_for_client(
&self,
client_id: &str,
) -> Result<GoogleOAuth2Config, Box<dyn std::error::Error>> {
let client_key = client_id.to_lowercase();

let cache = self.config_cache.read().map_err(|e| format!("Failed to read cache: {}", e))?;
if let Some(config) = cache.get(&client_key) {
return Ok(config.clone());
}

drop(cache);

let config = self.config_loader.get_google_oauth2_config(client_id).ok_or_else(|| {
let available_clients = self.config_loader.list_available_clients();
format!(
"No Google OAuth2 configuration found for client '{}'. Available clients: {:?}",
client_id, available_clients
)
})?;

tracing::info!("Loaded Google OAuth2 config for client '{}'", client_id);

let mut cache =
self.config_cache.write().map_err(|e| format!("Failed to write cache: {}", e))?;
cache.insert(client_key.clone(), config.clone());

Ok(config)
}
}

#[cfg(test)]
mod tests {
use super::*;
use config_loader::ConfigLoader;

#[test]
fn test_google_oauth2_factory_caching() {
std::env::set_var("OE_GOOGLE_CLIENT_ID_TESTCLIENT", "test_client_id");
std::env::set_var("OE_GOOGLE_CLIENT_SECRET_TESTCLIENT", "test_client_secret");

let config = ConfigLoader::from_env();
let factory = GoogleOAuth2Factory::new(Arc::new(config));

let config1 = factory
.get_google_config_for_client("testclient")
.expect("Should create config");
let config2 = factory
.get_google_config_for_client("testclient")
.expect("Should get cached config");

assert_eq!(config1.client_id, config2.client_id);
assert_eq!(config1.client_secret, config2.client_secret);
}

#[test]
fn test_google_oauth2_factory_missing_client() {
let config = ConfigLoader::from_env();
let factory = GoogleOAuth2Factory::new(Arc::new(config));

let result = factory.get_google_config_for_client("nonexistent");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("No Google OAuth2 configuration found"));
}
}
1 change: 1 addition & 0 deletions tee-worker/omni-executor/rpc-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod auth_utils;
mod config;
mod detailed_error;
mod error_code;
mod google_oauth2_factory;
mod mailer_factory;
mod methods;
mod middlewares;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,25 @@
use crate::server::RpcContext;
use crate::{
detailed_error::DetailedError, error_code::EXTERNAL_API_ERROR_CODE, server::RpcContext,
};
use executor_core::intent_executor::IntentExecutor;
use executor_primitives::{utils::hex::ToHexPrefixed, Identity, Web2IdentityType};
use executor_crypto::hashing::blake2_256;
use executor_primitives::Hash;
use executor_storage::{OAuth2StateVerifierStorage, Storage};
use heima_identity_verification::web2::google;
use jsonrpsee::{
types::{ErrorCode, ErrorObject},
RpcModule,
};
use parity_scale_codec::Encode;
use serde::Deserialize;
use tracing::error;

#[derive(Debug, Deserialize)]
struct GetOAuth2GoogleAuthorizationUrlParams {
pub uid: String, // A unique identifier for the user/session requesting the OAuth2 URL
pub redirect_uri: String,
pub client_id: String,
}

pub fn register_get_oauth2_google_authorization_url<
EthereumIntentExecutor: IntentExecutor + Send + Sync + 'static,
Expand All @@ -21,20 +34,39 @@ pub fn register_get_oauth2_google_authorization_url<
.register_async_method(
"omni_getOAuth2GoogleAuthorizationUrl",
|params, ctx, _| async move {
match params.parse::<(String, String)>() {
Ok((google_account, redirect_uri)) => {
let google_identity =
Identity::from_web2_account(&google_account, Web2IdentityType::Google);
let authorization_data =
google::get_authorize_data(&ctx.google_client_id, &redirect_uri);
let storage = OAuth2StateVerifierStorage::new(ctx.storage_db.clone());
storage
.insert(&google_identity.hash(), authorization_data.state.clone())
.map_err(|_| ErrorCode::InternalError)?;
Ok::<String, ErrorObject>(authorization_data.authorize_url.to_hex())
},
Err(_) => Err(ErrorCode::ParseError.into()),
}
let params = params
.parse::<GetOAuth2GoogleAuthorizationUrlParams>()
.map_err(|_| ErrorCode::ParseError)?;

let google_config = ctx
.google_oauth2_factory
.get_google_config_for_client(&params.client_id)
.map_err(|e| {
error!(
"Failed to get Google OAuth2 config for client '{}': {}",
params.client_id, e
);
DetailedError::new(
EXTERNAL_API_ERROR_CODE,
"Failed to get Google OAuth2 configuration",
)
.with_field("client_id")
.with_received(&params.client_id)
.with_reason(format!("Error: {}", e))
.to_error_object()
})?;

let authorization_data =
google::get_authorize_data(&google_config.client_id, &params.redirect_uri);
let storage = OAuth2StateVerifierStorage::new(ctx.storage_db.clone());
let key: Hash =
blake2_256((params.client_id.clone(), params.uid.clone()).encode().as_slice())
.into();

storage
.insert(&key, authorization_data.state.clone())
.map_err(|_| ErrorCode::InternalError)?;
Ok::<String, ErrorObject>(authorization_data.authorize_url)
},
)
.expect("Failed to register omni_getOAuth2GoogleAuthorizationUrl method");
Expand Down
Loading