Skip to content

Commit

Permalink
Add internal HTTP-level Tower service
Browse files Browse the repository at this point in the history
Fixes #1496
  • Loading branch information
SimonSapin committed Aug 22, 2022
1 parent d23e2cb commit d0be3e3
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 131 deletions.
230 changes: 121 additions & 109 deletions apollo-router/src/axum_http_server_factory.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Axum http server factory. Axum provides routing capability on top of Hyper HTTP.
use std::collections::HashMap;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -31,6 +30,7 @@ use http::header::CONTENT_ENCODING;
use http::HeaderValue;
use http::Request;
use http::Uri;
use http_body::Body as _;
use hyper::server::conn::Http;
use hyper::Body;
use opentelemetry::global;
Expand Down Expand Up @@ -65,6 +65,7 @@ use crate::plugins::traffic_shaping::Elapsed;
use crate::plugins::traffic_shaping::RateLimited;
use crate::router::ApolloRouterError;
use crate::router_factory::SupergraphServiceFactory;
use crate::services::transport;

/// A basic http server using Axum.
/// Uses streaming as primary method of response.
Expand All @@ -78,6 +79,114 @@ impl AxumHttpServerFactory {
}
}

pub(crate) fn make_transport_service<RF>(
service_factory: RF,
configuration: Arc<Configuration>,
) -> Result<transport::BoxCloneService, ApolloRouterError>
where
RF: SupergraphServiceFactory,
{
let plugin_handlers = service_factory.custom_endpoints();
let cors = configuration
.server
.cors
.clone()
.into_layer()
.map_err(|e| {
ApolloRouterError::ConfigError(
crate::configuration::ConfigurationError::LayerConfiguration {
layer: "Cors".to_string(),
error: e,
},
)
})?;
let graphql_endpoint = if configuration.server.endpoint.ends_with("/*") {
// Needed for axum (check the axum docs for more information about wildcards https://docs.rs/axum/latest/axum/struct.Router.html#wildcards)
format!("{}router_extra_path", configuration.server.endpoint)
} else {
configuration.server.endpoint.clone()
};
let mut router = Router::<hyper::Body>::new()
.route(
&graphql_endpoint,
get({
let display_landing_page = configuration.server.landing_page;
move |host: Host, Extension(service): Extension<RF>, http_request: Request<Body>| {
handle_get(
host,
service.new_service().boxed(),
http_request,
display_landing_page,
)
}
})
.post({
move |host: Host,
uri: OriginalUri,
request: Json<graphql::Request>,
Extension(service): Extension<RF>,
header_map: HeaderMap| {
handle_post(
host,
uri,
request,
service.new_service().boxed(),
header_map,
)
}
}),
)
.layer(middleware::from_fn(decompress_request_body))
.layer(
TraceLayer::new_for_http()
.make_span_with(PropagatingMakeSpan::new())
.on_response(|resp: &Response<_>, _duration: Duration, span: &Span| {
if resp.status() >= StatusCode::BAD_REQUEST {
span.record(
"otel.status_code",
&opentelemetry::trace::StatusCode::Error.as_str(),
);
} else {
span.record(
"otel.status_code",
&opentelemetry::trace::StatusCode::Ok.as_str(),
);
}
}),
)
.route(&configuration.server.health_check_path, get(health_check))
.layer(Extension(service_factory))
.layer(cors)
.layer(CompressionLayer::new()); // To compress response body

for (plugin_name, handler) in plugin_handlers {
router = router.route(
&format!("/plugins/{}/*path", plugin_name),
get({
let new_handler = handler.clone();
move |host: Host, request_parts: Request<Body>| {
custom_plugin_handler(host, request_parts, new_handler)
}
})
.post({
let new_handler = handler.clone();
move |host: Host, request_parts: Request<Body>| {
custom_plugin_handler(host, request_parts, new_handler)
}
}),
);
}
Ok(router
.map_response(|response| {
response.map(|body| {
let body = Box::pin(body);
Body::wrap_stream(stream::poll_fn(move |ctx| body.as_mut().poll_data(ctx)))
})
})
.map_err(|error| match error {})
.boxed_clone())
}

