Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ pub enum DBError {
Other(#[from] anyhow::Error),
#[error(transparent)]
TypeError(#[from] TypingError),
#[error("{error}, executing: `{sql}`")]
WithSql {
#[source]
error: Box<DBError>,
sql: Box<str>,
},
}

impl From<bflatn_to::Error> for DBError {
Expand Down
19 changes: 12 additions & 7 deletions crates/core/src/subscription/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use spacetimedb_execution::{pipelined::PipelinedProject, Datastore, DeltaStore};
use spacetimedb_lib::{metrics::ExecutionMetrics, Identity};
use spacetimedb_primitives::TableId;

use crate::error::DBError;
use crate::{db::db_metrics::DB_METRICS, execution_context::WorkloadType, worker_metrics::WORKER_METRICS};

pub mod delta;
Expand Down Expand Up @@ -133,23 +134,27 @@ pub fn execute_plans<Tx, F>(
comp: Compression,
tx: &Tx,
update_type: TableUpdateType,
) -> Result<(DatabaseUpdate<F>, ExecutionMetrics)>
) -> Result<(DatabaseUpdate<F>, ExecutionMetrics), DBError>
where
Tx: Datastore + DeltaStore + Sync,
F: WebsocketFormat,
{
plans
.par_iter()
.flat_map_iter(|plan| plan.plans_fragments())
.map(|plan| (plan, plan.subscribed_table_id(), plan.subscribed_table_name()))
.map(|(plan, table_id, table_name)| {
.flat_map_iter(|plan| plan.plans_fragments().map(|fragment| (plan.sql(), fragment)))
.map(|(sql, plan)| (sql, plan, plan.subscribed_table_id(), plan.subscribed_table_name()))
.map(|(sql, plan, table_id, table_name)| {
plan.physical_plan()
.clone()
.optimize()
.map(PipelinedProject::from)
.and_then(|plan| collect_table_update(&[plan], table_id, table_name.into(), comp, tx, update_type))
.map(|plan| (sql, PipelinedProject::from(plan)))
.and_then(|(_, plan)| collect_table_update(&[plan], table_id, table_name.into(), comp, tx, update_type))
.map_err(|err| DBError::WithSql {
sql: sql.into(),
error: Box::new(DBError::Other(err)),
})
})
.collect::<Result<Vec<_>>>()
.collect::<Result<Vec<_>, _>>()
.map(|table_updates_with_metrics| {
let n = table_updates_with_metrics.len();
let mut tables = Vec::with_capacity(n);
Expand Down
120 changes: 56 additions & 64 deletions crates/core/src/subscription/module_subscription_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use spacetimedb_client_api_messages::websocket::{
};
use spacetimedb_execution::pipelined::PipelinedProject;
use spacetimedb_expr::check::parse_and_type_sub;
use spacetimedb_expr::errors::TypingError;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::metrics::ExecutionMetrics;
use spacetimedb_lib::Identity;
Expand Down Expand Up @@ -104,8 +105,11 @@ type FullSubscriptionUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::

/// A utility for sending an error message to a client and returning early
macro_rules! return_on_err {
($expr:expr, $handler:expr) => {
match $expr {
($expr:expr, $sql:expr, $handler:expr) => {
match $expr.map_err(|err| DBError::WithSql {
sql: $sql.into(),
error: Box::new(DBError::Other(err.into())),
}) {
Ok(val) => val,
Err(e) => {
// TODO: Handle errors sending messages.
Expand All @@ -117,9 +121,8 @@ macro_rules! return_on_err {
}

/// Hash a sql query, using the caller's identity if necessary
fn hash_query(sql: &str, tx: &TxId, auth: &AuthCtx) -> Result<QueryHash, DBError> {
fn hash_query(sql: &str, tx: &TxId, auth: &AuthCtx) -> Result<QueryHash, TypingError> {
parse_and_type_sub(sql, &SchemaViewer::new(tx, auth), auth)
.map_err(DBError::from)
.map(|(_, has_param)| QueryHash::from_string(sql, auth.caller, has_param))
}

Expand Down Expand Up @@ -185,10 +188,10 @@ impl ModuleSubscriptions {

Ok(match sender.config.protocol {
Protocol::Binary => collect_table_update(&plans, table_id, table_name.into(), comp, &tx, update_type)
.map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics))?,
.map(|(table_update, metrics)| (FormatSwitch::Bsatn(table_update), metrics)),
Protocol::Text => collect_table_update(&plans, table_id, table_name.into(), comp, &tx, update_type)
.map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics))?,
})
.map(|(table_update, metrics)| (FormatSwitch::Json(table_update), metrics)),
}?)
}

fn evaluate_queries(
Expand Down Expand Up @@ -254,7 +257,7 @@ impl ModuleSubscriptions {
let query = super::query::WHITESPACE.replace_all(&request.query, " ");
let sql = query.trim();

let hash = return_on_err!(hash_query(sql, &tx, &auth), send_err_msg);
let hash = return_on_err!(hash_query(sql, &tx, &auth), sql, send_err_msg);

let existing_query = {
let guard = self.subscriptions.read();
Expand All @@ -265,11 +268,13 @@ impl ModuleSubscriptions {
existing_query
.map(Ok)
.unwrap_or_else(|| compile_read_only_query(&auth, &tx, sql).map(Arc::new)),
sql,
send_err_msg
);

let (table_rows, metrics) = return_on_err!(
self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Subscribe),
query.sql(),
send_err_msg
);

Expand Down Expand Up @@ -353,18 +358,11 @@ impl ModuleSubscriptions {
self.relational_db.release_tx(tx);
});
let auth = AuthCtx::new(self.owner_identity, sender.id.identity);
let eval_result =
self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Unsubscribe);

// If execution error, send to client
let (table_rows, metrics) = match eval_result {
Ok(ok) => ok,
Err(e) => {
// Apparently we ignore errors sending messages.
let _ = send_err_msg(e.to_string().into());
return Ok(());
}
};
let (table_rows, metrics) = return_on_err!(
self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Unsubscribe),
query.sql(),
send_err_msg
);

record_exec_metrics(
&WorkloadType::Subscribe,
Expand Down Expand Up @@ -499,12 +497,12 @@ impl ModuleSubscriptions {
continue;
}

let hash = return_on_err!(hash_query(sql, &tx, &auth), send_err_msg);
let hash = return_on_err!(hash_query(sql, &tx, &auth), sql, send_err_msg);

if let Some(unit) = guard.query(&hash) {
queries.push(unit);
} else {
let compiled = return_on_err!(compile_read_only_query(&auth, &tx, sql), send_err_msg);
let compiled = return_on_err!(compile_read_only_query(&auth, &tx, sql), sql, send_err_msg);
queries.push(Arc::new(compiled));
}
}
Expand Down Expand Up @@ -721,7 +719,7 @@ pub struct WriteConflict;
mod tests {
use super::{AssertTxFn, ModuleSubscriptions};
use crate::client::messages::{
SerializableMessage, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
SerializableMessage, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage,
TransactionUpdateMessage,
};
use crate::client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName, Protocol};
Expand Down Expand Up @@ -1075,6 +1073,21 @@ mod tests {
Ok(())
}

fn check_subscription_err(sql: &str, result: Option<SerializableMessage>) {
if let Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(SubscriptionError { message, .. }),
..
})) = result
{
assert!(
message.contains(sql),
"Expected error message to contain the SQL query: {sql}, but got: {message}",
);
return;
}
panic!("Expected a subscription error message, but got: {:?}", result);
}

/// Test that clients receive error messages on subscribe
#[tokio::test]
async fn subscribe_single_error() -> anyhow::Result<()> {
Expand All @@ -1087,15 +1100,11 @@ mod tests {
db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?;

// Subscribe to an invalid query (r is not in scope)
subscribe_single(&subs, "select r.* from t", tx, &mut 0)?;
let sql = "select r.* from t";
subscribe_single(&subs, sql, tx, &mut 0)?;

check_subscription_err(sql, rx.recv().await);

assert!(matches!(
rx.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(..),
..
}))
));
Ok(())
}

Expand All @@ -1111,15 +1120,11 @@ mod tests {
db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?;

// Subscribe to an invalid query (r is not in scope)
subscribe_multi(&subs, &["select r.* from t"], tx, &mut 0)?;
let sql = "select r.* from t";
subscribe_multi(&subs, &[sql], tx, &mut 0)?;

check_subscription_err(sql, rx.recv().await);

assert!(matches!(
rx.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(..),
..
}))
));
Ok(())
}

