Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bin/router/src/pipeline/coerce_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct CoerceVariablesPayload {
pub fn coerce_request_variables(
req: &HttpRequest,
supergraph: &SupergraphData,
execution_params: ExecutionRequest,
execution_params: &mut ExecutionRequest,
normalized_operation: &Arc<GraphQLNormalizationPayload>,
) -> Result<CoerceVariablesPayload, PipelineError> {
if req.method() == Method::GET {
Expand All @@ -37,7 +37,7 @@ pub fn coerce_request_variables(

match collect_variables(
&normalized_operation.operation_for_plan,
execution_params.variables,
&mut execution_params.variables,
&supergraph.metadata,
) {
Ok(values) => {
Expand Down
26 changes: 5 additions & 21 deletions bin/router/src/pipeline/execution.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::sync::Arc;

Expand All @@ -11,11 +10,10 @@ use crate::shared_state::RouterSharedState;
use hive_router_plan_executor::execute_query_plan;
use hive_router_plan_executor::execution::jwt_forward::JwtAuthForwardingPlan;
use hive_router_plan_executor::execution::plan::{
ClientRequestDetails, OperationDetails, PlanExecutionOutput, QueryPlanExecutionContext,
ClientRequestDetails, PlanExecutionOutput, QueryPlanExecutionContext,
};
use hive_router_plan_executor::introspection::resolve::IntrospectionContext;
use hive_router_query_planner::planner::plan_nodes::QueryPlan;
use hive_router_query_planner::state::supergraph_state::OperationKind;
use http::HeaderName;
use ntex::web::HttpRequest;

Expand All @@ -29,14 +27,14 @@ enum ExposeQueryPlanMode {
}

#[inline]
pub async fn execute_plan<'a>(
req: &mut HttpRequest,
query: Cow<'a, str>,
pub async fn execute_plan(
req: &HttpRequest,
supergraph: &SupergraphData,
app_state: &Arc<RouterSharedState>,
normalized_payload: &Arc<GraphQLNormalizationPayload>,
query_plan_payload: &Arc<QueryPlan>,
variable_payload: &CoerceVariablesPayload,
client_request_details: &ClientRequestDetails<'_, '_>,
) -> Result<PlanExecutionOutput, PipelineError> {
let mut expose_query_plan = ExposeQueryPlanMode::No;

Expand Down Expand Up @@ -86,21 +84,7 @@ pub async fn execute_plan<'a>(
headers_plan: &app_state.headers_plan,
variable_values: &variable_payload.variables_map,
extensions,
client_request: ClientRequestDetails {
method: req.method().clone(),
url: req.uri().clone(),
headers: req.headers(),
operation: OperationDetails {
name: normalized_payload.operation_for_plan.name.clone(),
kind: match normalized_payload.operation_for_plan.operation_kind {
Some(OperationKind::Query) => "query",
Some(OperationKind::Mutation) => "mutation",
Some(OperationKind::Subscription) => "subscription",
None => "query",
},
query,
},
},
client_request: client_request_details,
introspection_context: &introspection_context,
operation_type_name: normalized_payload.root_type_name,
jwt_auth_forwarding: &jwt_forward_plan,
Expand Down
18 changes: 14 additions & 4 deletions bin/router/src/pipeline/execution_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use http::Method;
use ntex::util::Bytes;
use ntex::web::types::Query;
use ntex::web::HttpRequest;
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use sonic_rs::Value;
use tracing::{trace, warn};

Expand All @@ -25,12 +25,22 @@ struct GETQueryParams {
pub struct ExecutionRequest {
pub query: String,
pub operation_name: Option<String>,
pub variables: Option<HashMap<String, Value>>,
#[serde(default, deserialize_with = "deserialize_null_default")]
pub variables: HashMap<String, Value>,
// TODO: We don't use extensions yet, but we definitely will in the future.
#[allow(dead_code)]
pub extensions: Option<HashMap<String, Value>>,
}

fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
T: Default + Deserialize<'de>,
D: Deserializer<'de>,
{
let opt = Option::<T>::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
}

impl TryInto<ExecutionRequest> for GETQueryParams {
type Error = PipelineErrorVariant;

Expand All @@ -42,12 +52,12 @@ impl TryInto<ExecutionRequest> for GETQueryParams {

let variables = match self.variables.as_deref() {
Some(v_str) if !v_str.is_empty() => match sonic_rs::from_str(v_str) {
Ok(vars) => Some(vars),
Ok(vars) => vars,
Err(e) => {
return Err(PipelineErrorVariant::FailedToParseVariables(e));
}
},
_ => None,
_ => HashMap::new(),
};

let extensions = match self.extensions.as_deref() {
Expand Down
51 changes: 25 additions & 26 deletions bin/router/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{borrow::Cow, sync::Arc};
use std::sync::Arc;

use hive_router_plan_executor::execution::plan::{
ClientRequestDetails, OperationDetails, PlanExecutionOutput,
Expand Down Expand Up @@ -110,7 +110,7 @@ pub async fn execute_pipeline(
) -> Result<PlanExecutionOutput, PipelineError> {
perform_csrf_prevention(req, &shared_state.router_config.csrf)?;

let execution_request = get_execution_request(req, body_bytes).await?;
let mut execution_request = get_execution_request(req, body_bytes).await?;
let parser_payload = parse_operation_with_cache(req, shared_state, &execution_request).await?;
validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload)
.await?;
Expand All @@ -123,34 +123,33 @@ pub async fn execute_pipeline(
&parser_payload,
)
.await?;
let query: Cow<'_, str> = Cow::Owned(execution_request.query.clone());
let variable_payload =
coerce_request_variables(req, supergraph, execution_request, &normalize_payload)?;
coerce_request_variables(req, supergraph, &mut execution_request, &normalize_payload)?;

let query_plan_cancellation_token =
CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout);

let progressive_override_ctx =
request_override_context(&shared_state.override_labels_evaluator, || {
ClientRequestDetails {
method: req.method().clone(),
url: req.uri().clone(),
headers: req.headers(),
operation: OperationDetails {
name: normalize_payload.operation_for_plan.name.clone(),
kind: match normalize_payload.operation_for_plan.operation_kind {
Some(OperationKind::Query) => "query",
Some(OperationKind::Mutation) => "mutation",
Some(OperationKind::Subscription) => "subscription",
None => "query",
},
query: query.clone(),
},
}
})
.map_err(|error| {
req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error))
})?;
let client_request_details = ClientRequestDetails {
method: req.method(),
url: req.uri(),
headers: req.headers(),
operation: OperationDetails {
name: normalize_payload.operation_for_plan.name.as_deref(),
kind: match normalize_payload.operation_for_plan.operation_kind {
Some(OperationKind::Query) => "query",
Some(OperationKind::Mutation) => "mutation",
Some(OperationKind::Subscription) => "subscription",
None => "query",
},
query: &execution_request.query,
},
};

let progressive_override_ctx = request_override_context(
&shared_state.override_labels_evaluator,
&client_request_details,
)
.map_err(|error| req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)))?;

let query_plan_payload = plan_operation_with_cache(
req,
Expand All @@ -164,12 +163,12 @@ pub async fn execute_pipeline(

let execution_result = execute_plan(
req,
query,
supergraph,
shared_state,
&normalize_payload,
&query_plan_payload,
&variable_payload,
&client_request_details,
)
.await?;

Expand Down
26 changes: 8 additions & 18 deletions bin/router/src/pipeline/progressive_override.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ pub struct RequestOverrideContext {
}

#[inline]
pub fn request_override_context<'req, F>(
pub fn request_override_context<'exec, 'req>(
override_labels_evaluator: &OverrideLabelsEvaluator,
get_client_request: F,
) -> Result<RequestOverrideContext, LabelEvaluationError>
where
F: FnOnce() -> ClientRequestDetails<'req>,
{
let active_flags = override_labels_evaluator.evaluate(get_client_request)?;
client_request_details: &ClientRequestDetails<'exec, 'req>,
) -> Result<RequestOverrideContext, LabelEvaluationError> {
let active_flags = override_labels_evaluator.evaluate(client_request_details)?;

// Generate the random percentage value for this request.
// Percentage is 0 - 100_000_000_000 (100*PERCENTAGE_SCALE_FACTOR)
Expand Down Expand Up @@ -161,25 +158,18 @@ impl OverrideLabelsEvaluator {
})
}

pub(crate) fn evaluate<'req, F>(
pub(crate) fn evaluate<'exec, 'req>(
&self,
get_client_request: F,
) -> Result<HashSet<String>, LabelEvaluationError>
where
F: FnOnce() -> ClientRequestDetails<'req>,
{
client_request: &ClientRequestDetails<'exec, 'req>,
) -> Result<HashSet<String>, LabelEvaluationError> {
let mut active_flags = self.static_enabled_labels.clone();

if self.expressions.is_empty() {
return Ok(active_flags);
}

let client_request = get_client_request();
let mut target = VrlTargetValue {
value: VrlValue::Object(BTreeMap::from([(
"request".into(),
(&client_request).into(),
)])),
value: VrlValue::Object(BTreeMap::from([("request".into(), client_request.into())])),
metadata: VrlValue::Object(BTreeMap::new()),
secrets: VrlSecrets::default(),
};
Expand Down
41 changes: 19 additions & 22 deletions lib/executor/src/execution/plan.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
borrow::Cow,
collections::{BTreeSet, HashMap},
};
use std::collections::{BTreeSet, HashMap};

use bytes::{BufMut, Bytes};
use futures::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
Expand Down Expand Up @@ -51,26 +48,26 @@ use crate::{
},
};

pub struct OperationDetails<'a> {
pub name: Option<String>,
pub query: Cow<'a, str>,
pub kind: &'a str,
pub struct OperationDetails<'exec> {
pub name: Option<&'exec str>,
pub query: &'exec str,
pub kind: &'static str,
}

pub struct ClientRequestDetails<'a> {
pub method: Method,
pub url: http::Uri,
pub headers: &'a NtexHeaderMap,
pub operation: OperationDetails<'a>,
pub struct ClientRequestDetails<'exec, 'req> {
pub method: &'req Method,
pub url: &'req http::Uri,
pub headers: &'req NtexHeaderMap,
pub operation: OperationDetails<'exec>,
}

pub struct QueryPlanExecutionContext<'exec> {
pub struct QueryPlanExecutionContext<'exec, 'req> {
pub query_plan: &'exec QueryPlan,
pub projection_plan: &'exec Vec<FieldProjectionPlan>,
pub headers_plan: &'exec HeaderRulesPlan,
pub variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
pub extensions: Option<HashMap<String, sonic_rs::Value>>,
pub client_request: ClientRequestDetails<'exec>,
pub client_request: &'exec ClientRequestDetails<'exec, 'req>,
pub introspection_context: &'exec IntrospectionContext<'exec, 'static>,
pub operation_type_name: &'exec str,
pub executors: &'exec SubgraphExecutorMap,
Expand All @@ -82,8 +79,8 @@ pub struct PlanExecutionOutput {
pub headers: HeaderMap,
}

pub async fn execute_query_plan<'exec>(
ctx: QueryPlanExecutionContext<'exec>,
pub async fn execute_query_plan<'exec, 'req>(
ctx: QueryPlanExecutionContext<'exec, 'req>,
) -> Result<PlanExecutionOutput, PlanExecutionError> {
let init_value = if let Some(introspection_query) = ctx.introspection_context.query {
resolve_introspection(introspection_query, ctx.introspection_context)
Expand All @@ -96,7 +93,7 @@ pub async fn execute_query_plan<'exec>(
ctx.variable_values,
ctx.executors,
ctx.introspection_context.metadata,
&ctx.client_request,
ctx.client_request,
ctx.headers_plan,
ctx.jwt_auth_forwarding,
// Deduplicate subgraph requests only if the operation type is a query
Expand Down Expand Up @@ -137,11 +134,11 @@ pub async fn execute_query_plan<'exec>(
})
}

pub struct Executor<'exec> {
pub struct Executor<'exec, 'req> {
variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
schema_metadata: &'exec SchemaMetadata,
executors: &'exec SubgraphExecutorMap,
client_request: &'exec ClientRequestDetails<'exec>,
client_request: &'exec ClientRequestDetails<'exec, 'req>,
headers_plan: &'exec HeaderRulesPlan,
jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
dedupe_subgraph_requests: bool,
Expand Down Expand Up @@ -231,12 +228,12 @@ struct PreparedFlattenData {
representation_hash_to_index: HashMap<u64, usize>,
}

impl<'exec> Executor<'exec> {
impl<'exec, 'req> Executor<'exec, 'req> {
pub fn new(
variable_values: &'exec Option<HashMap<String, sonic_rs::Value>>,
executors: &'exec SubgraphExecutorMap,
schema_metadata: &'exec SchemaMetadata,
client_request: &'exec ClientRequestDetails<'exec>,
client_request: &'exec ClientRequestDetails<'exec, 'req>,
headers_plan: &'exec HeaderRulesPlan,
jwt_forwarding_plan: &'exec Option<JwtAuthForwardingPlan>,
dedupe_subgraph_requests: bool,
Expand Down
8 changes: 4 additions & 4 deletions lib/executor/src/executors/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ impl SubgraphExecutorMap {
Ok(subgraph_executor_map)
}

pub async fn execute<'a>(
pub async fn execute<'a, 'req>(
&self,
subgraph_name: &str,
execution_request: HttpExecutionRequest<'a>,
client_request: &ClientRequestDetails<'a>,
client_request: &ClientRequestDetails<'a, 'req>,
) -> HttpExecutionResponse {
match self.get_or_create_executor(subgraph_name, client_request) {
Ok(Some(executor)) => executor.execute(execution_request).await,
Expand Down Expand Up @@ -164,7 +164,7 @@ impl SubgraphExecutorMap {
fn get_or_create_executor(
&self,
subgraph_name: &str,
client_request: &ClientRequestDetails<'_>,
client_request: &ClientRequestDetails<'_, '_>,
) -> Result<Option<SubgraphExecutorBoxedArc>, SubgraphExecutorError> {
let from_expression =
self.get_or_create_executor_from_expression(subgraph_name, client_request)?;
Expand All @@ -183,7 +183,7 @@ impl SubgraphExecutorMap {
fn get_or_create_executor_from_expression(
&self,
subgraph_name: &str,
client_request: &ClientRequestDetails<'_>,
client_request: &ClientRequestDetails<'_, '_>,
) -> Result<Option<SubgraphExecutorBoxedArc>, SubgraphExecutorError> {
if let Some(expression) = self.expressions_by_subgraph.get(subgraph_name) {
let original_url_value = VrlValue::Bytes(Bytes::from(
Expand Down
Loading
Loading