From beec1f637199969f4a79c7e7b50ff61db2eda61b Mon Sep 17 00:00:00 2001 From: Jackson Newhouse Date: Tue, 9 Apr 2024 23:00:02 +0000 Subject: [PATCH] Non-windowed updating aggregates using datafusion. --- Cargo.lock | 30 +- Cargo.toml | 14 +- crates/arroyo-datastream/src/logical.rs | 1 + crates/arroyo-df/src/builder.rs | 20 +- crates/arroyo-df/src/extension/aggregate.rs | 6 +- crates/arroyo-df/src/extension/mod.rs | 68 ++- crates/arroyo-df/src/extension/sink.rs | 50 +- .../src/extension/updating_aggregate.rs | 162 +++++++ crates/arroyo-df/src/lib.rs | 20 +- crates/arroyo-df/src/physical.rs | 3 +- crates/arroyo-df/src/plan/aggregate.rs | 122 ++++- crates/arroyo-df/src/plan/join.rs | 22 +- crates/arroyo-df/src/plan/mod.rs | 3 + .../test/queries/create_table_updating.sql | 20 + .../error_missing_window_basic_tumble.sql | 1 - .../error_no_aggregate_over_debezium.sql | 12 + .../error_no_nested_updating_aggregates.sql | 7 + ...no_inserting_updates_into_non_updating.sql | 2 +- .../src/test/queries/no_updating_joins.sql | 21 + crates/arroyo-rpc/proto/api.proto | 10 + crates/arroyo-rpc/src/df.rs | 61 ++- crates/arroyo-rpc/src/lib.rs | 26 +- .../src/tables/expiring_time_key_map.rs | 455 +++++++++++++----- crates/arroyo-state/src/tables/mod.rs | 5 - .../arroyo-state/src/tables/table_manager.rs | 33 +- .../src/arrow/join_with_expiration.rs | 4 +- crates/arroyo-worker/src/arrow/mod.rs | 1 + .../src/arrow/updating_aggregator.rs | 336 +++++++++++++ crates/arroyo-worker/src/engine.rs | 2 + 29 files changed, 1328 insertions(+), 189 deletions(-) create mode 100644 crates/arroyo-df/src/extension/updating_aggregate.rs create mode 100644 crates/arroyo-df/src/test/queries/create_table_updating.sql create mode 100644 crates/arroyo-df/src/test/queries/error_no_aggregate_over_debezium.sql create mode 100644 crates/arroyo-df/src/test/queries/error_no_nested_updating_aggregates.sql create mode 100644 crates/arroyo-df/src/test/queries/no_updating_joins.sql create mode 100644 crates/arroyo-worker/src/arrow/updating_aggregator.rs diff --git a/Cargo.lock b/Cargo.lock index 05e96886f..220434259 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -345,7 +345,7 @@ dependencies = [ [[package]] name = "arrow-json" version = "50.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=timestamp_formats#6f9ca64cc65b23c5ca9350b57f37560a0b9863fa" +source = "git+https://github.com/ArroyoSystems/arrow-rs?branch=50.0.0/json#8b568ec317aa1f2ab7d1e33ea6fef1fc5697c5fc" dependencies = [ "arrow-array", "arrow-buffer", @@ -2823,7 +2823,7 @@ dependencies = [ [[package]] name = "datafusion" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "ahash 0.8.7", "arrow", @@ -2892,7 +2892,7 @@ dependencies = [ [[package]] name = "datafusion-common" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "ahash 0.8.7", "arrow", @@ -2932,7 +2932,7 @@ dependencies = [ [[package]] name = "datafusion-execution" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "chrono", @@ -2968,7 +2968,7 @@ dependencies = [ [[package]] name = "datafusion-expr" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "ahash 0.8.7", "arrow", @@ -2983,7 +2983,7 @@ dependencies = [ [[package]] name = "datafusion-functions" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "base64 0.21.7", @@ -2997,7 +2997,7 @@ dependencies = [ [[package]] name = "datafusion-functions-array" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "datafusion-common 36.0.0", @@ -3028,7 +3028,7 @@ dependencies = [ [[package]] name = "datafusion-optimizer" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "async-trait", @@ -3079,7 +3079,7 @@ dependencies = [ [[package]] name = "datafusion-physical-expr" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "ahash 0.8.7", "arrow", @@ -3145,7 +3145,7 @@ dependencies = [ [[package]] name = "datafusion-physical-plan" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "ahash 0.8.7", "arrow", @@ -3190,7 +3190,7 @@ dependencies = [ [[package]] name = "datafusion-proto" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "chrono", @@ -3218,7 +3218,7 @@ dependencies = [ [[package]] name = "datafusion-sql" version = "36.0.0" -source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=reset_execs_36#ac33c34268424edf3f7eafa694a14bd5ee4027ea" +source = "git+https://github.com/ArroyoSystems/arrow-datafusion?branch=36_combine_partial#56fe262dfba8370dbc602be6ba7112a356d477e9" dependencies = [ "arrow", "arrow-schema", @@ -4703,7 +4703,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.6", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -6753,7 +6753,7 @@ checksum = "c55e02e35260070b6f716a2423c2ff1c3bb1642ddca6f99e1f26d06268a0e2d2" dependencies = [ "bytes", "heck", - "itertools 0.11.0", + "itertools 0.10.5", "log", "multimap", "once_cell", @@ -6787,7 +6787,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "efb6c9a1dd1def8e2124d17e83a20af56f1570d6c2d2bd9e266ccb768df3840e" dependencies = [ "anyhow", - "itertools 0.11.0", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.52", diff --git a/Cargo.toml b/Cargo.toml index 54fc41fc5..5b0428f95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,10 +60,10 @@ arrow-array = {git = 'https://github.com/ArroyoSystems/arrow-rs', branch = '50.0 arrow-schema = {git = 'https://github.com/ArroyoSystems/arrow-rs', branch = '50.0.0/parquet_bytes'} arrow-json = {git = 'https://github.com/ArroyoSystems/arrow-rs', branch = '50.0.0/json'} object_store = {git = 'https://github.com/ArroyoSystems/arrow-rs', branch = '0.9.0/put_part_api'} -datafusion = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-common = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-execution = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-expr = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-physical-expr = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-physical-plan = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} -datafusion-proto = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = 'reset_execs_36'} +datafusion = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-common = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-execution = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-expr = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-physical-expr = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-physical-plan = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} +datafusion-proto = {git = 'https://github.com/ArroyoSystems/arrow-datafusion', branch = '36_combine_partial'} diff --git a/crates/arroyo-datastream/src/logical.rs b/crates/arroyo-datastream/src/logical.rs index 0e86f6a5a..f1abb3ecc 100644 --- a/crates/arroyo-datastream/src/logical.rs +++ b/crates/arroyo-datastream/src/logical.rs @@ -33,6 +33,7 @@ pub enum OperatorName { TumblingWindowAggregate, SlidingWindowAggregate, SessionWindowAggregate, + UpdatingAggregate, ConnectorSource, ConnectorSink, } diff --git a/crates/arroyo-df/src/builder.rs b/crates/arroyo-df/src/builder.rs index d6cb7bea4..fa2a595f3 100644 --- a/crates/arroyo-df/src/builder.rs +++ b/crates/arroyo-df/src/builder.rs @@ -27,6 +27,7 @@ use datafusion_proto::protobuf::{PhysicalExprNode, PhysicalPlanNode}; use petgraph::graph::{DiGraph, NodeIndex}; use tokio::runtime::Runtime; use tokio::sync::oneshot; +use tracing::info; use crate::extension::debezium::{DEBEZIUM_UNROLLING_EXTENSION_NAME, TO_DEBEZIUM_EXTENSION_NAME}; use crate::extension::key_calculation::KeyCalculationExtension; @@ -124,6 +125,7 @@ impl<'a> Planner<'a> { &self, key_indices: Vec, aggregate: &LogicalPlan, + add_timestamp_field: bool, ) -> DFResult { let physical_plan = self.sync_plan(aggregate)?; let codec = ArroyoPhysicalExtensionCodec { @@ -173,11 +175,18 @@ impl<'a> Planner<'a> { physical_plan_type: Some(PhysicalPlanType::Aggregate(final_aggregate_proto)), }; - let partial_schema = ArroyoSchema::new_keyed( - add_timestamp_field_arrow(partial_schema.clone()), - partial_schema.fields().len(), - key_indices, - ); + info!("partial schema begins as :{:?}", partial_schema); + + let (partial_schema, timestamp_index) = if add_timestamp_field { + ( + add_timestamp_field_arrow(partial_schema.clone()), + partial_schema.fields().len(), + ) + } else { + (partial_schema.clone(), partial_schema.fields().len() - 1) + }; + + let partial_schema = ArroyoSchema::new_keyed(partial_schema, timestamp_index, key_indices); Ok(SplitPlanOutput { partial_aggregation_plan, @@ -362,6 +371,7 @@ impl<'a> TreeNodeVisitor for PlanToGraphVisitor<'a> { } else { vec![] }; + info!("building node: {:?}", node); self.build_extension(input_nodes, arroyo_extension) .map_err(|e| DataFusionError::Plan(format!("error building extension: {}", e)))?; diff --git a/crates/arroyo-df/src/extension/aggregate.rs b/crates/arroyo-df/src/extension/aggregate.rs index f8180b0df..3d8c86e79 100644 --- a/crates/arroyo-df/src/extension/aggregate.rs +++ b/crates/arroyo-df/src/extension/aggregate.rs @@ -78,7 +78,7 @@ impl AggregateExtension { partial_aggregation_plan, partial_schema, finish_plan, - } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate)?; + } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate, true)?; let final_physical_plan = planner.sync_plan(&self.final_calculation)?; let final_physical_plan_node = PhysicalPlanNode::try_from_physical_plan( @@ -126,7 +126,7 @@ impl AggregateExtension { partial_aggregation_plan, partial_schema, finish_plan, - } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate)?; + } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate, true)?; let final_physical_plan = planner.sync_plan(&self.final_calculation)?; let final_physical_plan_node = PhysicalPlanNode::try_from_physical_plan( @@ -251,7 +251,7 @@ impl AggregateExtension { partial_aggregation_plan, partial_schema, finish_plan, - } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate)?; + } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate, true)?; let config = TumblingWindowAggregateOperator { name: "InstantWindow".to_string(), diff --git a/crates/arroyo-df/src/extension/mod.rs b/crates/arroyo-df/src/extension/mod.rs index 4cac9100f..a3a32ab82 100644 --- a/crates/arroyo-df/src/extension/mod.rs +++ b/crates/arroyo-df/src/extension/mod.rs @@ -1,9 +1,13 @@ use std::sync::Arc; use anyhow::Result; +use arrow_schema::{DataType, TimeUnit}; use arroyo_datastream::logical::{LogicalEdge, LogicalNode}; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; -use datafusion_common::{DFSchemaRef, DataFusionError, OwnedTableReference, Result as DFResult}; +use arroyo_rpc::{IS_RETRACT_FIELD, TIMESTAMP_FIELD}; +use datafusion_common::{ + DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, Result as DFResult, +}; use datafusion_expr::{Expr, LogicalPlan, UserDefinedLogicalNode, UserDefinedLogicalNodeCore}; use watermark_node::WatermarkNode; @@ -15,6 +19,7 @@ use self::debezium::{ DebeziumUnrollingExtension, ToDebeziumExtension, DEBEZIUM_UNROLLING_EXTENSION_NAME, TO_DEBEZIUM_EXTENSION_NAME, }; +use self::updating_aggregate::{UpdatingAggregateExtension, UPDATING_AGGREGATE_EXTENSION_NAME}; use self::{ aggregate::{AggregateExtension, AGGREGATE_EXTENSION_NAME}, join::JOIN_NODE_NAME, @@ -33,6 +38,7 @@ pub(crate) mod key_calculation; pub(crate) mod remote_table; pub(crate) mod sink; pub(crate) mod table_source; +pub(crate) mod updating_aggregate; pub(crate) mod watermark_node; pub(crate) mod window_fn; pub(crate) trait ArroyoExtension { @@ -117,6 +123,13 @@ impl<'a> TryFrom<&'a dyn UserDefinedLogicalNode> for &'a dyn ArroyoExtension { node.as_any().downcast_ref::().unwrap(); Ok(to_debezium_extension as &dyn ArroyoExtension) } + UPDATING_AGGREGATE_EXTENSION_NAME => { + let updating_aggregate_extension = node + .as_any() + .downcast_ref::() + .unwrap(); + Ok(updating_aggregate_extension as &dyn ArroyoExtension) + } other => Err(DataFusionError::Plan(format!("unexpected node: {}", other))), } } @@ -186,3 +199,56 @@ impl UserDefinedLogicalNodeCore for TimestampAppendExtension { Self::new(inputs[0].clone(), self.qualifier.clone()) } } + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct IsRetractExtension { + pub(crate) input: LogicalPlan, + pub(crate) schema: DFSchemaRef, +} + +impl IsRetractExtension { + pub(crate) fn new(input: LogicalPlan) -> Self { + let mut output_fields = input.schema().fields().clone(); + let timestamp_index = output_fields.len() - 1; + output_fields[timestamp_index] = DFField::new_unqualified( + TIMESTAMP_FIELD, + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ); + output_fields.push(DFField::new_unqualified( + IS_RETRACT_FIELD, + DataType::Boolean, + false, + )); + let schema = Arc::new( + DFSchema::new_with_metadata(output_fields, input.schema().metadata().clone()).unwrap(), + ); + Self { input, schema } + } +} + +impl UserDefinedLogicalNodeCore for IsRetractExtension { + fn name(&self) -> &str { + "IsRetractExtension" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "IsRetractExtension") + } + + fn from_template(&self, _exprs: &[Expr], inputs: &[LogicalPlan]) -> Self { + Self::new(inputs[0].clone()) + } +} diff --git a/crates/arroyo-df/src/extension/sink.rs b/crates/arroyo-df/src/extension/sink.rs index e360582f7..72b672da8 100644 --- a/crates/arroyo-df/src/extension/sink.rs +++ b/crates/arroyo-df/src/extension/sink.rs @@ -7,11 +7,14 @@ use arroyo_rpc::{ df::{ArroyoSchema, ArroyoSchemaRef}, IS_RETRACT_FIELD, }; -use datafusion_common::{DFSchemaRef, DataFusionError, OwnedTableReference, Result as DFResult}; +use datafusion_common::{ + plan_err, DFSchemaRef, DataFusionError, OwnedTableReference, Result as DFResult, +}; use datafusion_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNodeCore}; use prost::Message; +use tracing::info; use crate::{ builder::{NamedNode, Planner}, @@ -40,12 +43,47 @@ impl SinkExtension { mut schema: DFSchemaRef, mut input: Arc, ) -> DFResult { - if input + let input_is_updating = input .schema() - .has_column_with_unqualified_name(IS_RETRACT_FIELD) - { - if let Table::ConnectorTable(connector_table) = &table { - if connector_table.is_updating() { + .has_column_with_unqualified_name(IS_RETRACT_FIELD); + info!( + "input_is_updating: {}\nschema:\n{:?}\ninput schema:\n{:?}", + input_is_updating, + schema + .fields() + .iter() + .map(|f| f.qualified_name()) + .collect::>() + .join(","), + input + .schema() + .fields() + .iter() + .map(|f| f.qualified_name()) + .collect::>() + .join(",") + ); + match &table { + Table::ConnectorTable(connector_table) => { + match (input_is_updating, connector_table.is_updating()) { + (_, true) => { + let to_debezium_extension = + ToDebeziumExtension::try_new(input.as_ref().clone())?; + input = Arc::new(LogicalPlan::Extension(Extension { + node: Arc::new(to_debezium_extension), + })); + schema = input.schema().clone(); + } + (true, false) => { + return plan_err!("input is updating, but sink is not updating"); + } + (false, false) => {} + } + } + Table::MemoryTable { .. } => return plan_err!("memory tables not supported"), + Table::TableFromQuery { .. } => {} + Table::PreviewSink { .. } => { + if input_is_updating { let to_debezium_extension = ToDebeziumExtension::try_new(input.as_ref().clone())?; input = Arc::new(LogicalPlan::Extension(Extension { diff --git a/crates/arroyo-df/src/extension/updating_aggregate.rs b/crates/arroyo-df/src/extension/updating_aggregate.rs new file mode 100644 index 000000000..af5880798 --- /dev/null +++ b/crates/arroyo-df/src/extension/updating_aggregate.rs @@ -0,0 +1,162 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; +use arroyo_datastream::logical::{LogicalEdge, LogicalEdgeType, LogicalNode, OperatorName}; +use arroyo_rpc::{df::ArroyoSchema, grpc::api::UpdatingAggregateOperator, TIMESTAMP_FIELD}; +use datafusion_expr::{Extension, LogicalPlan, UserDefinedLogicalNodeCore}; +use datafusion_proto::protobuf::{physical_plan_node::PhysicalPlanType, PhysicalPlanNode}; + +use crate::builder::{NamedNode, SplitPlanOutput}; + +use super::{ArroyoExtension, IsRetractExtension, NodeWithIncomingEdges}; +use prost::Message; + +pub(crate) const UPDATING_AGGREGATE_EXTENSION_NAME: &'static str = "UpdatingAggregateExtension"; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct UpdatingAggregateExtension { + pub(crate) aggregate: LogicalPlan, + pub(crate) key_fields: Vec, + pub(crate) final_calculation: LogicalPlan, +} + +impl UpdatingAggregateExtension { + pub fn new(aggregate: LogicalPlan, key_fields: Vec) -> Self { + let final_calculation = LogicalPlan::Extension(Extension { + node: Arc::new(IsRetractExtension::new(aggregate.clone())), + }); + Self { + aggregate, + key_fields, + final_calculation, + } + } +} + +impl UserDefinedLogicalNodeCore for UpdatingAggregateExtension { + fn name(&self) -> &str { + UPDATING_AGGREGATE_EXTENSION_NAME + } + + fn inputs(&self) -> Vec<&datafusion_expr::LogicalPlan> { + vec![&self.aggregate] + } + + fn schema(&self) -> &datafusion_common::DFSchemaRef { + self.final_calculation.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "UpdatingAggregateExtension") + } + + fn from_template( + &self, + _exprs: &[datafusion::prelude::Expr], + inputs: &[datafusion_expr::LogicalPlan], + ) -> Self { + Self::new(inputs[0].clone(), self.key_fields.clone()) + } +} + +impl ArroyoExtension for UpdatingAggregateExtension { + fn node_name(&self) -> Option { + None + } + + fn plan_node( + &self, + planner: &crate::builder::Planner, + index: usize, + input_schemas: Vec, + ) -> Result { + if input_schemas.len() != 1 { + bail!( + "UpdatingAggregateExtension requires exactly one input schema, found {}", + input_schemas.len() + ); + } + let input_schema = input_schemas[0].clone(); + let SplitPlanOutput { + partial_aggregation_plan, + partial_schema, + finish_plan, + } = planner.split_physical_plan(self.key_fields.clone(), &self.aggregate, false)?; + let mut state_fields = partial_schema.schema.fields().to_vec(); + state_fields[partial_schema.timestamp_index] = Arc::new(Field::new( + TIMESTAMP_FIELD, + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )); + let state_partial_schema = ArroyoSchema::new_keyed( + Arc::new(Schema::new_with_metadata( + state_fields, + partial_schema.schema.metadata().clone(), + )), + partial_schema.timestamp_index, + self.key_fields.clone(), + ); + let mut state_final_fields = self + .aggregate + .schema() + .fields() + .iter() + .map(|df_field| df_field.field().clone()) + .collect::>(); + let timestamp_index = state_final_fields.len() - 1; + state_final_fields[timestamp_index] = Arc::new(Field::new( + TIMESTAMP_FIELD, + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + )); + let state_final_schema = ArroyoSchema::new_keyed( + Arc::new(Schema::new_with_metadata( + state_final_fields, + self.aggregate.schema().metadata().clone(), + )), + timestamp_index, + self.key_fields.clone(), + ); + + let Some(PhysicalPlanType::Aggregate(aggregate)) = finish_plan.physical_plan_type.as_ref() + else { + bail!("expect finish plan to be an aggregate"); + }; + let mut combine_aggregate = aggregate.as_ref().clone(); + combine_aggregate.set_mode(datafusion_proto::protobuf::AggregateMode::CombinePartial); + let combine_plan = PhysicalPlanNode { + physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new(combine_aggregate))), + }; + + let config = UpdatingAggregateOperator { + name: "UpdatingAggregate".to_string(), + partial_schema: Some(partial_schema.try_into()?), + state_partial_schema: Some(state_partial_schema.try_into()?), + state_final_schema: Some(state_final_schema.try_into()?), + partial_aggregation_plan: partial_aggregation_plan.encode_to_vec(), + combine_plan: combine_plan.encode_to_vec(), + final_aggregation_plan: finish_plan.encode_to_vec(), + }; + let node = LogicalNode { + operator_id: format!("updating_aggregate_{}", index), + description: format!("UpdatingAggregate"), + operator_name: OperatorName::UpdatingAggregate, + operator_config: config.encode_to_vec(), + parallelism: 1, + }; + let edge = LogicalEdge::project_all(LogicalEdgeType::Shuffle, (*input_schema).clone()); + Ok(NodeWithIncomingEdges { + node, + edges: vec![edge], + }) + } + + fn output_schema(&self) -> arroyo_rpc::df::ArroyoSchema { + ArroyoSchema::from_schema_unkeyed(Arc::new(self.schema().as_ref().into())).unwrap() + } +} diff --git a/crates/arroyo-df/src/lib.rs b/crates/arroyo-df/src/lib.rs index 51a4211f7..4b8a819dd 100644 --- a/crates/arroyo-df/src/lib.rs +++ b/crates/arroyo-df/src/lib.rs @@ -615,10 +615,28 @@ pub async fn parse_and_get_arrow_program( }; let plan_rewrite = plan.rewrite(&mut UnnestRewriter {})?; + info!( + "fields before rewrite: {:?}", + plan_rewrite + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); + let plan_rewrite = plan_rewrite.rewrite(&mut ArroyoRewriter { schema_provider: &schema_provider, })?; - + info!( + "fields after rewrite: {:?}", + plan_rewrite + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); let mut metadata = SourceMetadataVisitor::new(&schema_provider); plan_rewrite.visit(&mut metadata)?; used_connections.extend(metadata.connection_ids.iter()); diff --git a/crates/arroyo-df/src/physical.rs b/crates/arroyo-df/src/physical.rs index 7e303f003..9b832d059 100644 --- a/crates/arroyo-df/src/physical.rs +++ b/crates/arroyo-df/src/physical.rs @@ -1043,14 +1043,13 @@ pub struct ToDebeziumExec { impl ToDebeziumExec { pub fn try_new(input: Arc) -> DFResult { let input_schema = input.schema(); - let is_retract_index = input_schema.index_of(IS_RETRACT_FIELD)?; let timestamp_index = input_schema.index_of(TIMESTAMP_FIELD)?; let struct_fields: Vec<_> = input_schema .fields() .into_iter() .enumerate() .filter_map(|(index, field)| { - if index == is_retract_index || index == timestamp_index { + if field.name() == IS_RETRACT_FIELD || index == timestamp_index { None } else { Some(field.clone()) diff --git a/crates/arroyo-df/src/plan/aggregate.rs b/crates/arroyo-df/src/plan/aggregate.rs index 73cdd70f5..ce4b223ac 100644 --- a/crates/arroyo-df/src/plan/aggregate.rs +++ b/crates/arroyo-df/src/plan/aggregate.rs @@ -1,15 +1,122 @@ use crate::extension::aggregate::AggregateExtension; use crate::extension::key_calculation::KeyCalculationExtension; +use crate::extension::updating_aggregate::UpdatingAggregateExtension; use crate::plan::WindowDetectingVisitor; use crate::{find_window, WindowBehavior}; +use arroyo_rpc::{IS_RETRACT_FIELD, TIMESTAMP_FIELD}; use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{plan_err, DFField, DFSchema, DataFusionError, Result as DFResult}; -use datafusion_expr::{Aggregate, Expr, Extension, LogicalPlan}; +use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::{aggregate_function, Aggregate, Expr, Extension, LogicalPlan}; use std::sync::Arc; +use tracing::info; #[derive(Debug, Default)] pub struct AggregateRewriter {} +impl AggregateRewriter { + pub fn rewrite_non_windowed_aggregate( + input: Arc, + mut key_fields: Vec, + group_expr: Vec, + mut aggr_expr: Vec, + schema: Arc, + ) -> DFResult { + if input + .schema() + .has_column_with_unqualified_name(IS_RETRACT_FIELD) + { + return plan_err!("can't currently nest updating aggregates"); + } + + let key_count = key_fields.len(); + key_fields.extend(input.schema().fields().clone()); + + let key_schema = Arc::new(DFSchema::new_with_metadata( + key_fields, + schema.metadata().clone(), + )?); + + let mut key_projection_expressions = group_expr.clone(); + key_projection_expressions.extend( + input + .schema() + .fields() + .iter() + .map(|field| Expr::Column(field.qualified_column())), + ); + + let key_projection = + LogicalPlan::Projection(datafusion_expr::Projection::try_new_with_schema( + key_projection_expressions.clone(), + input.clone(), + key_schema.clone(), + )?); + + info!( + "key projection fields: {:?}", + key_projection + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); + + let key_plan = LogicalPlan::Extension(Extension { + node: Arc::new(KeyCalculationExtension::new( + key_projection, + (0..key_count).collect(), + )), + }); + let Some(timestamp_field) = key_plan + .schema() + .fields() + .iter() + .find(|field| field.name() == TIMESTAMP_FIELD) + else { + return plan_err!("no timestamp field found in schema"); + }; + let column = timestamp_field.qualified_column(); + aggr_expr.push(Expr::AggregateFunction(AggregateFunction::new( + aggregate_function::AggregateFunction::Max, + vec![Expr::Column(column.clone())], + false, + None, + None, + ))); + let mut output_schema_fields = schema.fields().clone(); + output_schema_fields.push(timestamp_field.clone()); + let output_schema = Arc::new(DFSchema::new_with_metadata( + output_schema_fields, + schema.metadata().clone(), + )?); + let aggregate = Aggregate::try_new_with_schema( + Arc::new(key_plan), + group_expr, + aggr_expr, + output_schema, + )?; + info!( + "aggregate field names: {:?}", + aggregate + .schema + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ); + let updating_aggregate_extension = UpdatingAggregateExtension::new( + LogicalPlan::Aggregate(aggregate), + (0..key_count).collect(), + ); + let final_plan = LogicalPlan::Extension(Extension { + node: Arc::new(updating_aggregate_extension), + }); + Ok(final_plan) + } +} + impl TreeNodeRewriter for AggregateRewriter { type N = LogicalPlan; @@ -85,7 +192,12 @@ impl TreeNodeRewriter for AggregateRewriter { group_expr.remove(window_index); key_fields.remove(window_index); let window_field = schema.field(window_index).clone(); - WindowBehavior::FromOperator { window: input_window, window_field, window_index, is_nested: true } + WindowBehavior::FromOperator { + window: input_window, + window_field, + window_index, + is_nested: true, + } } } } @@ -104,9 +216,9 @@ impl TreeNodeRewriter for AggregateRewriter { } } (false, false) => { - return Err(DataFusionError::NotImplemented( - format!("must have window in aggregate. Make sure you are calling one of the windowing functions (hop, tumble, session) or using the window field of the input"), - )) + return Self::rewrite_non_windowed_aggregate( + input, key_fields, group_expr, aggr_expr, schema, + ); } }; diff --git a/crates/arroyo-df/src/plan/join.rs b/crates/arroyo-df/src/plan/join.rs index 81eebde78..391612e3f 100644 --- a/crates/arroyo-df/src/plan/join.rs +++ b/crates/arroyo-df/src/plan/join.rs @@ -2,10 +2,11 @@ use crate::extension::join::JoinExtension; use crate::extension::key_calculation::KeyCalculationExtension; use crate::plan::WindowDetectingVisitor; use arroyo_datastream::WindowType; +use arroyo_rpc::IS_RETRACT_FIELD; use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::{ - Column, DFField, DFSchema, DataFusionError, JoinConstraint, JoinType, Result as DFResult, - ScalarValue, + plan_err, Column, DFField, DFSchema, DataFusionError, JoinConstraint, JoinType, + Result as DFResult, ScalarValue, }; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::{ @@ -55,6 +56,22 @@ impl JoinRewriter { } } + fn check_updating(left: &LogicalPlan, right: &LogicalPlan) -> DFResult<()> { + if left + .schema() + .has_column_with_unqualified_name(IS_RETRACT_FIELD) + { + return plan_err!("can't handle updating left side of join"); + } + if right + .schema() + .has_column_with_unqualified_name(IS_RETRACT_FIELD) + { + return plan_err!("can't handle updating right side of join"); + } + Ok(()) + } + fn create_join_key_plan( &self, input: Arc, @@ -204,6 +221,7 @@ impl TreeNodeRewriter for JoinRewriter { "can't handle join constraint other than ON".into(), )); }; + Self::check_updating(&left, &right)?; let (left_expressions, right_expressions): (Vec<_>, Vec<_>) = on.clone().into_iter().unzip(); diff --git a/crates/arroyo-df/src/plan/mod.rs b/crates/arroyo-df/src/plan/mod.rs index 08310fd57..7c814a88a 100644 --- a/crates/arroyo-df/src/plan/mod.rs +++ b/crates/arroyo-df/src/plan/mod.rs @@ -282,6 +282,9 @@ impl<'a> TreeNodeRewriter for ArroyoRewriter<'a> { .input .schema() .has_column_with_unqualified_name(IS_RETRACT_FIELD) + && !projection + .schema + .has_column_with_unqualified_name(IS_RETRACT_FIELD) { let field = projection .input diff --git a/crates/arroyo-df/src/test/queries/create_table_updating.sql b/crates/arroyo-df/src/test/queries/create_table_updating.sql new file mode 100644 index 000000000..7f557115b --- /dev/null +++ b/crates/arroyo-df/src/test/queries/create_table_updating.sql @@ -0,0 +1,20 @@ +CREATE TABLE nexmark ( + auction bigint, + bidder bigint, + price bigint, + channel text, + url text, + datetime timestamp, + extra text, +) WITH ( + connector = 'filesystem', + format = 'parquet', + type = 'source', + path = '/home/data', + 'source.regex-pattern' = '00001-000.parquet', + event_time_field = 'datetime' +); + +CREATE TABLE counts as (SELECT count(*) FROM nexmark); + +SELECT * FROM counts \ No newline at end of file diff --git a/crates/arroyo-df/src/test/queries/error_missing_window_basic_tumble.sql b/crates/arroyo-df/src/test/queries/error_missing_window_basic_tumble.sql index e3aaef16b..04daf62c6 100644 --- a/crates/arroyo-df/src/test/queries/error_missing_window_basic_tumble.sql +++ b/crates/arroyo-df/src/test/queries/error_missing_window_basic_tumble.sql @@ -1,4 +1,3 @@ ---fail=must have window in aggregate. Make sure you are calling one of the windowing functions (hop, tumble, session) or using the window field of the input CREATE TABLE Nexmark WITH ( connector = 'nexmark', event_rate = '10' diff --git a/crates/arroyo-df/src/test/queries/error_no_aggregate_over_debezium.sql b/crates/arroyo-df/src/test/queries/error_no_aggregate_over_debezium.sql new file mode 100644 index 000000000..dbb966049 --- /dev/null +++ b/crates/arroyo-df/src/test/queries/error_no_aggregate_over_debezium.sql @@ -0,0 +1,12 @@ +--fail=Error during planning: can't currently nest updating aggregates +CREATE TABLE debezium_input ( + count int + ) WITH ( + connector = 'kafka', + bootstrap_servers = 'localhost:9092', + type = 'source', + topic = 'updating', + format = 'debezium_json' +); + +SELECT count(*) FROM debezium_input \ No newline at end of file diff --git a/crates/arroyo-df/src/test/queries/error_no_nested_updating_aggregates.sql b/crates/arroyo-df/src/test/queries/error_no_nested_updating_aggregates.sql new file mode 100644 index 000000000..ebb49b51f --- /dev/null +++ b/crates/arroyo-df/src/test/queries/error_no_nested_updating_aggregates.sql @@ -0,0 +1,7 @@ +--fail=Error during planning: can't currently nest updating aggregates +CREATE TABLE impulse with ( + connector = 'impulse', + event_rate = '10' + ); + SELECT sum(count) , max(counter) FROM( + SELECT count(*) as count, subtask_index, max(counter) as counter FROM impulse group by 2); \ No newline at end of file diff --git a/crates/arroyo-df/src/test/queries/no_inserting_updates_into_non_updating.sql b/crates/arroyo-df/src/test/queries/no_inserting_updates_into_non_updating.sql index 66100bd50..01d54f904 100644 --- a/crates/arroyo-df/src/test/queries/no_inserting_updates_into_non_updating.sql +++ b/crates/arroyo-df/src/test/queries/no_inserting_updates_into_non_updating.sql @@ -15,7 +15,7 @@ CREATE table sink ( bootstrap_servers = 'localhost:9092', type = 'sink', topic = 'sink', - format = 'json' + format = 'debezium_json' ); INSERT into sink diff --git a/crates/arroyo-df/src/test/queries/no_updating_joins.sql b/crates/arroyo-df/src/test/queries/no_updating_joins.sql new file mode 100644 index 000000000..f04b03afd --- /dev/null +++ b/crates/arroyo-df/src/test/queries/no_updating_joins.sql @@ -0,0 +1,21 @@ +--fail=Error during planning: can't handle updating left side of join +CREATE TABLE nexmark ( + auction bigint, + bidder bigint, + price bigint, + channel text, + url text, + datetime timestamp, + extra text, +) WITH ( + connector = 'filesystem', + format = 'parquet', + type = 'source', + path = '/home/data', + 'source.regex-pattern' = '00001-000.parquet', + event_time_field = 'datetime' +); + +CREATE TABLE counts as (SELECT count(*) as counts, bidder FROM nexmark GROUP BY 2); + +SELECT a.counts, b.counts FROM counts A join counts B on A.bidder = b.bidder \ No newline at end of file diff --git a/crates/arroyo-rpc/proto/api.proto b/crates/arroyo-rpc/proto/api.proto index 7912f5e31..eee882c80 100644 --- a/crates/arroyo-rpc/proto/api.proto +++ b/crates/arroyo-rpc/proto/api.proto @@ -106,6 +106,16 @@ message WindowFunctionOperator { bytes window_function_plan = 4; } +message UpdatingAggregateOperator { + string name = 1; + ArroyoSchema partial_schema = 2; + ArroyoSchema state_partial_schema = 3; + ArroyoSchema state_final_schema = 4; + bytes partial_aggregation_plan = 5; + bytes combine_plan = 6; + bytes final_aggregation_plan = 7; +} + message WasmUdfs { string name = 1; repeated WasmFunction wasm_functions = 2; diff --git a/crates/arroyo-rpc/src/df.rs b/crates/arroyo-rpc/src/df.rs index 1fc4fa237..bb7398167 100644 --- a/crates/arroyo-rpc/src/df.rs +++ b/crates/arroyo-rpc/src/df.rs @@ -256,24 +256,67 @@ impl ArroyoSchema { pub fn sort_fields(&self, with_timestamp: bool) -> Vec { let mut sort_fields = vec![]; if let Some(keys) = &self.key_indices { - sort_fields.extend( - keys.iter() - .map(|index| SortField::new(self.schema.field(*index).data_type().clone())), - ); + sort_fields.extend(keys.iter()); } if with_timestamp { - sort_fields.push(SortField::new(DataType::Timestamp( - TimeUnit::Nanosecond, - None, - ))); + sort_fields.push(self.timestamp_index); } - sort_fields + self.sort_fields_by_indices(&sort_fields) + } + + fn sort_fields_by_indices(&self, indices: &[usize]) -> Vec { + indices + .iter() + .map(|index| SortField::new(self.schema.field(*index).data_type().clone())) + .collect() } pub fn converter(&self, with_timestamp: bool) -> Result { Converter::new(self.sort_fields(with_timestamp)) } + pub fn value_converter(&self, with_timestamp: bool) -> Result { + match &self.key_indices { + None => { + let mut indices = (0..self.schema.fields().len()).collect::>(); + + if !with_timestamp { + indices.remove(self.timestamp_index); + } + Converter::new(self.sort_fields_by_indices(&indices)) + } + Some(keys) => { + let indices = (0..self.schema.fields().len()) + .filter(|index| { + !keys.contains(index) && (with_timestamp || *index != self.timestamp_index) + }) + .collect::>(); + Converter::new(self.sort_fields_by_indices(&indices)) + } + } + } + pub fn value_indices(&self, with_timestamp: bool) -> Vec { + let field_count = self.schema.fields().len(); + match &self.key_indices { + None => { + let mut indices = (0..field_count).collect::>(); + + if !with_timestamp { + indices.remove(self.timestamp_index); + } + indices + } + Some(keys) => { + let indices = (0..field_count) + .filter(|index| { + !keys.contains(index) && (with_timestamp || *index != self.timestamp_index) + }) + .collect::>(); + indices + } + } + } + pub fn sort(&self, batch: RecordBatch, with_timestamp: bool) -> Result { if self.key_indices.is_none() && !with_timestamp { return Ok(batch); diff --git a/crates/arroyo-rpc/src/lib.rs b/crates/arroyo-rpc/src/lib.rs index 5bd42576c..a8bc5578a 100644 --- a/crates/arroyo-rpc/src/lib.rs +++ b/crates/arroyo-rpc/src/lib.rs @@ -12,7 +12,7 @@ use crate::api_types::connections::PrimitiveType; use crate::formats::{BadData, Format, Framing}; use crate::grpc::{LoadCompactedDataReq, SubtaskCheckpointMetadata}; use anyhow::Result; -use arrow::row::{OwnedRow, RowConverter, SortField}; +use arrow::row::{OwnedRow, RowConverter, Rows, SortField}; use arrow_array::{Array, ArrayRef, BooleanArray}; use arrow_schema::DataType; use arroyo_types::{CheckpointBarrier, HASH_SEEDS}; @@ -230,12 +230,36 @@ impl Converter { } } + pub fn convert_all_columns(&self, columns: &[Arc]) -> anyhow::Result { + match self { + Converter::RowConverter(row_converter) => Ok(row_converter.convert_columns(columns)?), + Converter::Empty(row_converter, array) => { + Ok(row_converter.convert_columns(&vec![array.clone()])?) + } + } + } + pub fn convert_rows(&self, rows: Vec>) -> anyhow::Result> { match self { Converter::RowConverter(row_converter) => Ok(row_converter.convert_rows(rows)?), Converter::Empty(_row_converter, _array) => Ok(vec![]), } } + + pub fn convert_raw_rows(&self, row_bytes: Vec<&[u8]>) -> anyhow::Result> { + match self { + Converter::RowConverter(row_converter) => { + let parser = row_converter.parser(); + let mut row_list = vec![]; + for bytes in row_bytes { + let row = parser.parse(bytes); + row_list.push(row); + } + Ok(row_converter.convert_rows(row_list)?) + } + Converter::Empty(_row_converter, _array) => Ok(vec![]), + } + } } fn default_async_timeout_seconds() -> u64 { diff --git a/crates/arroyo-state/src/tables/expiring_time_key_map.rs b/crates/arroyo-state/src/tables/expiring_time_key_map.rs index 02c9d40e7..482fab1fa 100644 --- a/crates/arroyo-state/src/tables/expiring_time_key_map.rs +++ b/crates/arroyo-state/src/tables/expiring_time_key_map.rs @@ -1,16 +1,17 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, + mem, sync::Arc, time::{Duration, SystemTime}, }; use anyhow::{anyhow, bail, Ok, Result}; -use arrow::compute::{concat_batches, kernels::aggregate, take}; -use arrow::row::{OwnedRow, Row}; +use arrow::compute::{concat_batches, filter_record_batch, kernels::aggregate, take}; +use arrow::row::OwnedRow; use arrow_array::{ cast::AsArray, types::{TimestampNanosecondType, UInt64Type}, - PrimitiveArray, RecordBatch, + BooleanArray, PrimitiveArray, RecordBatch, TimestampNanosecondArray, }; use arrow_ord::{partition::partition, sort::sort_to_indices}; use arroyo_rpc::{ @@ -22,7 +23,9 @@ use arroyo_rpc::{ Converter, }; use arroyo_storage::StorageProviderRef; -use arroyo_types::{from_micros, from_nanos, print_time, server_for_hash, to_micros, TaskInfoRef}; +use arroyo_types::{ + from_micros, from_nanos, print_time, server_for_hash, to_micros, to_nanos, TaskInfoRef, +}; use futures::{StreamExt, TryStreamExt}; use parquet::{ @@ -37,7 +40,7 @@ use crate::{ TableData, }; use arroyo_rpc::df::{ArroyoSchema, ArroyoSchemaRef}; -use tracing::{debug, info}; +use tracing::{debug, warn}; use super::{table_checkpoint_path, CompactionConfig, Table, TableEpochCheckpointer}; @@ -57,61 +60,13 @@ impl ExpiringTimeKeyTable { state_tx: Sender, watermark: Option, ) -> Result { - let cutoff = watermark - .map(|watermark| (watermark - self.retention)) - .unwrap_or_else(|| SystemTime::UNIX_EPOCH); - info!( - "watermark is {:?}, cutoff is {:?}", - watermark.map(print_time), - print_time(cutoff) - ); - let files: Vec<_> = self - .checkpoint_files - .iter() - .filter_map(|file| { - // file must have some data greater than the cutoff and routing keys within the range. - if cutoff <= from_micros(file.max_timestamp_micros) - && (file.max_routing_key >= *self.task_info.key_range.start() - && *self.task_info.key_range.end() >= file.min_routing_key) - { - let needs_hash_filtering = *self.task_info.key_range.end() - < file.max_routing_key - || *self.task_info.key_range.start() > file.min_routing_key; - Some((file.file.clone(), needs_hash_filtering)) - } else { - None - } - }) - .collect(); + let cutoff = self.get_cutoff(watermark); + let files = self.get_files_with_filtering(cutoff); let mut data: BTreeMap> = BTreeMap::new(); - for (file, needs_filtering) in files { - let object_meta = self - .storage_provider - .get_backing_store() - .head(&(file.into())) - .await?; - let object_reader = - ParquetObjectReader::new(self.storage_provider.get_backing_store(), object_meta); - let reader_builder = ParquetRecordBatchStreamBuilder::new(object_reader).await?; - let mut stream = reader_builder.build()?; - // projection to trim the metadata fields. Should probably be factored out. - let projection: Vec<_> = (0..(stream.schema().fields().len() - 2)).collect(); - while let Some(batch_result) = stream.next().await { - let mut batch = batch_result?; - if needs_filtering { - match self - .schema - .filter_by_hash_index(batch, &self.task_info.key_range)? - { - None => continue, - Some(filtered_batch) => batch = filtered_batch, - }; - } - if batch.num_rows() == 0 { - continue; - } - batch = batch.project(&projection)?; + let timestamp_index = self.schema.timestamp_index(); + let batches_by_timestamp = self + .call_on_filtered_batches(files, |batch| { let timestamp_array: &PrimitiveArray = batch .column(self.schema.timestamp_index()) .as_primitive_opt() @@ -128,9 +83,8 @@ impl ExpiringTimeKeyTable { ); let batches = if max_timestamp != min_timestamp { // assume monotonic for now - let partitions = partition( - vec![batch.column(self.schema.timestamp_index()).clone()].as_slice(), - )?; + let partitions = + partition(vec![batch.column(timestamp_index).clone()].as_slice())?; partitions .ranges() .into_iter() @@ -142,11 +96,13 @@ impl ExpiringTimeKeyTable { } else { vec![(min_timestamp, batch)] }; - for (timestamp, batch) in batches { - if cutoff <= timestamp { - data.entry(timestamp).or_default().push(batch) - } - } + Ok(batches) + }) + .await?; + + for (timestamp, batch) in batches_by_timestamp { + if cutoff <= timestamp { + data.entry(timestamp).or_default().push(batch) } } @@ -157,40 +113,15 @@ impl ExpiringTimeKeyTable { state_tx, }) } - - pub(crate) async fn get_key_time_view( + async fn call_on_filtered_batches( &self, - state_tx: Sender, - watermark: Option, - ) -> Result { - let cutoff = watermark - .map(|watermark| (watermark - self.retention)) - .unwrap_or_else(|| SystemTime::UNIX_EPOCH); - info!( - "watermark is {:?}, cutoff is {:?}", - watermark.map(print_time), - print_time(cutoff) - ); - let files: Vec<_> = self - .checkpoint_files - .iter() - .filter_map(|file| { - // file must have some data greater than the cutoff and routing keys within the range. - if cutoff <= from_micros(file.max_timestamp_micros) - && (file.max_routing_key >= *self.task_info.key_range.start() - && *self.task_info.key_range.end() >= file.min_routing_key) - { - let needs_hash_filtering = *self.task_info.key_range.end() - < file.max_routing_key - || *self.task_info.key_range.start() > file.min_routing_key; - Some((file.file.clone(), needs_hash_filtering)) - } else { - None - } - }) - .collect(); - - let mut view = KeyTimeView::new(self.clone(), state_tx)?; + files: Vec<(String, bool)>, + batch_processor: F, + ) -> Result> + where + F: Fn(RecordBatch) -> Result> + Send + Sync, // Ensure `F` is a closure that can be sent and synced between threads + { + let mut result = vec![]; for (file, needs_filtering) in files { let object_meta = self .storage_provider @@ -214,7 +145,53 @@ impl ExpiringTimeKeyTable { Some(filtered_batch) => batch = filtered_batch, }; } + if batch.num_rows() == 0 { + continue; + } batch = batch.project(&projection)?; + result.extend(batch_processor(batch)?) + } + } + Ok(result) + } + + fn get_cutoff(&self, watermark: Option) -> SystemTime { + watermark + .map(|watermark| watermark - self.retention) + .unwrap_or(SystemTime::UNIX_EPOCH) + } + + fn get_files_with_filtering(&self, cutoff: SystemTime) -> Vec<(String, bool)> { + self.checkpoint_files + .iter() + .filter_map(|file| { + // file must have some data greater than the cutoff and routing keys within the range. + if cutoff <= from_micros(file.max_timestamp_micros) + && (file.max_routing_key >= *self.task_info.key_range.start() + && *self.task_info.key_range.end() >= file.min_routing_key) + { + let needs_hash_filtering = *self.task_info.key_range.end() + < file.max_routing_key + || *self.task_info.key_range.start() > file.min_routing_key; + Some((file.file.clone(), needs_hash_filtering)) + } else { + None + } + }) + .collect() + } + + pub(crate) async fn get_key_time_view( + &self, + state_tx: Sender, + watermark: Option, + ) -> Result { + let cutoff = self.get_cutoff(watermark); + let files = self.get_files_with_filtering(cutoff); + + let mut view = KeyTimeView::new(self.clone(), state_tx)?; + let batches_to_add = self + .call_on_filtered_batches(files, |batch| { let timestamp_array: &PrimitiveArray = batch .column(self.schema.timestamp_index()) .as_primitive_opt() @@ -225,11 +202,46 @@ impl ExpiringTimeKeyTable { as u128, ); if max_timestamp < cutoff { - continue; + Ok(vec![]) + } else { + Ok(vec![batch]) } - // TODO: more time filtering - view.insert_internal(batch)?; - } + }) + .await?; + for batch in batches_to_add { + view.insert_internal(batch)?; + } + Ok(view) + } + + pub(crate) async fn get_last_key_value_view( + &self, + state_tx: Sender, + watermark: Option, + ) -> Result { + let cutoff = self.get_cutoff(watermark); + let files = self.get_files_with_filtering(cutoff); + let mut view = LastKeyValueView::new(self.clone(), state_tx)?; + let batches = self + .call_on_filtered_batches(files, |batch| { + let timestamp_array: &PrimitiveArray = batch + .column(self.schema.timestamp_index()) + .as_primitive_opt() + .ok_or_else(|| anyhow!("failed to find timestamp column"))?; + let max_timestamp = from_nanos( + aggregate::max(timestamp_array) + .ok_or_else(|| anyhow!("should have max timestamp"))? + as u128, + ); + if max_timestamp < cutoff { + Ok(vec![]) + } else { + Ok(vec![batch]) + } + }) + .await?; + for batch in batches { + view.insert_batch_internal(batch, false).await?; } Ok(view) } @@ -899,18 +911,39 @@ enum BatchData { } impl KeyTimeView { - pub fn get_batch(&mut self, row: Row) -> Result> { - if !self.keyed_data.contains_key(row.as_ref()) { + fn new(parent: ExpiringTimeKeyTable, state_tx: Sender) -> Result { + let schema = parent.schema.memory_schema(); + let key_converter = schema.converter(false)?; + let value_schema = Arc::new(schema.schema_without_keys()?); + let value_indices = schema.value_indices(true); + Ok(Self { + key_converter, + parent, + keyed_data: HashMap::new(), + schema, + value_indices, + value_schema, + state_tx, + }) + } + + pub fn get_batch(&mut self, row: &[u8]) -> Result> { + if !self.keyed_data.contains_key(row) { + warn!( + "couldn't find data for {:?}, map has {} keys", + row, + self.keyed_data.len() + ); return Ok(None); } - let Some(value) = self.keyed_data.get_mut(row.as_ref()) else { + let Some(value) = self.keyed_data.get_mut(row) else { unreachable!("just checked") }; if let BatchData::BatchVec(batches) = value { let coalesced_batches = concat_batches(&self.value_schema.schema, batches.iter())?; *value = BatchData::SingleBatch(coalesced_batches); } - let Some(BatchData::SingleBatch(single_batch)) = self.keyed_data.get(row.as_ref()) else { + let Some(BatchData::SingleBatch(single_batch)) = self.keyed_data.get(row) else { unreachable!("just inserted") }; Ok(Some(single_batch)) @@ -927,6 +960,22 @@ impl KeyTimeView { } pub async fn insert(&mut self, batch: RecordBatch) -> Result> { + self.state_tx + .send(StateMessage::TableData { + table: self.parent.table_name.to_string(), + data: TableData::RecordBatch(batch.clone()), + }) + .await?; + Ok(self + .insert_internal(batch)? + .into_iter() + .map(|(row, _)| row) + .collect()) + } + pub async fn insert_and_report_prior_presence( + &mut self, + batch: RecordBatch, + ) -> Result> { self.state_tx .send(StateMessage::TableData { table: self.parent.table_name.to_string(), @@ -936,7 +985,7 @@ impl KeyTimeView { self.insert_internal(batch) } - fn insert_internal(&mut self, batch: RecordBatch) -> Result> { + fn insert_internal(&mut self, batch: RecordBatch) -> Result> { let sorted_batch = self.schema.sort(batch, false)?; let value_batch = sorted_batch.project(&self.value_indices)?; let mut rows = vec![]; @@ -952,8 +1001,8 @@ impl KeyTimeView { .to_vec() }; let key_row = self.key_converter.convert_columns(&key_columns)?; - rows.push(key_row.clone()); let contents = self.keyed_data.get_mut(key_row.as_ref()); + rows.push((key_row.clone(), contents.is_some())); let batch = match contents { Some(BatchData::BatchVec(vec)) => { vec.push(value_batch); @@ -975,27 +1024,189 @@ impl KeyTimeView { } Ok(rows) } +} + +#[derive(Debug)] +pub struct LastKeyValueView { + parent: ExpiringTimeKeyTable, + key_converter: Converter, + value_converter: Converter, + value_indices: Vec, + // indices of schema that aren't keys, used for projection + backing_map: HashMap, (Vec, SystemTime)>, + expirations: BTreeMap>>, + state_tx: Sender, +} +impl LastKeyValueView { fn new(parent: ExpiringTimeKeyTable, state_tx: Sender) -> Result { let schema = parent.schema.memory_schema(); let key_converter = schema.converter(false)?; - let value_schema = Arc::new(schema.schema_without_keys()?); - let value_indices = if schema.key_indices.is_some() { - let key_indices = schema.key_indices.as_ref().unwrap(); - (0..schema.schema.fields().len()) - .filter(|i| !key_indices.contains(i)) - .collect() - } else { - (0..schema.schema.fields().len()).collect() - }; + let value_converter = schema.value_converter(false)?; + let value_indices = schema.value_indices(false); + let backing_map = HashMap::new(); + let expirations = BTreeMap::new(); Ok(Self { key_converter, - parent, - keyed_data: HashMap::new(), - schema, + value_converter, value_indices, - value_schema, + backing_map, + expirations, + parent, state_tx, }) } + pub async fn insert_batch(&mut self, batch: RecordBatch) -> Result<()> { + self.insert_batch_internal(batch, true).await + } + + pub fn get_current_matching_values( + &self, + batch: &RecordBatch, + ) -> Result> { + let key_batch: RecordBatch = batch.project( + &self + .parent + .schema + .memory_schema() + .key_indices + .as_ref() + .unwrap(), + )?; + let key_rows = self + .key_converter + .convert_all_columns(key_batch.columns())?; + let capacity = batch.num_rows().min(self.backing_map.len()); + let mut prior_values = Vec::with_capacity(capacity); + let mut prior_timestamp_builder = TimestampNanosecondArray::builder(capacity); + let mut prior_row_filter = BooleanArray::builder(capacity); + for i in 0..batch.num_rows() { + match self.backing_map.get(key_rows.row(i).as_ref()) { + None => { + prior_row_filter.append_value(false); + } + Some((value, timestamp)) => { + prior_row_filter.append_value(true); + prior_timestamp_builder.append_value(to_nanos(*timestamp) as i64); + prior_values.push(value.as_slice()); + } + } + } + let filter = prior_row_filter.finish(); + let filtered_key_batch = filter_record_batch(&key_batch, &filter)?; + if filtered_key_batch.num_rows() == 0 { + return Ok(None); + } + let mut columns = filtered_key_batch.columns().to_vec(); + let value_columns = self.value_converter.convert_raw_rows(prior_values)?; + columns.extend(value_columns); + columns.push(Arc::new(prior_timestamp_builder.finish())); + + Ok(Some(( + RecordBatch::try_new(batch.schema(), columns)?, + filter, + ))) + } + + async fn insert_batch_internal( + &mut self, + batch: RecordBatch, + should_write: bool, + ) -> Result<()> { + let key_batch: RecordBatch = batch.project( + &self + .parent + .schema + .memory_schema() + .key_indices + .as_ref() + .unwrap(), + )?; + let value_batch = batch.project(&self.value_indices)?; + let key_rows = self + .key_converter + .convert_all_columns(key_batch.columns())?; + let value_rows = self + .value_converter + .convert_all_columns(value_batch.columns())?; + let timestamp_columns = batch + .column(self.parent.schema.memory_schema().timestamp_index) + .as_any() + .downcast_ref::() + .ok_or_else(|| anyhow!("should be able to extract timestamp array"))?; + let mut max_timestamps = Vec::with_capacity(batch.num_rows()); + for i in 0..batch.num_rows() { + let max_timestamp = self.insert_entry( + key_rows.row(i).as_ref(), + value_rows.row(i).as_ref(), + from_nanos(timestamp_columns.value(i) as u128), + )?; + + max_timestamps.push(to_nanos(max_timestamp) as i64); + } + if !should_write { + return Ok(()); + } + let mut columns = batch.columns().to_vec(); + columns[self.parent.schema.memory_schema().timestamp_index] = + Arc::new(TimestampNanosecondArray::from(max_timestamps)); + let batch = + RecordBatch::try_new(self.parent.schema.memory_schema().schema.clone(), columns)?; + self.state_tx + .send(StateMessage::TableData { + table: self.parent.table_name.to_string(), + data: TableData::RecordBatch(batch.clone()), + }) + .await?; + Ok(()) + } + + fn insert_entry( + &mut self, + key_row: &[u8], + value_row: &[u8], + timestamp: SystemTime, + ) -> Result { + match self.backing_map.get_mut(key_row) { + None => { + self.backing_map + .insert(key_row.to_vec(), (value_row.to_vec(), timestamp)); + self.expirations + .entry(timestamp) + .or_default() + .insert(key_row.to_vec()); + Ok(timestamp) + } + Some((old_value, existing_timestamp)) => { + if *existing_timestamp < timestamp { + self.expirations + .get_mut(existing_timestamp) + .unwrap() + .remove(key_row); + self.expirations + .entry(timestamp) + .or_default() + .insert(key_row.to_vec()); + *existing_timestamp = timestamp; + } + *old_value = value_row.to_vec(); + Ok(*existing_timestamp) + } + } + } + + pub fn expire(&mut self, watermark: Option) -> Result<()> { + let Some(watermark) = watermark else { + return Ok(()); + }; + let cutoff = watermark - self.parent.retention; + let mut to_delete = self.expirations.split_off(&cutoff); + mem::swap(&mut self.expirations, &mut to_delete); + for keys in to_delete.values() { + for key in keys { + self.backing_map.remove(key); + } + } + Ok(()) + } } diff --git a/crates/arroyo-state/src/tables/mod.rs b/crates/arroyo-state/src/tables/mod.rs index cf532932d..d0f5306e6 100644 --- a/crates/arroyo-state/src/tables/mod.rs +++ b/crates/arroyo-state/src/tables/mod.rs @@ -16,11 +16,6 @@ pub mod expiring_time_key_map; pub mod global_keyed_map; pub mod table_manager; -pub enum Compactor { - TimeKeyMap, - KeyTimeMultiMap, -} - pub(crate) fn table_checkpoint_path( job_id: &str, operator_id: &str, diff --git a/crates/arroyo-state/src/tables/table_manager.rs b/crates/arroyo-state/src/tables/table_manager.rs index c0bfa1e80..38e538afc 100644 --- a/crates/arroyo-state/src/tables/table_manager.rs +++ b/crates/arroyo-state/src/tables/table_manager.rs @@ -23,7 +23,9 @@ use tracing::{debug, info, warn}; use crate::{tables::global_keyed_map::GlobalKeyedTable, StateMessage}; use crate::{CheckpointMessage, TableData}; -use super::expiring_time_key_map::{ExpiringTimeKeyTable, ExpiringTimeKeyView, KeyTimeView}; +use super::expiring_time_key_map::{ + ExpiringTimeKeyTable, ExpiringTimeKeyView, KeyTimeView, LastKeyValueView, +}; use super::global_keyed_map::GlobalKeyedView; use super::{ErasedCheckpointer, ErasedTable}; @@ -458,4 +460,33 @@ impl TableManager { .ok_or_else(|| anyhow!("Failed to downcast table {}", table_name))?; Ok(cache) } + + pub async fn get_last_key_value_table( + &mut self, + table_name: &str, + watermark: Option, + ) -> Result<&mut LastKeyValueView> { + if let std::collections::hash_map::Entry::Vacant(e) = + self.caches.entry(table_name.to_string()) + { + let table_implementation = self + .tables + .get(table_name) + .ok_or_else(|| anyhow!("no registered table {}", table_name))?; + let expiring_time_key_table = table_implementation + .as_any() + .downcast_ref::() + .ok_or_else(|| anyhow!("wrong table type for table {}", table_name))?; + let saved_data = expiring_time_key_table + .get_last_key_value_view(self.writer.sender.clone(), watermark) + .await?; + let cache: Box = Box::new(saved_data); + e.insert(cache); + } + let cache = self.caches.get_mut(table_name).unwrap(); + let cache: &mut LastKeyValueView = cache + .downcast_mut() + .ok_or_else(|| anyhow!("Failed to downcast table {}", table_name))?; + Ok(cache) + } } diff --git a/crates/arroyo-worker/src/arrow/join_with_expiration.rs b/crates/arroyo-worker/src/arrow/join_with_expiration.rs index 40f3d0ebb..6336dc21a 100644 --- a/crates/arroyo-worker/src/arrow/join_with_expiration.rs +++ b/crates/arroyo-worker/src/arrow/join_with_expiration.rs @@ -57,7 +57,7 @@ impl JoinWithExpiration { let mut right_batches = vec![]; for row in left_rows { if let Some(batch) = right_table - .get_batch(row.row()) + .get_batch(row.as_ref()) .expect("shouldn't error getting batch") { right_batches.push(batch.clone()); @@ -95,7 +95,7 @@ impl JoinWithExpiration { let mut left_batches = vec![]; for row in right_rows { if let Some(batch) = left_table - .get_batch(row.row()) + .get_batch(row.as_ref()) .expect("shouldn't error getting batch") { left_batches.push(batch.clone()); diff --git a/crates/arroyo-worker/src/arrow/mod.rs b/crates/arroyo-worker/src/arrow/mod.rs index 5ed9f3c16..979e43f60 100644 --- a/crates/arroyo-worker/src/arrow/mod.rs +++ b/crates/arroyo-worker/src/arrow/mod.rs @@ -27,6 +27,7 @@ pub mod session_aggregating_window; pub mod sliding_aggregating_window; pub(crate) mod sync; pub mod tumbling_aggregating_window; +pub mod updating_aggregator; pub mod window_fn; pub struct ValueExecutionOperator { diff --git a/crates/arroyo-worker/src/arrow/updating_aggregator.rs b/crates/arroyo-worker/src/arrow/updating_aggregator.rs new file mode 100644 index 000000000..ee7d04b82 --- /dev/null +++ b/crates/arroyo-worker/src/arrow/updating_aggregator.rs @@ -0,0 +1,336 @@ +use std::{ + any::Any, + collections::HashMap, + pin::Pin, + sync::{Arc, RwLock}, + time::SystemTime, +}; + +use anyhow::{anyhow, Result}; +use arrow::compute::concat_batches; +use arrow_array::RecordBatch; + +use arroyo_operator::{ + context::ArrowContext, + operator::{ArrowOperator, OperatorConstructor, OperatorNode}, +}; +use arroyo_rpc::grpc::{api::UpdatingAggregateOperator, TableConfig}; +use arroyo_state::timestamp_table_config; +use arroyo_types::{CheckpointBarrier, Watermark}; +use datafusion::{execution::context::SessionContext, physical_plan::ExecutionPlan}; + +use arroyo_df::physical::{ArroyoPhysicalExtensionCodec, DecodingContext}; +use arroyo_operator::operator::Registry; +use arroyo_rpc::df::ArroyoSchemaRef; +use datafusion_common::ScalarValue; +use datafusion_execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SendableRecordBatchStream, +}; +use datafusion_expr::ColumnarValue; +use datafusion_proto::{physical_plan::AsExecutionPlan, protobuf::PhysicalPlanNode}; +use futures::{lock::Mutex, Future}; +use prost::Message; +use std::time::Duration; +use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; +use tokio_stream::StreamExt; +use tracing::info; + +pub struct UpdatingAggregatingFunc { + partial_aggregation_plan: Arc, + partial_schema: ArroyoSchemaRef, + state_partial_schema: ArroyoSchemaRef, + state_final_schema: ArroyoSchemaRef, + combine_plan: Arc, + finish_execution_plan: Arc, + receiver: Arc>>>, + sender: Option>, + // this is optional because an exec with no input has unreliable behavior. + // In particular, if it is a global aggregate it will emit a record batch with 1 row initialized with the empty aggregate state, + // while if it does have group by keys it will emit a record batch with 0 rows. + exec: Arc>>, +} + +impl UpdatingAggregatingFunc { + async fn flush(&mut self, ctx: &mut ArrowContext) -> Result<()> { + if self.sender.is_none() { + return Ok(()); + } + let flush_start = SystemTime::now(); + { + self.sender.take(); + } + let mut partial_batches = vec![]; + let mut flushing_exec = self.exec.lock().await.take().unwrap(); + while let Some(batch) = flushing_exec.next().await { + partial_batches.push(batch?); + } + info!( + "finished active exec in {:?}", + flush_start.elapsed().unwrap() + ); + let new_partial_batch = + concat_batches(&self.state_partial_schema.schema, &partial_batches)?; + let prior_partials = ctx + .table_manager + .get_last_key_value_table("p", ctx.last_present_watermark()) + .await?; + let mut final_input_batches = vec![]; + if let Some((prior_partial_batch, _filter)) = + prior_partials.get_current_matching_values(&new_partial_batch)? + { + let combining_batches = vec![new_partial_batch, prior_partial_batch]; + let combine_batch = concat_batches(&self.partial_schema.schema, &combining_batches)?; + let mut combine_exec = { + let (sender, receiver) = unbounded_channel(); + sender.send(combine_batch)?; + self.receiver.write().unwrap().replace(receiver); + self.combine_plan + .execute(0, SessionContext::new().task_ctx())? + }; + while let Some(batch) = combine_exec.next().await { + let batch = batch?; + let renamed_batch = RecordBatch::try_new( + self.state_partial_schema.schema.clone(), + batch.columns().to_vec(), + )?; + prior_partials.insert_batch(renamed_batch).await?; + final_input_batches.push(batch); + } + } else { + // all the new data is disjoint from what's in state, no need to combine. + prior_partials + .insert_batch(new_partial_batch.clone()) + .await?; + final_input_batches.push(new_partial_batch); + } + let final_input_batch = concat_batches(&self.partial_schema.schema, &final_input_batches)?; + let mut final_exec = { + let (sender, receiver) = unbounded_channel(); + sender.send(final_input_batch)?; + self.receiver.write().unwrap().replace(receiver); + let final_exec = self + .finish_execution_plan + .execute(0, SessionContext::new().task_ctx())?; + final_exec + }; + let final_output_table = ctx + .table_manager + .get_last_key_value_table("f", ctx.last_present_watermark()) + .await?; + let mut batches_to_write = vec![]; + while let Some(results) = final_exec.next().await { + let results = results?; + let renamed_results = RecordBatch::try_new( + self.state_final_schema.schema.clone(), + results.columns().to_vec(), + )?; + if let Some((prior_batch, _filter)) = + final_output_table.get_current_matching_values(&renamed_results)? + { + let is_retract = ColumnarValue::Scalar(ScalarValue::Boolean(Some(true))) + .into_array(prior_batch.num_rows()) + .unwrap(); + let mut columns = prior_batch.columns().to_vec(); + columns.push(is_retract); + let retract_batch = + RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), columns)?; + batches_to_write.push(retract_batch); + } + final_output_table + .insert_batch(renamed_results.clone()) + .await?; + let is_retract = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))) + .into_array(results.num_rows()) + .unwrap(); + let mut columns = results.columns().to_vec(); + columns.push(is_retract); + let result_batch = + RecordBatch::try_new(ctx.out_schema.as_ref().unwrap().schema.clone(), columns)?; + batches_to_write.push(result_batch); + } + for batch in batches_to_write.into_iter() { + ctx.collect(batch).await; + } + info!("flushed in {:?}", flush_start.elapsed().unwrap()); + Ok(()) + } + + fn init_exec(&mut self) { + let (sender, receiver) = unbounded_channel(); + { + let mut internal_receiver = self.receiver.write().unwrap(); + *internal_receiver = Some(receiver); + } + let new_exec = self + .partial_aggregation_plan + .execute(0, SessionContext::new().task_ctx()) + .unwrap(); + self.exec = Arc::new(Mutex::new(Some(new_exec))); + self.sender = Some(sender); + } +} + +#[async_trait::async_trait] +impl ArrowOperator for UpdatingAggregatingFunc { + fn name(&self) -> String { + "UpdatingAggregatingFunc".to_string() + } + + async fn process_batch(&mut self, batch: RecordBatch, _ctx: &mut ArrowContext) { + if self.sender.is_none() { + self.init_exec(); + } + self.sender.as_ref().unwrap().send(batch).unwrap(); + } + + async fn handle_checkpoint(&mut self, _b: CheckpointBarrier, ctx: &mut ArrowContext) { + self.flush(ctx).await.unwrap(); + } + + fn tables(&self) -> HashMap { + vec![ + ( + "f".to_string(), + timestamp_table_config( + "f", + "final_table", + Duration::from_secs(60 * 60 * 24), + self.state_final_schema.as_ref().clone(), + ), + ), + ( + "p".to_string(), + timestamp_table_config( + "p", + "partial_table", + Duration::from_secs(60 * 60 * 24), + self.state_partial_schema.as_ref().clone(), + ), + ), + ] + .into_iter() + .collect() + } + fn tick_interval(&self) -> Option { + Some(Duration::from_secs(1)) + } + + async fn handle_tick(&mut self, _tick: u64, ctx: &mut ArrowContext) { + self.flush(ctx).await.unwrap(); + } + + async fn handle_watermark( + &mut self, + watermark: Watermark, + ctx: &mut ArrowContext, + ) -> Option { + let last_watermark = ctx.last_present_watermark(); + let final_table = ctx + .table_manager + .get_last_key_value_table("f", last_watermark) + .await + .expect("should have final table"); + final_table + .expire(last_watermark) + .expect("should expire final table"); + let partial_table = ctx + .table_manager + .get_last_key_value_table("p", last_watermark) + .await + .expect("should have partial table"); + partial_table + .expire(last_watermark) + .expect("should expire partial table"); + Some(watermark) + } + + fn future_to_poll( + &mut self, + ) -> Option> + Send>>> { + if self.sender.is_none() { + return None; + } + let exec = self.exec.clone(); + Some(Box::pin(async move { + let batch = exec.lock().await.as_mut().unwrap().next().await; + Box::new(batch) as Box + })) + } + + async fn handle_future_result(&mut self, _result: Box, _: &mut ArrowContext) { + unreachable!("should not have future result") + } +} + +pub struct UpdatingAggregatingConstructor; + +impl OperatorConstructor for UpdatingAggregatingConstructor { + type ConfigT = UpdatingAggregateOperator; + + fn with_config( + &self, + config: Self::ConfigT, + registry: Arc, + ) -> anyhow::Result { + let receiver = Arc::new(RwLock::new(None)); + + let codec = ArroyoPhysicalExtensionCodec { + context: DecodingContext::UnboundedBatchStream(receiver.clone()), + }; + + let partial_aggregation_plan = + PhysicalPlanNode::decode(&mut config.partial_aggregation_plan.as_slice())?; + + // deserialize partial aggregation into execution plan with an UnboundedBatchStream source. + let partial_aggregation_plan = partial_aggregation_plan.try_into_physical_plan( + registry.as_ref(), + &RuntimeEnv::new(RuntimeConfig::new()).unwrap(), + &codec, + )?; + + let partial_schema = config + .partial_schema + .ok_or_else(|| anyhow!("requires partial schema"))? + .try_into()?; + + let combine_plan = PhysicalPlanNode::decode(&mut config.combine_plan.as_slice())?; + let combine_execution_plan = combine_plan.try_into_physical_plan( + registry.as_ref(), + &RuntimeEnv::new(RuntimeConfig::new()).unwrap(), + &codec, + )?; + + let finish_plan = PhysicalPlanNode::decode(&mut config.final_aggregation_plan.as_slice())?; + + let finish_execution_plan = finish_plan.try_into_physical_plan( + registry.as_ref(), + &RuntimeEnv::new(RuntimeConfig::new()).unwrap(), + &codec, + )?; + + Ok(OperatorNode::from_operator(Box::new( + UpdatingAggregatingFunc { + partial_aggregation_plan, + partial_schema: Arc::new(partial_schema), + combine_plan: combine_execution_plan, + state_partial_schema: Arc::new( + config + .state_partial_schema + .ok_or_else(|| anyhow!("requires partial schema"))? + .try_into()?, + ), + state_final_schema: Arc::new( + config + .state_final_schema + .ok_or_else(|| anyhow!("requires final schema"))? + .try_into()?, + ), + finish_execution_plan, + receiver, + sender: None, + exec: Arc::new(Mutex::new(None)), + }, + ))) + } +} diff --git a/crates/arroyo-worker/src/engine.rs b/crates/arroyo-worker/src/engine.rs index 1cff55b49..d343f52df 100644 --- a/crates/arroyo-worker/src/engine.rs +++ b/crates/arroyo-worker/src/engine.rs @@ -18,6 +18,7 @@ use crate::arrow::join_with_expiration::JoinWithExpirationConstructor; use crate::arrow::session_aggregating_window::SessionAggregatingWindowConstructor; use crate::arrow::sliding_aggregating_window::SlidingAggregatingWindowConstructor; use crate::arrow::tumbling_aggregating_window::TumblingAggregateWindowConstructor; +use crate::arrow::updating_aggregator::UpdatingAggregatingConstructor; use crate::arrow::window_fn::WindowFunctionConstructor; use crate::arrow::{KeyExecutionConstructor, ValueExecutionConstructor}; use crate::network_manager::{NetworkManager, Quad, Senders}; @@ -801,6 +802,7 @@ pub fn construct_operator( OperatorName::TumblingWindowAggregate => Box::new(TumblingAggregateWindowConstructor), OperatorName::SlidingWindowAggregate => Box::new(SlidingAggregatingWindowConstructor), OperatorName::SessionWindowAggregate => Box::new(SessionAggregatingWindowConstructor), + OperatorName::UpdatingAggregate => Box::new(UpdatingAggregatingConstructor), OperatorName::ExpressionWatermark => Box::new(WatermarkGeneratorConstructor), OperatorName::Join => Box::new(JoinWithExpirationConstructor), OperatorName::InstantJoin => Box::new(InstantJoinConstructor),