Expand Down Expand Up @@ -1147,7 +1152,8 @@ mod tests {
let mut query_id = 0;

// Subscribe to `t`
subscribe_single(&subs, "select * from t where id = 1", tx.clone(), &mut query_id)?;
let sql = "select * from t where id = 1";
subscribe_single(&subs, sql, tx.clone(), &mut query_id)?;

// The initial subscription should succeed
assert!(matches!(
Expand All @@ -1169,13 +1175,8 @@ mod tests {
// Specifically that we do not recompile queries on unsubscribe.
// We execute the cached plan which in this case is an index scan.
// The index no longer exists, and therefore it fails.
assert!(matches!(
rx.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(..),
..
}))
));
check_subscription_err(sql, rx.recv().await);

Ok(())
}

Expand Down Expand Up @@ -1203,7 +1204,8 @@ mod tests {
let mut query_id = 0;

// Subscribe to `t`
subscribe_multi(&subs, &["select * from t where id = 1"], tx.clone(), &mut query_id)?;
let sql = "select * from t where id = 1";
subscribe_multi(&subs, &[sql], tx.clone(), &mut query_id)?;

// The initial subscription should succeed
assert!(matches!(
Expand All @@ -1225,13 +1227,8 @@ mod tests {
// Specifically that we do not recompile queries on unsubscribe.
// We execute the cached plan which in this case is an index scan.
// The index no longer exists, and therefore it fails.
assert!(matches!(
rx.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(..),
..
}))
));
check_subscription_err(sql, rx.recv().await);

Ok(())
}