impl HttpServerFactory for AxumHttpServerFactory {
type Future = Pin<Box<dyn Future<Output = Result<HttpServerHandle, ApolloRouterError>> + Send>>;

Expand All @@ -86,7 +195,6 @@ impl HttpServerFactory for AxumHttpServerFactory {
service_factory: RF,
configuration: Arc<Configuration>,
listener: Option<Listener>,
plugin_handlers: HashMap<String, Handler>,
) -> Self::Future
where
RF: SupergraphServiceFactory,
Expand All @@ -95,97 +203,7 @@ impl HttpServerFactory for AxumHttpServerFactory {
let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>();
let listen_address = configuration.server.listen.clone();

let cors = configuration
.server
.cors
.clone()
.into_layer()
.map_err(|e| {
ApolloRouterError::ConfigError(
crate::configuration::ConfigurationError::LayerConfiguration {
layer: "Cors".to_string(),
error: e,
},
)
})?;
let graphql_endpoint = if configuration.server.endpoint.ends_with("/*") {
// Needed for axum (check the axum docs for more information about wildcards https://docs.rs/axum/latest/axum/struct.Router.html#wildcards)
format!("{}router_extra_path", configuration.server.endpoint)
} else {
configuration.server.endpoint.clone()
};
let mut router = Router::new()
.route(
&graphql_endpoint,
get({
let display_landing_page = configuration.server.landing_page;
move |host: Host,
Extension(service): Extension<RF>,
http_request: Request<Body>| {
handle_get(
host,
service.new_service().boxed(),
http_request,
display_landing_page,
)
}
})
.post({
move |host: Host,
uri: OriginalUri,
request: Json<graphql::Request>,
Extension(service): Extension<RF>,
header_map: HeaderMap| {
handle_post(
host,
uri,
request,
service.new_service().boxed(),
header_map,
)
}
}),
)
.layer(middleware::from_fn(decompress_request_body))
.layer(
TraceLayer::new_for_http()
.make_span_with(PropagatingMakeSpan::new())
.on_response(|resp: &Response<_>, _duration: Duration, span: &Span| {
if resp.status() >= StatusCode::BAD_REQUEST {
span.record(
"otel.status_code",
&opentelemetry::trace::StatusCode::Error.as_str(),
);
} else {
span.record(
"otel.status_code",
&opentelemetry::trace::StatusCode::Ok.as_str(),
);
}
}),
)
.route(&configuration.server.health_check_path, get(health_check))
.layer(Extension(service_factory))
.layer(cors)
.layer(CompressionLayer::new()); // To compress response body

for (plugin_name, handler) in plugin_handlers {
router = router.route(
&format!("/plugins/{}/*path", plugin_name),
get({
let new_handler = handler.clone();
move |host: Host, request_parts: Request<Body>| {
custom_plugin_handler(host, request_parts, new_handler)
}
})
.post({
let new_handler = handler.clone();
move |host: Host, request_parts: Request<Body>| {
custom_plugin_handler(host, request_parts, new_handler)
}
}),
);
}
let router = make_transport_service(service_factory, configuration)?;

// if we received a TCP listener, reuse it, otherwise create a new one
#[cfg_attr(not(unix), allow(unused_mut))]
Expand Down Expand Up @@ -240,7 +258,7 @@ impl HttpServerFactory for AxumHttpServerFactory {
max_open_file_warning = None;
}

tokio::task::spawn(async move{
tokio::task::spawn(async move {
match res {
NetworkStream::Tcp(stream) => {
stream
Expand Down Expand Up @@ -706,6 +724,7 @@ impl<B> MakeSpan<B> for PropagatingMakeSpan {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::net::SocketAddr;
use std::str::FromStr;

Expand Down Expand Up @@ -840,7 +859,6 @@ mod tests {
.build(),
),
None,
HashMap::new(),
)
.await
.expect("Failed to create server factory");
Expand All @@ -858,7 +876,6 @@ mod tests {
async fn init_with_config(
mut mock: MockSupergraphService,
conf: Configuration,
plugin_handlers: HashMap<String, Handler>,
) -> (HttpServerHandle, Client) {
let server_factory = AxumHttpServerFactory::new();
let (service, mut handle) = tower_test::mock::spawn();
Expand All @@ -880,7 +897,6 @@ mod tests {
},
Arc::new(conf),
None,
plugin_handlers,
)
.await
.expect("Failed to create server factory");
Expand Down Expand Up @@ -928,7 +944,6 @@ mod tests {
.build(),
),
None,
HashMap::new(),
)
.await
.expect("Failed to create server factory");
Expand Down Expand Up @@ -1255,7 +1270,7 @@ mod tests {
.build(),
)
.build();
let (server, client) = init_with_config(expectations, conf, HashMap::new()).await;
let (server, client) = init_with_config(expectations, conf).await;
let url = format!("{}/graphql", server.listen_address());

// Post query
Expand Down Expand Up @@ -1324,7 +1339,7 @@ mod tests {
.build(),
)
.build();
let (server, client) = init_with_config(expectations, conf, HashMap::new()).await;
let (server, client) = init_with_config(expectations, conf).await;
let url = format!("{}/prefix/graphql", server.listen_address());

// Post query
Expand Down Expand Up @@ -1393,7 +1408,7 @@ mod tests {
.build(),
)
.build();
let (server, client) = init_with_config(expectations, conf, HashMap::new()).await;
let (server, client) = init_with_config(expectations, conf).await;
for url in &[
format!("{}/graphql/test", server.listen_address()),
format!("{}/graphql/anothertest", server.listen_address()),
Expand Down Expand Up @@ -1612,7 +1627,7 @@ mod tests {
.build(),
)
.build();
let (server, client) = init_with_config(expectations, conf, HashMap::new()).await;
let (server, client) = init_with_config(expectations, conf).await;

let response = client
.request(Method::OPTIONS, &format!("{}/", server.listen_address()))
Expand Down Expand Up @@ -1798,7 +1813,7 @@ Content-Type: application/json\r
)
.build();
let expectations = MockSupergraphService::new();
let (server, client) = init_with_config(expectations, conf, HashMap::new()).await;
let (server, client) = init_with_config(expectations, conf).await;
let url = format!("{}/health", server.listen_address());

let response = client.get(url).send().await.unwrap();
Expand Down Expand Up @@ -2026,8 +2041,7 @@ Content-Type: application/json\r
.build(),
)
.build();
let (server, client) =
init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await;
let (server, client) = init_with_config(MockSupergraphService::new(), conf).await;
let url = format!("{}/", server.listen_address());

let response =
Expand All @@ -2053,8 +2067,7 @@ Content-Type: application/json\r
.build(),
)
.build();
let (server, client) =
init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await;
let (server, client) = init_with_config(MockSupergraphService::new(), conf).await;
let url = format!("{}/", server.listen_address());

let response = request_cors_with_origin(&client, url.as_str(), valid_origin).await;
Expand Down Expand Up @@ -2084,8 +2097,7 @@ Content-Type: application/json\r
.build(),
)
.build();
let (server, client) =
init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await;
let (server, client) = init_with_config(MockSupergraphService::new(), conf).await;
let url = format!("{}/", server.listen_address());

// regex tests
Expand Down
Loading

0 comments on commit d0be3e3

Please sign in to comment.