Skip to content

Commit

Permalink
proxy IC API, pass headers through Agent transport, other minor
Browse files Browse the repository at this point in the history
  • Loading branch information
blind-oracle committed May 30, 2024
1 parent 20c9ad9 commit fb936d0
Show file tree
Hide file tree
Showing 13 changed files with 469 additions and 73 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ hyper-util = "0.1"
ic-agent = { version = "0.35", features = ["reqwest"] }
#ic-http-gateway = { git = "https://github.com/dfinity/http-gateway" }
ic-http-gateway = { path = "../http-gateway/packages/ic-http-gateway" }
ic-transport-types = "0.35"
instant-acme = "0.4"
jemallocator = "0.5"
jemalloc-ctl = "0.5"
Expand Down Expand Up @@ -86,6 +87,7 @@ rustls-acme = "0.10"
rustls-pemfile = "2"
sha1 = "0.10"
serde = "1.0"
serde_cbor = "0.11"
serde_json = "1.0"
strum = "0.26"
strum_macros = "0.26"
Expand Down
10 changes: 10 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ pub struct Cli {
#[command(flatten, next_help_heading = "HTTP Server")]
pub http_server: HttpServer,

#[command(flatten, next_help_heading = "IC")]
pub ic: Ic,

#[command(flatten, next_help_heading = "Certificates")]
pub cert: Cert,

Expand Down Expand Up @@ -131,6 +134,13 @@ pub struct HttpServer {
pub grace_period: Duration,
}

#[derive(Args)]
pub struct Ic {
/// URLs to use to connect to the IC network
#[clap(long = "ic-url")]
pub url: Vec<Url>,
}

#[derive(Args)]
pub struct Cert {
/// Read certificates from given directories, each certificate should be a pair .pem + .key files with the same base name
Expand Down
8 changes: 4 additions & 4 deletions src/http/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{sync::Arc, time::Duration};
use std::{fmt, sync::Arc, time::Duration};

use async_trait::async_trait;
use mockall::automock;
use reqwest::dns::Resolve;

#[automock]
#[async_trait]
pub trait Client: Send + Sync {
pub trait Client: Send + Sync + fmt::Debug {
async fn execute(&self, req: reqwest::Request) -> Result<reqwest::Response, reqwest::Error>;
}

Expand Down Expand Up @@ -43,11 +43,11 @@ pub fn new(
Ok(client)
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct ReqwestClient(reqwest::Client);

impl ReqwestClient {
pub fn new(client: reqwest::Client) -> Self {
pub const fn new(client: reqwest::Client) -> Self {
Self(client)
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/log/clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub struct Row {
pub host: String,
pub path: String,
pub canister_id: String,
pub ic_streaming: bool,
pub ic_upgrade: bool,
pub error_cause: String,
pub tls_version: String,
pub tls_cipher: String,
Expand Down
15 changes: 14 additions & 1 deletion src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ use crate::{
server::{ConnInfo, TlsInfo},
},
log::clickhouse::{Clickhouse, Row},
routing::{error_cause::ErrorCause, middleware::request_id::RequestId, RequestCtx},
routing::{
error_cause::ErrorCause, ic::IcResponseStatus, middleware::request_id::RequestId,
RequestCtx,
},
tasks::{Run, TaskManager},
};
use body::CountingBody;
Expand Down Expand Up @@ -278,6 +281,7 @@ pub async fn middleware(

let ctx = response.extensions().get::<Arc<RequestCtx>>().cloned();
let error_cause = response.extensions().get::<ErrorCause>().cloned();
let ic_status = response.extensions().get::<IcResponseStatus>().cloned();
let status = response.status().as_u16();

// By this time the channel should already have the data
Expand Down Expand Up @@ -347,6 +351,11 @@ pub async fn middleware(
let conn_sent = conn_info.traffic.sent();
let conn_req_count = conn_info.req_count.load(Ordering::SeqCst);

let (ic_streaming, ic_upgrade) = ic_status
.as_ref()
.map(|x| (x.streaming, x.metadata.upgraded_to_update_call))
.unwrap_or((false, false));

// Log the request
info!(
request_id = request_id.to_string(),
Expand All @@ -361,6 +370,8 @@ pub async fn middleware(
host,
path,
canister_id,
ic_streaming,
ic_upgrade,
error = error_cause,
req_size = request_size,
resp_size = response_size,
Expand All @@ -386,6 +397,8 @@ pub async fn middleware(
host: host.into(),
path: path.into(),
canister_id,
ic_streaming,
ic_upgrade,
error_cause,
tls_version: tls_version.into(),
tls_cipher: tls_cipher.into(),
Expand Down
1 change: 1 addition & 0 deletions src/policy/denylist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ mod tests {
use super::*;
use async_trait::async_trait;

#[derive(Debug)]
struct TestClient(reqwest::Client);

#[async_trait]
Expand Down
13 changes: 0 additions & 13 deletions src/routing/error_cause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use strum_macros::Display;
// Process error chain trying to find given error type
pub fn error_infer<E: StdError + Send + Sync + 'static>(error: &anyhow::Error) -> Option<&E> {
for cause in error.chain() {
println!("{:?}", cause);
if let Some(e) = cause.downcast_ref() {
return Some(e);
}
Expand Down Expand Up @@ -245,17 +244,5 @@ mod test {
ErrorCause::from(err),
ErrorCause::BackendTLSErrorOther(_)
));

let err = Box::new(rustls::Error::BadMaxFragmentSize) as Box<dyn StdError + Send + Sync>;
let err2 = anyhow!(err);
// let mut iter = err2.chain();
// println!("{:?}", iter.next());
// println!("{:?}", iter.next());

//let err = ErrorCause::from_boxed(err);
println!("{:?}", error_infer::<Box<rustls::Error>>(&err2));

assert!(false);
//assert!(matches!(err, ErrorCause::BackendTLSErrorOther(_)));
}
}
80 changes: 42 additions & 38 deletions src/routing/handler.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,37 @@
use std::sync::Arc;

use axum::{
body::Body,
extract::{Request, State},
response::{IntoResponse, Response},
response::Response,
Extension,
};
use bytes::Bytes;
use futures::StreamExt;
use http::Uri;
use http_body::Frame;
use http_body_util::{BodyExt, Full, LengthLimitError, Limited, StreamBody};
use ic_http_gateway::{
CanisterRequest, HttpGatewayClient, HttpGatewayRequestArgs, HttpGatewayResponse,
HttpGatewayResponseBody,
};
use http::{HeaderValue, Uri};
use http_body_util::{BodyExt, LengthLimitError, Limited};
use ic_http_gateway::{CanisterRequest, HttpGatewayClient, HttpGatewayRequestArgs};

use super::{error_cause::ErrorCause, RequestCtx};
use super::{
error_cause::ErrorCause,
ic::{
self,
transport::{PassHeaders, PASS_HEADERS},
IcResponseStatus,
},
middleware::{self, request_id::RequestId},
RequestCtx,
};

const MAX_REQUEST_BODY_SIZE: usize = 10_485_760;
const MAX_REQUEST_BODY_SIZE: usize = 10 * 1_048_576;

#[derive(derive_new::new)]
pub struct HandlerState {
gw: HttpGatewayClient,
}

fn convert_response(resp: HttpGatewayResponse) -> Response {
let (parts, body) = resp.canister_response.into_parts();

match body {
HttpGatewayResponseBody::Bytes(v) => {
Response::from_parts(parts, Body::new(Full::new(v.into()))).into_response()
}

HttpGatewayResponseBody::Stream(v) => {
let v = v.map(|x| x.map(|y| Frame::data(Bytes::from(y))));
let body = StreamBody::new(v);
let body = Body::new(body);

Response::from_parts(parts, body).into_response()
}
}
client: HttpGatewayClient,
}

pub async fn handler(
State(state): State<Arc<HandlerState>>,
Extension(ctx): Extension<Arc<RequestCtx>>,
Extension(request_id): Extension<RequestId>,
request: Request,
) -> Result<Response, ErrorCause> {
let (mut parts, body) = request.into_parts();
Expand Down Expand Up @@ -74,15 +60,33 @@ pub async fn handler(
canister_id: ctx.canister.id,
};

// Execute the request
let resp = state
.gw
.request(args)
//.unsafe_allow_skip_verification()
.send()
// Pass headers in/out the IC request
let resp = PASS_HEADERS
.scope(PassHeaders::new(), async {
PASS_HEADERS.with(|x| {
let hdr =
HeaderValue::from_maybe_shared(Bytes::from(request_id.to_string())).unwrap();

x.borrow_mut()
.headers_out
.insert(middleware::request_id::HEADER, hdr)
});

// Execute the request
state
.client
.request(args)
//.unsafe_allow_skip_verification()
.send()
.await
})
.await
.map_err(ErrorCause::from_err)?;

let ic_status = IcResponseStatus::from(&resp);
let mut response = ic::convert_response(resp.canister_response);
response.extensions_mut().insert(ic_status);

// Convert it into Axum response
Ok(convert_response(resp))
Ok(response)
}
72 changes: 72 additions & 0 deletions src/routing/ic/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
pub mod transport;

use std::sync::Arc;

use anyhow::Error;
use axum::{
body::Body,
response::{IntoResponse, Response},
};
use bytes::Bytes;
use futures::StreamExt;
use http_body::Frame;
use http_body_util::{Full, StreamBody};
use ic_agent::agent::http_transport::route_provider::RouteProvider;
use ic_http_gateway::{
HttpGatewayClient, HttpGatewayResponse, HttpGatewayResponseBody, HttpGatewayResponseMetadata,
};

use crate::{http::Client as HttpClient, Cli};

#[derive(Clone)]
pub struct IcResponseStatus {
pub streaming: bool,
pub metadata: HttpGatewayResponseMetadata,
}

impl From<&HttpGatewayResponse> for IcResponseStatus {
fn from(value: &HttpGatewayResponse) -> Self {
Self {
streaming: matches!(
value.canister_response.body(),
HttpGatewayResponseBody::Stream(_)
),
metadata: value.metadata.clone(),
}
}
}

pub fn convert_response(resp: Response<HttpGatewayResponseBody>) -> Response {
let (parts, body) = resp.into_parts();

match body {
HttpGatewayResponseBody::Bytes(v) => {
Response::from_parts(parts, Body::new(Full::new(v.into()))).into_response()
}

HttpGatewayResponseBody::Stream(v) => {
let v = v.map(|x| x.map(|y| Frame::data(Bytes::from(y))));
let body = StreamBody::new(v);
let body = Body::new(body);

Response::from_parts(parts, body).into_response()
}
}
}

pub fn setup(
_cli: &Cli,
http_client: Arc<dyn HttpClient>,
route_provider: Arc<dyn RouteProvider>,
) -> Result<HttpGatewayClient, Error> {
let transport =
transport::ReqwestTransport::create_with_client_route(route_provider, http_client)?;
let agent = ic_agent::Agent::builder()
.with_transport(transport)
.build()?;
let client = ic_http_gateway::HttpGatewayClientBuilder::new()
.with_agent(agent)
.build()?;

Ok(client)
}
Loading

0 comments on commit fb936d0

Please sign in to comment.