Expand All @@ -1256,8 +1253,8 @@ mod tests {
.unwrap()
})
})?;

subscribe_single(&subs, "select t.* from t join s on t.id = s.id", tx, &mut 0)?;
let sql = "select t.* from t join s on t.id = s.id";
subscribe_single(&subs, sql, tx, &mut 0)?;

// The initial subscription should succeed
assert!(matches!(
Expand Down Expand Up @@ -1285,13 +1282,8 @@ mod tests {
// Specifically, plans are cached on the initial subscribe.
// Hence we execute a cached plan which happens to be an index join.
// We've removed the index on `s`, and therefore it fails.
assert!(matches!(
rx.recv().await,
Some(SerializableMessage::Subscription(SubscriptionMessage {
result: SubscriptionResult::Error(..),
..
}))
));
check_subscription_err(sql, rx.recv().await);

Ok(())
}

Expand Down
15 changes: 13 additions & 2 deletions crates/core/src/subscription/module_subscription_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ impl Plan {
pub fn plans_fragments(&self) -> impl Iterator<Item = &SubscriptionPlan> + '_ {
self.plans.iter()
}

/// The `SQL` text of this subscription.
pub fn sql(&self) -> &str {
&self.sql
}
}

/// For each client, we hold a handle for sending messages, and we track the queries they are subscribed to.
Expand Down Expand Up @@ -634,8 +639,14 @@ impl SubscriptionManager {
sql = qstate.query.sql,
reason = ?err,
);
acc.errs
.extend(clients_for_query.map(|id| (id, err.to_string().into_boxed_str())))
let err = DBError::WithSql {
sql: qstate.query.sql.as_str().into(),
error: Box::new(err.into()),
}
.to_string()
.into_boxed_str();

acc.errs.extend(clients_for_query.map(|id| (id, err.clone())))
}
// The query didn't return any rows to update
Ok(None) => {}
Expand Down
Loading