Skip to content

Commit

Permalink
feat: add header to specify query timeout (#4988)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Polanco <alxpolanc@gmail.com>
Co-authored-by: Alberto Schiabel <jkomyno@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 18, 2024
1 parent 155af62 commit 4fe298b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 36 deletions.
29 changes: 29 additions & 0 deletions query-engine/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use query_structure::DomainError;
use thiserror::Error;
use user_facing_errors::UnknownError;

use crate::response_ir::{Item, Map};

#[derive(Debug, Error)]
#[error(
"Error converting field \"{field}\" of expected non-nullable type \"{expected_type}\", found incompatible value of \"{found}\"."
Expand Down Expand Up @@ -62,6 +64,9 @@ pub enum CoreError {

#[error("Error in batch request {request_idx}: {error}")]
BatchError { request_idx: usize, error: Box<CoreError> },

#[error("Query timed out")]
QueryTimeout,
}

impl CoreError {
Expand Down Expand Up @@ -227,3 +232,27 @@ impl From<CoreError> for user_facing_errors::Error {
}
}
}

#[derive(Debug, serde::Serialize, PartialEq)]
pub struct ExtendedUserFacingError {
#[serde(flatten)]
user_facing_error: user_facing_errors::Error,

#[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")]
extensions: Map,
}

impl ExtendedUserFacingError {
pub fn set_extension(&mut self, key: String, val: serde_json::Value) {
self.extensions.entry(key).or_insert(Item::Json(val));
}
}

impl From<CoreError> for ExtendedUserFacingError {
fn from(error: CoreError) -> Self {
ExtendedUserFacingError {
user_facing_error: error.into(),
extensions: Default::default(),
}
}
}
29 changes: 0 additions & 29 deletions query-engine/core/src/interactive_transactions/error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
use thiserror::Error;

use crate::{
response_ir::{Item, Map},
CoreError,
};

#[derive(Debug, Error, PartialEq)]
pub enum TransactionError {
#[error("Unable to start a transaction in the given time.")]
Expand All @@ -22,27 +17,3 @@ pub enum TransactionError {
#[error("Unexpected response: {reason}.")]
Unknown { reason: String },
}

#[derive(Debug, serde::Serialize, PartialEq)]
pub struct ExtendedTransactionUserFacingError {
#[serde(flatten)]
user_facing_error: user_facing_errors::Error,

#[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")]
extensions: Map,
}

impl ExtendedTransactionUserFacingError {
pub fn set_extension(&mut self, key: String, val: serde_json::Value) {
self.extensions.entry(key).or_insert(Item::Json(val));
}
}

impl From<CoreError> for ExtendedTransactionUserFacingError {
fn from(error: CoreError) -> Self {
ExtendedTransactionUserFacingError {
user_facing_error: error.into(),
extensions: Default::default(),
}
}
}
4 changes: 2 additions & 2 deletions query-engine/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ pub mod telemetry;

pub use self::telemetry::*;
pub use self::{
error::{CoreError, FieldConversionError},
error::{CoreError, ExtendedUserFacingError, FieldConversionError},
executor::{QueryExecutor, TransactionOptions},
interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId},
interactive_transactions::{TransactionError, TxId},
query_document::*,
};

Expand Down
48 changes: 43 additions & 5 deletions query-engine/query-engine/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ use opentelemetry::trace::TraceContextExt;
use opentelemetry::{global, propagation::Extractor};
use query_core::helpers::*;
use query_core::telemetry::capturing::TxTraceExt;
use query_core::{telemetry, ExtendedTransactionUserFacingError, TransactionOptions, TxId};
use query_core::{telemetry, ExtendedUserFacingError, TransactionOptions, TxId};
use request_handlers::{dmmf, render_graphql_schema, RequestBody, RequestHandler};
use serde::Serialize;
use serde_json::json;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Instant;
use std::time::{Duration, Instant};
use tracing::{field, Instrument, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;

Expand Down Expand Up @@ -116,6 +116,8 @@ async fn request_handler(cx: Arc<PrismaContext>, req: Request<Body>) -> Result<R
let tx_id = transaction_id(headers);
let tracing_cx = get_parent_span_context(headers);

let query_timeout = query_timeout(headers);

let span = if tx_id.is_none() {
let span = info_span!("prisma:engine", user_facing = true);
span.set_parent(tracing_cx);
Expand Down Expand Up @@ -169,13 +171,14 @@ async fn request_handler(cx: Arc<PrismaContext>, req: Request<Body>) -> Result<R
let full_body = hyper::body::to_bytes(body_start).await?;
let serialized_body = RequestBody::try_from_slice(full_body.as_ref(), cx.engine_protocol());

let capture_config = &capture_config;
let work = async move {
match serialized_body {
Ok(body) => {
let handler = RequestHandler::new(cx.executor(), cx.query_schema(), cx.engine_protocol());
let mut result = handler.handle(body, tx_id, traceparent).instrument(span).await;

if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config {
if let telemetry::capturing::Capturer::Enabled(capturer) = capture_config {
let telemetry = capturer.fetch_captures().await;
if let Some(telemetry) = telemetry {
result.set_extension("traces".to_owned(), json!(telemetry.traces));
Expand All @@ -202,7 +205,32 @@ async fn request_handler(cx: Arc<PrismaContext>, req: Request<Body>) -> Result<R
}
};

work.await
let query_timeout_fut = async {
match query_timeout {
Some(timeout) => tokio::time::sleep(timeout).await,
// Never return if timeout isn't set.
None => std::future::pending().await,
}
};

tokio::select! {
_ = query_timeout_fut => {
let captured_telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config {
capturer.fetch_captures().await
} else {
None
};

// Note: this relies on the fact that client will rollback the transaction after the
// error. If the client continues using this transaction (and later commits it), data
// corruption might happen because some write queries (but not all of them) might be
// already executed by the database before the timeout is fired.
Ok(err_to_http_resp(query_core::CoreError::QueryTimeout, captured_telemetry))
}
result = work => {
result
}
}
}

/// Expose the GraphQL playground if enabled.
Expand Down Expand Up @@ -454,11 +482,13 @@ fn err_to_http_resp(
query_core::TransactionError::Unknown { reason: _ } => StatusCode::INTERNAL_SERVER_ERROR,
},

query_core::CoreError::QueryTimeout => StatusCode::REQUEST_TIMEOUT,

// All other errors are treated as 500s, most of these paths should never be hit, only connector errors may occur.
_ => StatusCode::INTERNAL_SERVER_ERROR,
};

let mut err: ExtendedTransactionUserFacingError = err.into();
let mut err: ExtendedUserFacingError = err.into();
if let Some(telemetry) = captured_telemetry {
err.set_extension("traces".to_owned(), json!(telemetry.traces));
err.set_extension("logs".to_owned(), json!(telemetry.logs));
Expand Down Expand Up @@ -513,6 +543,14 @@ fn transaction_id(headers: &HeaderMap) -> Option<TxId> {
.map(TxId::from)
}

fn query_timeout(headers: &HeaderMap) -> Option<Duration> {
headers
.get("X-query-timeout")
.and_then(|h| h.to_str().ok())
.and_then(|value| value.parse::<u64>().ok())
.map(Duration::from_millis)
}

/// If the client sends us a trace and span id, extracting a new context if the
/// headers are set. If not, returns current context.
fn get_parent_span_context(headers: &HeaderMap) -> opentelemetry::Context {
Expand Down

0 comments on commit 4fe298b

Please sign in to comment.