Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 164 additions & 79 deletions crates/core/src/subscription/module_subscription_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::subscription::delta::eval_delta;
use crate::subscription::record_exec_metrics;
use hashbrown::hash_map::OccupiedError;
use hashbrown::{HashMap, HashSet};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use spacetimedb_client_api_messages::websocket::{
BsatnFormat, CompressableQueryUpdate, Compression, FormatSwitch, JsonFormat, QueryId, QueryUpdate, WebsocketFormat,
};
Expand All @@ -22,6 +22,7 @@ use spacetimedb_lib::metrics::ExecutionMetrics;
use spacetimedb_lib::{ConnectionId, Identity};
use spacetimedb_primitives::TableId;
use spacetimedb_subscription::SubscriptionPlan;
use std::collections::LinkedList;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

Expand All @@ -31,6 +32,7 @@ use std::sync::Arc;
type ClientId = (Identity, ConnectionId);
type Query = Arc<Plan>;
type Client = Arc<ClientConnectionSender>;
type SwitchedTableUpdate = FormatSwitch<TableUpdate<BsatnFormat>, TableUpdate<JsonFormat>>;
type SwitchedDbUpdate = FormatSwitch<ws::DatabaseUpdate<BsatnFormat>, ws::DatabaseUpdate<JsonFormat>>;

/// ClientQueryId is an identifier for a query set by the client.
Expand Down Expand Up @@ -156,9 +158,7 @@ impl QueryState {

// This returns all of the clients listening to a query. If a client has multiple subscriptions for this query, it will appear twice.
fn all_clients(&self) -> impl Iterator<Item = &ClientId> {
let legacy_iter = self.legacy_subscribers.iter();
let subscriptions_iter = self.subscriptions.iter();
legacy_iter.chain(subscriptions_iter)
itertools::chain(&self.legacy_subscribers, &self.subscriptions)
}
}

Expand Down Expand Up @@ -202,22 +202,6 @@ impl SubscriptionManager {
self.queries.get(hash).map(|state| state.query.clone())
}

/// Return all clients that are subscribed to a particular query.
/// Note this method filters out clients that have been dropped.
/// If you need all clients currently maintained by the manager,
/// regardless of drop status, do not use this method.
pub fn clients_for_query(&self, hash: &QueryHash) -> impl Iterator<Item = &ClientId> {
self.queries
.get(hash)
.into_iter()
.flat_map(|query| query.all_clients())
.filter(|id| {
self.clients
.get(*id)
.is_some_and(|info| !info.dropped.load(Ordering::Acquire))
})
}

