From d0be3e361eb4aead183d86262d715526cc59db1f Mon Sep 17 00:00:00 2001 From: Simon Sapin Date: Mon, 22 Aug 2022 17:31:45 +0200 Subject: [PATCH] Add internal HTTP-level Tower service Fixes https://github.com/apollographql/router/issues/1496 --- apollo-router/src/axum_http_server_factory.rs | 230 +++++++++--------- apollo-router/src/http_server_factory.rs | 11 +- apollo-router/src/router.rs | 21 ++ apollo-router/src/state_machine.rs | 13 +- 4 files changed, 144 insertions(+), 131 deletions(-) diff --git a/apollo-router/src/axum_http_server_factory.rs b/apollo-router/src/axum_http_server_factory.rs index b7f1be1aa1..cc99b05e56 100644 --- a/apollo-router/src/axum_http_server_factory.rs +++ b/apollo-router/src/axum_http_server_factory.rs @@ -1,5 +1,4 @@ //! Axum http server factory. Axum provides routing capability on top of Hyper HTTP. -use std::collections::HashMap; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; @@ -31,6 +30,7 @@ use http::header::CONTENT_ENCODING; use http::HeaderValue; use http::Request; use http::Uri; +use http_body::Body as _; use hyper::server::conn::Http; use hyper::Body; use opentelemetry::global; @@ -65,6 +65,7 @@ use crate::plugins::traffic_shaping::Elapsed; use crate::plugins::traffic_shaping::RateLimited; use crate::router::ApolloRouterError; use crate::router_factory::SupergraphServiceFactory; +use crate::services::transport; /// A basic http server using Axum. /// Uses streaming as primary method of response. @@ -78,6 +79,114 @@ impl AxumHttpServerFactory { } } +pub(crate) fn make_transport_service( + service_factory: RF, + configuration: Arc, +) -> Result +where + RF: SupergraphServiceFactory, +{ + let plugin_handlers = service_factory.custom_endpoints(); + let cors = configuration + .server + .cors + .clone() + .into_layer() + .map_err(|e| { + ApolloRouterError::ConfigError( + crate::configuration::ConfigurationError::LayerConfiguration { + layer: "Cors".to_string(), + error: e, + }, + ) + })?; + let graphql_endpoint = if configuration.server.endpoint.ends_with("/*") { + // Needed for axum (check the axum docs for more information about wildcards https://docs.rs/axum/latest/axum/struct.Router.html#wildcards) + format!("{}router_extra_path", configuration.server.endpoint) + } else { + configuration.server.endpoint.clone() + }; + let mut router = Router::::new() + .route( + &graphql_endpoint, + get({ + let display_landing_page = configuration.server.landing_page; + move |host: Host, Extension(service): Extension, http_request: Request| { + handle_get( + host, + service.new_service().boxed(), + http_request, + display_landing_page, + ) + } + }) + .post({ + move |host: Host, + uri: OriginalUri, + request: Json, + Extension(service): Extension, + header_map: HeaderMap| { + handle_post( + host, + uri, + request, + service.new_service().boxed(), + header_map, + ) + } + }), + ) + .layer(middleware::from_fn(decompress_request_body)) + .layer( + TraceLayer::new_for_http() + .make_span_with(PropagatingMakeSpan::new()) + .on_response(|resp: &Response<_>, _duration: Duration, span: &Span| { + if resp.status() >= StatusCode::BAD_REQUEST { + span.record( + "otel.status_code", + &opentelemetry::trace::StatusCode::Error.as_str(), + ); + } else { + span.record( + "otel.status_code", + &opentelemetry::trace::StatusCode::Ok.as_str(), + ); + } + }), + ) + .route(&configuration.server.health_check_path, get(health_check)) + .layer(Extension(service_factory)) + .layer(cors) + .layer(CompressionLayer::new()); // To compress response body + + for (plugin_name, handler) in plugin_handlers { + router = router.route( + &format!("/plugins/{}/*path", plugin_name), + get({ + let new_handler = handler.clone(); + move |host: Host, request_parts: Request| { + custom_plugin_handler(host, request_parts, new_handler) + } + }) + .post({ + let new_handler = handler.clone(); + move |host: Host, request_parts: Request| { + custom_plugin_handler(host, request_parts, new_handler) + } + }), + ); + } + Ok(router + .map_response(|response| { + response.map(|body| { + let body = Box::pin(body); + Body::wrap_stream(stream::poll_fn(move |ctx| body.as_mut().poll_data(ctx))) + }) + }) + .map_err(|error| match error {}) + .boxed_clone()) +} + impl HttpServerFactory for AxumHttpServerFactory { type Future = Pin> + Send>>; @@ -86,7 +195,6 @@ impl HttpServerFactory for AxumHttpServerFactory { service_factory: RF, configuration: Arc, listener: Option, - plugin_handlers: HashMap, ) -> Self::Future where RF: SupergraphServiceFactory, @@ -95,97 +203,7 @@ impl HttpServerFactory for AxumHttpServerFactory { let (shutdown_sender, shutdown_receiver) = oneshot::channel::<()>(); let listen_address = configuration.server.listen.clone(); - let cors = configuration - .server - .cors - .clone() - .into_layer() - .map_err(|e| { - ApolloRouterError::ConfigError( - crate::configuration::ConfigurationError::LayerConfiguration { - layer: "Cors".to_string(), - error: e, - }, - ) - })?; - let graphql_endpoint = if configuration.server.endpoint.ends_with("/*") { - // Needed for axum (check the axum docs for more information about wildcards https://docs.rs/axum/latest/axum/struct.Router.html#wildcards) - format!("{}router_extra_path", configuration.server.endpoint) - } else { - configuration.server.endpoint.clone() - }; - let mut router = Router::new() - .route( - &graphql_endpoint, - get({ - let display_landing_page = configuration.server.landing_page; - move |host: Host, - Extension(service): Extension, - http_request: Request| { - handle_get( - host, - service.new_service().boxed(), - http_request, - display_landing_page, - ) - } - }) - .post({ - move |host: Host, - uri: OriginalUri, - request: Json, - Extension(service): Extension, - header_map: HeaderMap| { - handle_post( - host, - uri, - request, - service.new_service().boxed(), - header_map, - ) - } - }), - ) - .layer(middleware::from_fn(decompress_request_body)) - .layer( - TraceLayer::new_for_http() - .make_span_with(PropagatingMakeSpan::new()) - .on_response(|resp: &Response<_>, _duration: Duration, span: &Span| { - if resp.status() >= StatusCode::BAD_REQUEST { - span.record( - "otel.status_code", - &opentelemetry::trace::StatusCode::Error.as_str(), - ); - } else { - span.record( - "otel.status_code", - &opentelemetry::trace::StatusCode::Ok.as_str(), - ); - } - }), - ) - .route(&configuration.server.health_check_path, get(health_check)) - .layer(Extension(service_factory)) - .layer(cors) - .layer(CompressionLayer::new()); // To compress response body - - for (plugin_name, handler) in plugin_handlers { - router = router.route( - &format!("/plugins/{}/*path", plugin_name), - get({ - let new_handler = handler.clone(); - move |host: Host, request_parts: Request| { - custom_plugin_handler(host, request_parts, new_handler) - } - }) - .post({ - let new_handler = handler.clone(); - move |host: Host, request_parts: Request| { - custom_plugin_handler(host, request_parts, new_handler) - } - }), - ); - } + let router = make_transport_service(service_factory, configuration)?; // if we received a TCP listener, reuse it, otherwise create a new one #[cfg_attr(not(unix), allow(unused_mut))] @@ -240,7 +258,7 @@ impl HttpServerFactory for AxumHttpServerFactory { max_open_file_warning = None; } - tokio::task::spawn(async move{ + tokio::task::spawn(async move { match res { NetworkStream::Tcp(stream) => { stream @@ -706,6 +724,7 @@ impl MakeSpan for PropagatingMakeSpan { #[cfg(test)] mod tests { + use std::collections::HashMap; use std::net::SocketAddr; use std::str::FromStr; @@ -840,7 +859,6 @@ mod tests { .build(), ), None, - HashMap::new(), ) .await .expect("Failed to create server factory"); @@ -858,7 +876,6 @@ mod tests { async fn init_with_config( mut mock: MockSupergraphService, conf: Configuration, - plugin_handlers: HashMap, ) -> (HttpServerHandle, Client) { let server_factory = AxumHttpServerFactory::new(); let (service, mut handle) = tower_test::mock::spawn(); @@ -880,7 +897,6 @@ mod tests { }, Arc::new(conf), None, - plugin_handlers, ) .await .expect("Failed to create server factory"); @@ -928,7 +944,6 @@ mod tests { .build(), ), None, - HashMap::new(), ) .await .expect("Failed to create server factory"); @@ -1255,7 +1270,7 @@ mod tests { .build(), ) .build(); - let (server, client) = init_with_config(expectations, conf, HashMap::new()).await; + let (server, client) = init_with_config(expectations, conf).await; let url = format!("{}/graphql", server.listen_address()); // Post query @@ -1324,7 +1339,7 @@ mod tests { .build(), ) .build(); - let (server, client) = init_with_config(expectations, conf, HashMap::new()).await; + let (server, client) = init_with_config(expectations, conf).await; let url = format!("{}/prefix/graphql", server.listen_address()); // Post query @@ -1393,7 +1408,7 @@ mod tests { .build(), ) .build(); - let (server, client) = init_with_config(expectations, conf, HashMap::new()).await; + let (server, client) = init_with_config(expectations, conf).await; for url in &[ format!("{}/graphql/test", server.listen_address()), format!("{}/graphql/anothertest", server.listen_address()), @@ -1612,7 +1627,7 @@ mod tests { .build(), ) .build(); - let (server, client) = init_with_config(expectations, conf, HashMap::new()).await; + let (server, client) = init_with_config(expectations, conf).await; let response = client .request(Method::OPTIONS, &format!("{}/", server.listen_address())) @@ -1798,7 +1813,7 @@ Content-Type: application/json\r ) .build(); let expectations = MockSupergraphService::new(); - let (server, client) = init_with_config(expectations, conf, HashMap::new()).await; + let (server, client) = init_with_config(expectations, conf).await; let url = format!("{}/health", server.listen_address()); let response = client.get(url).send().await.unwrap(); @@ -2026,8 +2041,7 @@ Content-Type: application/json\r .build(), ) .build(); - let (server, client) = - init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await; + let (server, client) = init_with_config(MockSupergraphService::new(), conf).await; let url = format!("{}/", server.listen_address()); let response = @@ -2053,8 +2067,7 @@ Content-Type: application/json\r .build(), ) .build(); - let (server, client) = - init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await; + let (server, client) = init_with_config(MockSupergraphService::new(), conf).await; let url = format!("{}/", server.listen_address()); let response = request_cors_with_origin(&client, url.as_str(), valid_origin).await; @@ -2084,8 +2097,7 @@ Content-Type: application/json\r .build(), ) .build(); - let (server, client) = - init_with_config(MockSupergraphService::new(), conf, HashMap::new()).await; + let (server, client) = init_with_config(MockSupergraphService::new(), conf).await; let url = format!("{}/", server.listen_address()); // regex tests diff --git a/apollo-router/src/http_server_factory.rs b/apollo-router/src/http_server_factory.rs index 2a9f22a99f..d2a29971ad 100644 --- a/apollo-router/src/http_server_factory.rs +++ b/apollo-router/src/http_server_factory.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; @@ -9,7 +8,6 @@ use futures::prelude::*; use super::router::ApolloRouterError; use crate::configuration::Configuration; use crate::configuration::ListenAddr; -use crate::plugin::Handler; use crate::router_factory::SupergraphServiceFactory; /// Factory for creating the http server component. @@ -24,7 +22,6 @@ pub(crate) trait HttpServerFactory { service_factory: RF, configuration: Arc, listener: Option, - plugin_handlers: HashMap, ) -> Self::Future where RF: SupergraphServiceFactory; @@ -81,7 +78,6 @@ impl HttpServerHandle { factory: &SF, router: RF, configuration: Arc, - plugin_handlers: HashMap, ) -> Result where SF: HttpServerFactory, @@ -113,12 +109,7 @@ impl HttpServerHandle { }; let handle = factory - .create( - router, - Arc::clone(&configuration), - listener, - plugin_handlers, - ) + .create(router, Arc::clone(&configuration), listener) .await?; tracing::debug!("restarted on {}", handle.listen_address()); diff --git a/apollo-router/src/router.rs b/apollo-router/src/router.rs index 72c23a2fa3..f3e97b7336 100644 --- a/apollo-router/src/router.rs +++ b/apollo-router/src/router.rs @@ -33,11 +33,32 @@ use crate::axum_http_server_factory::AxumHttpServerFactory; use crate::configuration::validate_configuration; use crate::configuration::Configuration; use crate::configuration::ListenAddr; +use crate::plugin::DynPlugin; +use crate::router_factory::SupergraphServiceConfigurator; use crate::router_factory::YamlSupergraphServiceFactory; +use crate::services::transport; +use crate::spec::Schema; use crate::state_machine::StateMachine; type SchemaStream = Pin + Send>>; +/// Could eventually add a public builder similar to `test_harness.rs` +/// https://github.com/apollographql/router/issues/1496 +async fn make_transport_service( + schema: &str, + configuration: Arc, + extra_plugins: Vec<(String, Box)>, +) -> Result { + let schema = Arc::new(Schema::parse(schema, &configuration)?); + let service_factory = YamlSupergraphServiceFactory + .create(configuration.clone(), schema, None, Some(extra_plugins)) + .await?; + Ok(crate::axum_http_server_factory::make_transport_service( + service_factory, + configuration, + )?) +} + /// Error types for FederatedServer. #[derive(Error, Debug, DisplayDoc)] pub enum ApolloRouterError { diff --git a/apollo-router/src/state_machine.rs b/apollo-router/src/state_machine.rs index ef6de6f11b..4ac1924379 100644 --- a/apollo-router/src/state_machine.rs +++ b/apollo-router/src/state_machine.rs @@ -306,16 +306,10 @@ where tracing::error!("cannot create the router: {}", err); Errored(ApolloRouterError::ServiceCreationError(err)) })?; - let plugin_handlers = router_factory.custom_endpoints(); let server_handle = self .http_server_factory - .create( - router_factory.clone(), - configuration.clone(), - None, - plugin_handlers, - ) + .create(router_factory.clone(), configuration.clone(), None) .await .map_err(|err| { tracing::error!("cannot start the router: {}", err); @@ -360,14 +354,11 @@ where .await { Ok(new_router_service) => { - let plugin_handlers = new_router_service.custom_endpoints(); - let server_handle = server_handle .restart( &self.http_server_factory, new_router_service.clone(), new_configuration.clone(), - plugin_handlers, ) .await .map_err(|err| { @@ -436,7 +427,6 @@ mod tests { use crate::http_ext::Response; use crate::http_server_factory::Listener; use crate::plugin::DynPlugin; - use crate::plugin::Handler; use crate::router_factory::SupergraphServiceConfigurator; use crate::router_factory::SupergraphServiceFactory; use crate::services::new_service::NewService; @@ -715,7 +705,6 @@ mod tests { _service_factory: RF, configuration: Arc, listener: Option, - _plugin_handlers: HashMap, ) -> Self::Future where RF: SupergraphServiceFactory,