Skip to content

Commit

Permalink
avoid use url path directly.
Browse files Browse the repository at this point in the history
  • Loading branch information
youngsofun committed Aug 12, 2024
1 parent 5a0400f commit 195b3d1
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 70 deletions.
4 changes: 2 additions & 2 deletions src/query/service/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub struct AuthMgr {
}

pub enum Credential {
Databend {
DatabendToken {
token: String,
token_type: TokenType,
set_user: bool,
Expand Down Expand Up @@ -75,7 +75,7 @@ impl AuthMgr {
pub async fn auth(&self, session: &mut Session, credential: &Credential) -> Result<()> {
let user_api = UserApiProvider::instance();
match credential {
Credential::Databend {
Credential::DatabendToken {
token,
set_user,
token_type,
Expand Down
37 changes: 26 additions & 11 deletions src/query/service/src/servers/http/http_services.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ use poem::post;
use poem::put;
use poem::Endpoint;
use poem::EndpointExt;
use poem::IntoEndpoint;
use poem::Route;

use super::v1::upload_to_stage;
use crate::auth::AuthMgr;
use crate::servers::http::middleware::EndpointKind;
use crate::servers::http::middleware::HTTPSessionMiddleware;
use crate::servers::http::middleware::PanicHandler;
use crate::servers::http::v1::clickhouse_router;
Expand Down Expand Up @@ -85,25 +86,39 @@ impl HttpHandler {
})
}

fn wrap_auth(&self, ep: Route) -> impl Endpoint {
let auth_manager = AuthMgr::instance();
let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_manager);
fn wrap_auth<E>(&self, ep: E, auth_type: EndpointKind) -> impl Endpoint
where
E: IntoEndpoint,
E::Endpoint: 'static,
{
let session_middleware = HTTPSessionMiddleware::create(self.kind, auth_type);
ep.with(session_middleware).boxed()
}

#[allow(clippy::let_with_type_underscore)]
#[async_backtrace::framed]
async fn build_router(&self, sock: SocketAddr) -> impl Endpoint {
let ep_v1 = Route::new()
.nest("/query", query_route())
.at("/session/login", post(login_handler))
.at("/session/renew", post(renew_handler))
.at("/upload_to_stage", put(upload_to_stage))
.at("/suggested_background_tasks", get(list_suggestions));
let ep_v1 = self.wrap_auth(ep_v1);
.nest("/query", self.wrap_auth(query_route(), EndpointKind::Query))
.at(
"/session/login",
self.wrap_auth(post(login_handler), EndpointKind::Login),
)
.at(
"/session/renew",
self.wrap_auth(post(renew_handler), EndpointKind::Refresh),
)
.at(
"/upload_to_stage",
self.wrap_auth(put(upload_to_stage), EndpointKind::Query),
)
.at(
"/suggested_background_tasks",
self.wrap_auth(get(list_suggestions), EndpointKind::Query),
);

let ep_clickhouse = Route::new().nest("/", clickhouse_router());
let ep_clickhouse = self.wrap_auth(ep_clickhouse);
let ep_clickhouse = self.wrap_auth(ep_clickhouse, EndpointKind::Clickhouse);

let ep_usage = Route::new().at(
"/",
Expand Down
59 changes: 42 additions & 17 deletions src/query/service/src/servers/http/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,31 @@ use crate::servers::HttpHandlerKind;
use crate::sessions::SessionManager;
use crate::sessions::SessionType;

#[derive(Copy, Clone)]
pub enum EndpointKind {
Login,
Refresh,
Query,
Clickhouse,
}

const USER_AGENT: &str = "User-Agent";
const TRACE_PARENT: &str = "traceparent";

pub struct HTTPSessionMiddleware {
pub kind: HttpHandlerKind,
pub endpoint_kind: EndpointKind,
pub auth_manager: Arc<AuthMgr>,
}

impl HTTPSessionMiddleware {
pub fn create(kind: HttpHandlerKind, auth_manager: Arc<AuthMgr>) -> HTTPSessionMiddleware {
HTTPSessionMiddleware { kind, auth_manager }
pub fn create(kind: HttpHandlerKind, endpoint_kind: EndpointKind) -> HTTPSessionMiddleware {
let auth_manager = AuthMgr::instance();
HTTPSessionMiddleware {
kind,
endpoint_kind,
auth_manager,
}
}
}

Expand Down Expand Up @@ -109,7 +123,11 @@ fn extract_baggage_from_headers(headers: &HeaderMap) -> Option<Vec<(String, Stri
Some(result)
}

fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
fn get_credential(
req: &Request,
kind: HttpHandlerKind,
endpoint_kind: EndpointKind,
) -> Result<Credential> {
let std_auth_headers: Vec<_> = req.headers().get_all(AUTHORIZATION).iter().collect();
if std_auth_headers.len() > 1 {
let msg = &format!("Multiple {} headers detected", AUTHORIZATION);
Expand All @@ -125,7 +143,12 @@ fn get_credential(req: &Request, kind: HttpHandlerKind) -> Result<Credential> {
))
}
} else {
auth_by_header(&std_auth_headers, client_ip, req.uri().path())
auth_by_header(
&std_auth_headers,
client_ip,
endpoint_kind,
req.uri().path(),
)
}
}

Expand Down Expand Up @@ -159,6 +182,7 @@ pub fn get_client_ip(req: &Request) -> Option<String> {
fn auth_by_header(
std_auth_headers: &[&HeaderValue],
client_ip: Option<String>,
endpoint_kind: EndpointKind,
path: &str,
) -> Result<Credential> {
let value = &std_auth_headers[0];
Expand All @@ -182,18 +206,17 @@ fn auth_by_header(
Some(bearer) => {
let token = bearer.token().to_string();
if SessionClaim::is_databend_token(&token) {
let (token_type, set_user) = if path == "/query" {
(TokenType::Session, true)
} else if path == "/session/renew" {
(TokenType::Refresh, true)
} else if path != "/session/login" {
(TokenType::Session, false)
} else {
return Err(ErrorCode::AuthenticateFailure(format!(
"should not use databend auth when accessing {path}"
)));
let (token_type, set_user) = match endpoint_kind {
EndpointKind::Login => (TokenType::Session, false),
EndpointKind::Refresh => (TokenType::Refresh, true),
EndpointKind::Query => (TokenType::Session, true),
EndpointKind::Clickhouse => {
return Err(ErrorCode::AuthenticateFailure(format!(
"should not use databend auth when accessing {path}"
)));
}
};
Ok(Credential::Databend {
Ok(Credential::DatabendToken {
token,
token_type,
set_user,
Expand Down Expand Up @@ -246,6 +269,7 @@ impl<E: Endpoint> Middleware<E> for HTTPSessionMiddleware {
HTTPSessionEndpoint {
ep,
kind: self.kind,
endpoint_kind: self.endpoint_kind,
auth_manager: self.auth_manager.clone(),
}
}
Expand All @@ -254,13 +278,14 @@ impl<E: Endpoint> Middleware<E> for HTTPSessionMiddleware {
pub struct HTTPSessionEndpoint<E> {
ep: E,
pub kind: HttpHandlerKind,
pub endpoint_kind: EndpointKind,
pub auth_manager: Arc<AuthMgr>,
}

impl<E> HTTPSessionEndpoint<E> {
#[async_backtrace::framed]
async fn auth(&self, req: &Request, query_id: String) -> Result<HttpQueryContext> {
let credential = get_credential(req, self.kind)?;
let credential = get_credential(req, self.kind, self.endpoint_kind)?;

let session_manager = SessionManager::instance();

Expand All @@ -274,7 +299,7 @@ impl<E> HTTPSessionEndpoint<E> {

self.auth_manager.auth(&mut session, &credential).await?;
let databend_token = match credential {
Credential::Databend { token, .. } => Some(token),
Credential::DatabendToken { token, .. } => Some(token),
_ => None,
};

Expand Down
4 changes: 2 additions & 2 deletions src/query/service/tests/it/servers/http/clickhouse_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::collections::HashMap;

use databend_common_base::base::tokio;
use databend_query::auth::AuthMgr;
use databend_query::servers::http::middleware::EndpointKind;
use databend_query::servers::http::middleware::HTTPSessionEndpoint;
use databend_query::servers::http::middleware::HTTPSessionMiddleware;
use databend_query::servers::http::v1::clickhouse_router;
Expand Down Expand Up @@ -321,7 +321,7 @@ struct Server {
impl Server {
pub async fn new() -> Self {
let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Clickhouse, AuthMgr::instance());
HTTPSessionMiddleware::create(HttpHandlerKind::Clickhouse, EndpointKind::Clickhouse);
let endpoint = Route::new()
.nest("/", clickhouse_router())
.with(session_middleware);
Expand Down
48 changes: 10 additions & 38 deletions src/query/service/tests/it/servers/http/http_query_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use databend_common_users::CustomClaims;
use databend_common_users::EnsureUser;
use databend_query::auth::AuthMgr;
use databend_query::servers::http::middleware::get_client_ip;
use databend_query::servers::http::middleware::EndpointKind;
use databend_query::servers::http::middleware::HTTPSessionEndpoint;
use databend_query::servers::http::middleware::HTTPSessionMiddleware;
use databend_query::servers::http::v1::make_page_uri;
Expand Down Expand Up @@ -98,12 +99,7 @@ struct TestHttpQueryRequest {

impl TestHttpQueryRequest {
fn new(json: serde_json::Value) -> Self {
let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());

let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint().unwrap();

let root_auth_header = {
let mut headers = HeaderMap::new();
Expand Down Expand Up @@ -660,11 +656,7 @@ async fn test_result_timeout() -> Result<()> {
async fn test_system_tables() -> Result<()> {
let _fixture = TestFixture::setup().await?;

let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());
let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint()?;

let sql = "select name from system.tables where database='system' order by name";

Expand Down Expand Up @@ -740,13 +732,7 @@ async fn test_insert() -> Result<()> {
#[tokio::test(flavor = "current_thread")]
async fn test_query_log() -> Result<()> {
let _fixture = TestFixture::setup().await?;

let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());

let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint()?;

let sql = "create table t1(a int)";
let (status, result) = post_sql_to_endpoint(&ep, sql, 10).await?;
Expand Down Expand Up @@ -807,12 +793,7 @@ async fn test_query_log() -> Result<()> {
async fn test_query_log_killed() -> Result<()> {
let _fixture = TestFixture::setup().await?;

let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());

let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint()?;

let sql = "select sleep(2)";
let json = serde_json::json!({"sql": sql.to_string(), "pagination": {"wait_time_secs": 0}});
Expand Down Expand Up @@ -876,17 +857,17 @@ async fn post_sql(sql: &str, wait_time_secs: u64) -> Result<(StatusCode, QueryRe
post_json(&json).await
}

pub async fn create_endpoint() -> Result<EndpointType> {
pub fn create_endpoint() -> Result<EndpointType> {
let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());
HTTPSessionMiddleware::create(HttpHandlerKind::Query, EndpointKind::Query);

Ok(Route::new()
.nest("/v1/query", query_route())
.with(session_middleware))
}

async fn post_json(json: &serde_json::Value) -> Result<(StatusCode, QueryResponse)> {
let ep = create_endpoint().await?;
let ep = create_endpoint()?;
post_json_to_endpoint(&ep, json, HeaderMap::default()).await
}

Expand Down Expand Up @@ -963,12 +944,7 @@ async fn test_auth_jwt() -> Result<()> {
.build();
let _fixture = TestFixture::setup_with_config(&config).await?;

let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());

let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint()?;

let now = Clock::now_since_epoch();
let claims = JWTClaims {
Expand Down Expand Up @@ -1154,11 +1130,7 @@ async fn test_auth_jwt_with_create_user() -> Result<()> {
.build();
let _fixture = TestFixture::setup_with_config(&config).await?;

let session_middleware =
HTTPSessionMiddleware::create(HttpHandlerKind::Query, AuthMgr::instance());
let ep = Route::new()
.nest("/v1/query", query_route())
.with(session_middleware);
let ep = create_endpoint()?;

let now = Clock::now_since_epoch();
let claims = JWTClaims {
Expand Down

0 comments on commit 195b3d1

Please sign in to comment.