pub fn calculate_gauge_stats(&self) -> SubscriptionGaugeStats {
let num_queries = self.queries.len();
let num_connections = self.clients.len();
Expand Down Expand Up @@ -520,25 +504,87 @@ impl SubscriptionManager {
// Put the main work on a rayon compute thread.
rayon::scope(|_| {
let span = tracing::info_span!("eval_incr").entered();
let (updates, errs, metrics) = tables

type ClientQueryUpdate<F> = (<F as WebsocketFormat>::QueryUpdate, /* num_rows */ u64);
struct ClientUpdate<'a> {
id: &'a ClientId,
table_id: TableId,
table_name: &'a str,
update: FormatSwitch<ClientQueryUpdate<BsatnFormat>, ClientQueryUpdate<JsonFormat>>,
}

// rayon has a fold-reduce idiom, which we use here. First, rayon splits
// the task onto a number of worker threads, and then on each thread, we *fold*
// each item of the iterator into an accumulator. This is that accumulator
// state; we use vecs because we're only ever going to be appending onto the
// end, so reallocation is more or less amortized.
#[derive(Default)]
struct FoldState<'a> {
updates: Vec<ClientUpdate<'a>>,
errs: Vec<(&'a ClientId, Box<str>)>,
metrics: ExecutionMetrics,
}

// Next, we *reduce* the result of multiple threads into one final output.
// This is the accumulator for that; we use `VecList`s here because they
// have good characteristics for this use case, namely cheap appension
// of the result of each thread.
#[derive(Default)]
struct ReduceState<'a> {
updates: VecList<ClientUpdate<'a>>,
errs: VecList<(&'a ClientId, Box<str>)>,
metrics: ExecutionMetrics,
}
impl<'a> ReduceState<'a> {
/// Convert the result of a single-thread fold to get ready for a multi-thread reduce.
fn from_fold(acc: FoldState<'a>) -> Self {
Self {
updates: acc.updates.into(),
errs: acc.errs.into(),
metrics: acc.metrics,
}
}
/// Concatenate this `ReduceState` with another one.
///
/// This is a cheap operation, since `LinkedList::append` is `O(1)`.
fn append(mut self, rhs: Self) -> Self {
self.updates.append(rhs.updates);
self.errs.append(rhs.errs);
self.metrics.merge(rhs.metrics);
self
}
}

let plans = tables
.iter()
.filter(|table| !table.inserts.is_empty() || !table.deletes.is_empty())
.map(|DatabaseTableUpdate { table_id, .. }| table_id)
.filter_map(|table_id| self.tables.get(table_id))
.filter_map(|DatabaseTableUpdate { table_id, .. }| self.tables.get(table_id))
.flatten()
.collect::<HashSet<_>>()
.par_iter()
.filter_map(|&hash| self.queries.get(hash).map(|state| (hash, &state.query)))
.flat_map_iter(|(hash, plan)| {
plan.plans_fragments()
.map(move |plan_fragment| (&plan.sql, hash, plan_fragment, ExecutionMetrics::default()))
// deduplicate queries by their hash
.filter({
let mut seen = HashSet::new();
// (HashSet::insert returns true for novel elements)
move |&hash| seen.insert(hash)
})
.flat_map(|hash| {
let qstate = &self.queries[hash];
qstate
.query
.plans_fragments()
.map(move |plan_fragment| (qstate, plan_fragment))
})
// collect all plan fragments we want to do work on into a
// single vec, which is more efficient for rayon to work with.
.collect::<Vec<_>>();

let ReduceState { updates, errs, metrics } = plans
.into_par_iter()
// If N clients are subscribed to a query,
// we copy the DatabaseTableUpdate N times,
// which involves cloning BSATN (binary) or product values (json).
.map(|(sql, hash, plan, mut metrics)| {
.fold(FoldState::default, |mut acc, (qstate, plan)| {
let table_id = plan.subscribed_table_id();
let table_name: Box<str> = plan.subscribed_table_name().into();
let table_name = plan.subscribed_table_name();
// Store at most one copy of the serialization to BSATN x Compression
// and ditto for the "serialization" for JSON.
// Each subscriber gets to pick which of these they want,
Expand Down Expand Up @@ -574,23 +620,28 @@ impl SubscriptionManager {
(update, num_rows)
}

let updates = match eval_delta(tx, &mut metrics, plan) {
// filter out clients that've dropped
let clients_for_query = qstate.all_clients().filter(|id| {
self.clients
.get(*id)
.is_some_and(|info| !info.dropped.load(Ordering::Acquire))
});

match eval_delta(tx, &mut acc.metrics, plan) {
Err(err) => {
tracing::error!(
message = "Query errored during tx update",
sql = sql,
sql = qstate.query.sql,
reason = ?err,
);
Err(self
.clients_for_query(hash)
.map(|id| (id, err.to_string().into_boxed_str()))
.collect::<Vec<_>>())
acc.errs
.extend(clients_for_query.map(|id| (id, err.to_string().into_boxed_str())))
}
// The query didn't return any rows to update
Ok(None) => Ok(vec![]),
Ok(Some(delta_updates)) => Ok(self
.clients_for_query(hash)
.map(|id| {
Ok(None) => {}
// The query did return updates - process them and add them to the accumulator
Ok(Some(delta_updates)) => {
let row_iter = clients_for_query.map(|id| {
let client = &self.clients[id].outbound_ref;
let update = match client.config.protocol {
Protocol::Binary => Bsatn(memo_encode::<BsatnFormat>(
Expand All @@ -601,69 +652,59 @@ impl SubscriptionManager {
Compression::Gzip => &mut ops_bin_gzip,
Compression::None => &mut ops_bin_none,
},
&mut metrics,
&mut acc.metrics,
)),
Protocol::Text => Json(memo_encode::<JsonFormat>(
&delta_updates,
client,
&mut ops_json,
&mut metrics,
&mut acc.metrics,
)),
};
(id, table_id, table_name.clone(), update)
})
.collect::<Vec<_>>()),
};

(updates, metrics)
})
.fold(
|| (vec![], vec![], ExecutionMetrics::default()),
|(mut rows, mut errs, mut agg_metrics), (result, metrics)| {
match result {
Ok(x) => {
rows.extend(x);
}
Err(x) => {
errs.extend(x);
}
ClientUpdate {
id,
table_id,
table_name,
update,
}
});
acc.updates.extend(row_iter);
}
agg_metrics.merge(metrics);
(rows, errs, agg_metrics)
},
)
.reduce_with(|(mut acc_rows, mut acc_errs, mut acc_metrics), (rows, errs, metrics)| {
acc_rows.extend(rows);
acc_errs.extend(errs);
acc_metrics.merge(metrics);
(acc_rows, acc_errs, acc_metrics)
}

acc
})
.unwrap_or_default();
// it would be nice to use `.collect_into_vec()` here, and reap the
// benefits of having an `IndexedParallelIterator`, but we actually
// produce many elements per `SubscriptionPlan` and would need to
// `flatten` them, meaning it effectively becomes unindexed.
.map(ReduceState::from_fold)
.reduce(ReduceState::default, ReduceState::append);

record_exec_metrics(&WorkloadType::Update, database_identity, metrics);

let clients_with_errors = errs.iter().map(|(id, _)| id).collect::<HashSet<_>>();
let clients_with_errors = errs.iter().map(|(id, _)| *id).collect::<HashSet<_>>();

let mut eval = updates
.into_iter()
// Filter out clients whose subscriptions failed
.filter(|(id, ..)| !clients_with_errors.contains(id))
.filter(|upd| !clients_with_errors.contains(upd.id))
// For each subscriber, aggregate all the updates for the same table.
// That is, we build a map `(subscriber_id, table_id) -> updates`.
// A particular subscriber uses only one format,
// so their `TableUpdate` will contain either JSON (`Protocol::Text`)
// or BSATN (`Protocol::Binary`).
.fold(
HashMap::<(&ClientId, TableId), FormatSwitch<TableUpdate<_>, TableUpdate<_>>>::new(),
|mut tables, (id, table_id, table_name, update)| {
match tables.entry((id, table_id)) {
Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(update) {
HashMap::<(&ClientId, TableId), SwitchedTableUpdate>::new(),
|mut tables, upd| {
match tables.entry((upd.id, upd.table_id)) {
Entry::Occupied(mut entry) => match entry.get_mut().zip_mut(upd.update) {
Bsatn((tbl_upd, update)) => tbl_upd.push(update),
Json((tbl_upd, update)) => tbl_upd.push(update),
},
Entry::Vacant(entry) => drop(entry.insert(match update {
Bsatn(update) => Bsatn(TableUpdate::new(table_id, table_name, update)),
Json(update) => Json(TableUpdate::new(table_id, table_name, update)),
Entry::Vacant(entry) => drop(entry.insert(match upd.update {
Bsatn(update) => Bsatn(TableUpdate::new(upd.table_id, upd.table_name.into(), update)),
Json(update) => Json(TableUpdate::new(upd.table_id, upd.table_name.into(), update)),
})),
}
tables
Expand Down Expand Up @@ -752,6 +793,50 @@ fn send_to_client(client: &ClientConnectionSender, message: impl Into<Serializab
}
}

/// A linked list of vecs.
///
/// To quote the docs for [`ParallelIterator::collect_vec_list`] (which I (Noa) also wrote):
///
/// > This is useful when you need to condense a parallel iterator into a
/// > collection, but have no specific requirements for what that collection
/// > should be. [...] This is a very efficient way to collect an unindexed
/// > parallel iterator, without much intermediate data movement.
struct VecList<T>(LinkedList<Vec<T>>);

impl<T> Default for VecList<T> {
fn default() -> Self {
Self(Default::default())
}
}
impl<T> From<Vec<T>> for VecList<T> {
fn from(vec: Vec<T>) -> Self {
let mut list = LinkedList::new();
if !vec.is_empty() {
list.push_back(vec);
}
Self(list)
}
}
impl<T> VecList<T> {
/// Append another `VecList` onto this one.
///
/// This operation is `O(1)`.
fn append(&mut self, mut other: Self) {
self.0.append(&mut other.0)
}
/// Iterate over the individual elements of this `VecList`.
fn iter(&self) -> impl Iterator<Item = &T> {
self.0.iter().flatten()
}
}
impl<T> IntoIterator for VecList<T> {
type Item = T;
type IntoIter = std::iter::Flatten<std::collections::linked_list::IntoIter<Vec<T>>>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter().flatten()
}
}

#[cfg(test)]
mod tests {
use std::{sync::Arc, time::Duration};
Expand Down
Loading