Skip to content

Commit

Permalink
fix: run on_upstream_http_request hook in federation source (#440)
Browse files Browse the repository at this point in the history
Co-authored-by: Dotan Simha <dotansimha@gmail.com>
  • Loading branch information
YassinEldeeb and dotansimha authored Feb 14, 2024
1 parent 585c1ce commit ccb9ff0
Show file tree
Hide file tree
Showing 10 changed files with 339 additions and 249 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions libs/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod graphql;
pub mod http;
pub mod json;
pub mod plugin;
pub mod plugin_manager;
pub mod serde_utils;
pub mod vrl_functions;
pub mod vrl_utils;
Expand Down
28 changes: 28 additions & 0 deletions libs/common/src/plugin_manager.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use crate::{
execute::RequestExecutionContext,
graphql::GraphQLRequest,
http::{ConductorHttpRequest, ConductorHttpResponse},
};
use reqwest::Response;

#[async_trait::async_trait(?Send)]
pub trait PluginManager: std::fmt::Debug + Send + Sync {
async fn on_downstream_http_request(&self, context: &mut RequestExecutionContext);
fn on_downstream_http_response(
&self,
context: &mut RequestExecutionContext,
response: &mut ConductorHttpResponse,
);
async fn on_downstream_graphql_request(&self, context: &mut RequestExecutionContext);
async fn on_upstream_graphql_request<'a>(&self, req: &mut GraphQLRequest);
async fn on_upstream_http_request<'a>(
&self,
ctx: &mut RequestExecutionContext,
request: &mut ConductorHttpRequest,
);
async fn on_upstream_http_response<'a>(
&self,
ctx: &mut RequestExecutionContext,
response: &Result<Response, reqwest_middleware::Error>,
);
}
18 changes: 10 additions & 8 deletions libs/engine/src/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use conductor_common::{
graphql::{ExtractGraphQLOperationError, GraphQLRequest, GraphQLResponse, ParsedGraphQLRequest},
http::{ConductorHttpRequest, ConductorHttpResponse, Url},
plugin::PluginError,
plugin_manager::PluginManager,
};
use conductor_config::{ConductorConfig, EndpointDefinition, SourceDefinition};
use conductor_tracing::{
Expand All @@ -17,7 +18,7 @@ use reqwest::{Method, StatusCode};
use tracing::error;

use crate::{
plugin_manager::PluginManager,
plugin_manager::PluginManagerImpl,
source::{
federation_source::FederationSourceRuntime,
graphql_source::GraphQLSourceRuntime,
Expand All @@ -30,7 +31,7 @@ use crate::{
pub struct ConductorGatewayRouteData {
pub endpoint: String,
pub tenant_id: u32,
pub plugin_manager: Arc<PluginManager>,
pub plugin_manager: Arc<Box<dyn PluginManager>>,
pub to: Arc<Box<dyn SourceRuntime>>,
}

Expand Down Expand Up @@ -102,9 +103,10 @@ impl ConductorGateway {
.cloned()
.collect::<Vec<_>>();

let plugin_manager = PluginManager::new(&Some(combined_plugins), tracing_manager, tenant_id)
.await
.map_err(GatewayError::PluginManagerInitError)?;
let plugin_manager =
PluginManagerImpl::new(&Some(combined_plugins), tracing_manager, tenant_id)
.await
.map_err(GatewayError::PluginManagerInitError)?;

let upstream_source: Box<dyn SourceRuntime> = config_object
.sources
Expand All @@ -115,7 +117,7 @@ impl ConductorGateway {
let route_data = ConductorGatewayRouteData {
endpoint: endpoint_config.path.clone(),
to: Arc::new(upstream_source),
plugin_manager: Arc::new(plugin_manager),
plugin_manager: Arc::new(Box::new(plugin_manager)),
tenant_id,
};

Expand Down Expand Up @@ -159,10 +161,10 @@ impl ConductorGateway {
plugins: Vec<Box<dyn conductor_common::plugin::Plugin>>,
request: ConductorHttpRequest,
) -> ConductorHttpResponse {
let plugin_manager = PluginManager::new_from_vec(plugins);
let plugin_manager = PluginManagerImpl::new_from_vec(plugins);
let route_data = ConductorGatewayRouteData {
endpoint: "/".to_string(),
plugin_manager: Arc::new(plugin_manager),
plugin_manager: Arc::new(Box::new(plugin_manager)),
to: source,
tenant_id: 0,
};
Expand Down
26 changes: 15 additions & 11 deletions libs/engine/src/plugin_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@ use conductor_common::{
graphql::GraphQLRequest,
http::{ConductorHttpRequest, ConductorHttpResponse},
plugin::{CreatablePlugin, Plugin, PluginError},
plugin_manager::PluginManager,
};
use conductor_config::PluginDefinition;
use conductor_tracing::minitrace_mgr::MinitraceManager;
use reqwest::Response;

#[derive(Debug, Default)]
pub struct PluginManager {
pub struct PluginManagerImpl {
plugins: Vec<Box<dyn Plugin>>,
}

impl PluginManager {
impl PluginManagerImpl {
pub fn new_from_vec(plugins: Vec<Box<dyn Plugin>>) -> Self {
let mut pm = Self { plugins };

// We want to make sure to register default plugins last, in order to ensure it's setting the value correctly
for p in PluginManager::default_plugins() {
for p in PluginManagerImpl::default_plugins() {
pm.register_boxed_plugin(p);
}

Expand All @@ -34,7 +35,7 @@ impl PluginManager {
tracing_manager: &mut MinitraceManager,
tenant_id: u32,
) -> Result<Self, PluginError> {
let mut instance = PluginManager::default();
let mut instance = PluginManagerImpl::default();

if let Some(config_defs) = plugins_config {
for plugin_def in config_defs.iter() {
Expand Down Expand Up @@ -98,7 +99,7 @@ impl PluginManager {
};

// We want to make sure to register these last, in order to ensure it's setting the value correctly
for p in PluginManager::default_plugins() {
for p in PluginManagerImpl::default_plugins() {
instance.register_boxed_plugin(p);
}

Expand All @@ -116,14 +117,17 @@ impl PluginManager {
pub fn register_plugin(&mut self, plugin: impl Plugin + 'static) {
self.plugins.push(Box::new(plugin));
}
}

#[async_trait::async_trait(?Send)]
impl PluginManager for PluginManagerImpl {
#[tracing::instrument(
level = "debug",
skip(self, context),
name = "on_downstream_http_request"
)]
#[inline]
pub async fn on_downstream_http_request(&self, context: &mut RequestExecutionContext) {
async fn on_downstream_http_request(&self, context: &mut RequestExecutionContext) {
let p = &self.plugins;

for plugin in p.iter() {
Expand All @@ -141,7 +145,7 @@ impl PluginManager {
name = "on_downstream_http_response"
)]
#[inline]
pub fn on_downstream_http_response(
fn on_downstream_http_response(
&self,
context: &mut RequestExecutionContext,
response: &mut ConductorHttpResponse,
Expand All @@ -163,7 +167,7 @@ impl PluginManager {
name = "on_downstream_graphql_request"
)]
#[inline]
pub async fn on_downstream_graphql_request(&self, context: &mut RequestExecutionContext) {
async fn on_downstream_graphql_request(&self, context: &mut RequestExecutionContext) {
let p = &self.plugins;

for plugin in p.iter() {
Expand All @@ -177,7 +181,7 @@ impl PluginManager {

#[tracing::instrument(level = "debug", skip(self, req), name = "on_upstream_graphql_request")]
#[inline]
pub async fn on_upstream_graphql_request<'a>(&self, req: &mut GraphQLRequest) {
async fn on_upstream_graphql_request<'a>(&self, req: &mut GraphQLRequest) {
let p = &self.plugins;

for plugin in p.iter() {
Expand All @@ -191,7 +195,7 @@ impl PluginManager {
name = "on_upstream_http_request"
)]
#[inline]
pub async fn on_upstream_http_request<'a>(
async fn on_upstream_http_request<'a>(
&self,
ctx: &mut RequestExecutionContext,
request: &mut ConductorHttpRequest,
Expand All @@ -213,7 +217,7 @@ impl PluginManager {
name = "on_upstream_http_response"
)]
#[inline]
pub async fn on_upstream_http_response<'a>(
async fn on_upstream_http_response<'a>(
&self,
ctx: &mut RequestExecutionContext,
response: &Result<Response, reqwest_middleware::Error>,
Expand Down
28 changes: 15 additions & 13 deletions libs/engine/src/source/federation_source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use base64::{engine, Engine};
use conductor_common::execute::RequestExecutionContext;
use conductor_common::graphql::GraphQLResponse;
use conductor_config::{FederationSourceConfig, SupergraphSourceConfig};
use federation_query_planner::execute_federation;
use federation_query_planner::supergraph::{parse_supergraph, Supergraph};
use federation_query_planner::FederationExecutor;
use futures::lock::Mutex;
use minitrace_reqwest::{traced_reqwest, TracedHttpClient};
use std::collections::HashMap;
use std::sync::Arc;
use std::{future::Future, pin::Pin};

#[derive(Debug)]
Expand Down Expand Up @@ -156,7 +158,7 @@ impl FederationSourceRuntime {
}

pub async fn update_supergraph(&mut self, new_schema: String) {
let new_supergraph = parse_supergraph(&new_schema).unwrap();
let new_supergraph: Supergraph = parse_supergraph(&new_schema).unwrap();
self.supergraph = new_supergraph;
}

Expand Down Expand Up @@ -190,7 +192,7 @@ impl SourceRuntime for FederationSourceRuntime {

fn execute<'a>(
&'a self,
_route_data: &'a ConductorGatewayRouteData,
route_data: &'a ConductorGatewayRouteData,
request_context: &'a mut RequestExecutionContext,
) -> Pin<Box<(dyn Future<Output = Result<GraphQLResponse, SourceError>> + 'a)>> {
Box::pin(wasm_polyfills::call_async(async move {
Expand All @@ -199,17 +201,17 @@ impl SourceRuntime for FederationSourceRuntime {
.take()
.expect("GraphQL request isn't available at the time of execution");

// let source_req = &mut downstream_request.request;

// TODO: this needs to be called by conductor execution when fetching subgarphs
// route_data
// .plugin_manager
// .on_upstream_graphql_request(source_req)
// .await;

let operation = downstream_request.parsed_operation;

match execute_federation(&self.client, &self.supergraph, operation).await {
let executor = FederationExecutor {
client: &self.client,
plugin_manager: route_data.plugin_manager.clone(),
supergraph: &self.supergraph,
};

match executor
.execute_federation(Arc::new(Mutex::new(request_context)), operation)
.await
{
Ok((response_data, query_plan)) => {
let mut response = serde_json::from_str::<GraphQLResponse>(&response_data).unwrap();

Expand Down
1 change: 1 addition & 0 deletions libs/federation_query_planner/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ bench = false
serde = { workspace = true }
wasm_polyfills = { path = "../wasm_polyfills" }
conductor_tracing = { path = "../tracing" }
conductor_common = { path = "../common" }
serde_json = { workspace = true }
async-trait = { workspace = true }
anyhow = { workspace = true }
Expand Down
Loading

0 comments on commit ccb9ff0

Please sign in to comment.