diff --git a/arroyo-df/src/lib.rs b/arroyo-df/src/lib.rs index 4e7b32a2c..c9a6a4d1a 100644 --- a/arroyo-df/src/lib.rs +++ b/arroyo-df/src/lib.rs @@ -15,6 +15,7 @@ pub mod logical; pub mod physical; mod plan_graph; pub mod schemas; +mod source_rewriter; mod tables; pub mod types; @@ -41,6 +42,7 @@ use schemas::{ window_arrow_struct, }; +use source_rewriter::SourceRewriter; use tables::{Insert, Table}; use types::interval_month_day_nanos_to_duration; @@ -375,6 +377,31 @@ impl ArroyoSchemaProvider { Ok(name) } + + fn get_table_source_with_fields( + &self, + name: &str, + fields: Vec, + ) -> datafusion_common::Result> { + let table = self + .get_table(name) + .ok_or_else(|| DataFusionError::Plan(format!("Table {} not found", name)))?; + + let fields = table + .get_fields() + .iter() + .filter_map(|field| { + if fields.contains(field) { + Some(field.clone()) + } else { + None + } + }) + .collect::>(); + + let schema = Arc::new(Schema::new_with_metadata(fields, HashMap::new())); + Ok(create_table(name.to_string(), schema)) + } } pub fn parse_dependencies(definition: &str) -> Result { @@ -922,39 +949,12 @@ impl TreeNodeRewriter for QueryToGraphVisitor { })) } LogicalPlan::TableScan(table_scan) => { - if let Some(projection_indices) = table_scan.projection { - let qualifier = table_scan.table_name.clone(); - let projected_schema = DFSchema::try_from_qualified_schema( - qualifier.clone(), - table_scan.source.schema().as_ref(), - )?; - let input_table_scan = LogicalPlan::TableScan(TableScan { - table_name: table_scan.table_name.clone(), - source: table_scan.source.clone(), - projection: None, - projected_schema: Arc::new(projected_schema), - filters: table_scan.filters.clone(), - fetch: table_scan.fetch, - }); - let projection_expressions: Vec<_> = projection_indices - .into_iter() - .map(|index| { - Expr::Column(Column { - relation: Some(qualifier.clone()), - name: table_scan.source.schema().fields()[index] - .name() - .to_string(), - }) - }) - .collect(); - let projection = LogicalPlan::Projection(datafusion_expr::Projection::try_new( - projection_expressions, - Arc::new(input_table_scan), - )?); - let mut timestamp_rewriter = TimestampRewriter {}; - let projection = projection.rewrite(&mut timestamp_rewriter)?; - return projection.rewrite(self); + if table_scan.projection.is_some() { + return Err(DataFusionError::Internal( + "Unexpected projection in table scan".to_string(), + )); } + let node_index = match self.table_source_to_nodes.get(&table_scan.table_name) { Some(node_index) => *node_index, None => { @@ -1061,8 +1061,13 @@ pub async fn parse_and_get_arrow_program( Insert::Anonymous { logical_plan } => (logical_plan, None), }; - let plan_with_timestamp = plan.rewrite(&mut TimestampRewriter {})?; - let plan_rewrite = plan_with_timestamp.rewrite(&mut rewriter).unwrap(); + let plan_rewrite = plan + .rewrite(&mut SourceRewriter { + schema_provider: schema_provider.clone(), + })? + .rewrite(&mut TimestampRewriter {})? + .rewrite(&mut rewriter) + .unwrap(); println!("REWRITE: {}", plan_rewrite.display_graphviz()); diff --git a/arroyo-df/src/source_rewriter.rs b/arroyo-df/src/source_rewriter.rs new file mode 100644 index 000000000..3183739ce --- /dev/null +++ b/arroyo-df/src/source_rewriter.rs @@ -0,0 +1,82 @@ +use crate::tables::FieldSpec; +use crate::tables::Table::ConnectorTable; +use crate::ArroyoSchemaProvider; +use arrow_schema::Schema; +use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::{Column, DFSchema, DataFusionError, Result as DFResult}; +use datafusion_expr::{Expr, LogicalPlan, TableScan}; +use std::sync::Arc; + +/// Rewrites a logical plan to move projections out of table scans +/// and into a separate projection node which may include virtual fields. +pub struct SourceRewriter { + pub(crate) schema_provider: ArroyoSchemaProvider, +} + +impl TreeNodeRewriter for SourceRewriter { + type N = LogicalPlan; + + fn mutate(&mut self, node: Self::N) -> DFResult { + let LogicalPlan::TableScan(table_scan) = node.clone() else { + return Ok(node); + }; + + let table_name = table_scan.table_name.table(); + let table = self + .schema_provider + .get_table(table_name) + .ok_or_else(|| DataFusionError::Plan(format!("Table {} not found", table_name)))?; + + let ConnectorTable(table) = table else { + return Ok(node); + }; + + let qualifier = table_scan.table_name.clone(); + + let expressions = table + .fields + .iter() + .map(|field| match field { + FieldSpec::StructField(f) => Expr::Column(Column { + relation: Some(qualifier.clone()), + name: f.name().to_string(), + }), + FieldSpec::VirtualField { field, expression } => { + expression.clone().alias(field.name().to_string()) + } + }) + .collect::>(); + + let non_virtual_fields = table + .fields + .iter() + .filter_map(|field| match field { + FieldSpec::StructField(f) => Some(f.clone()), + _ => None, + }) + .collect::>(); + + let table_scan_schema = DFSchema::try_from_qualified_schema( + qualifier.clone(), + &Schema::new(non_virtual_fields.clone()), + )?; + + let table_scan_table_source = self + .schema_provider + .get_table_source_with_fields(table_name, non_virtual_fields) + .unwrap(); + + let input_table_scan = LogicalPlan::TableScan(TableScan { + table_name: table_scan.table_name.clone(), + source: table_scan_table_source, + projection: None, // None because we are taking it out + projected_schema: Arc::new(table_scan_schema), + filters: table_scan.filters.clone(), + fetch: table_scan.fetch, + }); + + Ok(LogicalPlan::Projection( + datafusion_expr::Projection::try_new(expressions, Arc::new(input_table_scan))?, + )) + } +} diff --git a/arroyo-df/src/tables.rs b/arroyo-df/src/tables.rs index ba5dfd3f9..0d84d00f5 100644 --- a/arroyo-df/src/tables.rs +++ b/arroyo-df/src/tables.rs @@ -11,6 +11,7 @@ use arroyo_rpc::api_types::connections::{ }; use arroyo_rpc::formats::{BadData, Format, Framing}; use arroyo_types::ArroyoExtensionType; +use datafusion::sql::planner::PlannerContext; use datafusion::sql::sqlparser::ast::Query; use datafusion::{ optimizer::{analyzer::Analyzer, optimizer::Optimizer, OptimizerContext}, @@ -22,8 +23,7 @@ use datafusion::{ use datafusion_common::Column; use datafusion_common::{config::ConfigOptions, DFField, DFSchema}; use datafusion_expr::{ - CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr, LogicalPlan, Projection, - WriteOp, + CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr, LogicalPlan, WriteOp, }; use tracing::info; @@ -227,14 +227,6 @@ impl ConnectorTable { .unwrap_or(false) } - fn virtual_field_projection(&self) -> Result> { - if self.has_virtual_fields() { - bail!("virtual fields not supported in Arrow"); - } else { - Ok(None) - } - } - fn timestamp_override(&self) -> Result> { if let Some(field_name) = &self.event_time_field { if self.is_update() { @@ -317,7 +309,6 @@ impl ConnectorTable { bail!("can't read from a source with virtual fields and update mode.") } - let _virtual_field_projection = self.virtual_field_projection()?; let timestamp_override = self.timestamp_override()?; let watermark_column = self.watermark_column()?; @@ -491,7 +482,7 @@ impl Table { ) .collect(); - let _physical_schema = DFSchema::new_with_metadata( + let physical_schema = DFSchema::new_with_metadata( physical_fields .iter() .map(|f| { @@ -505,24 +496,24 @@ impl Table { HashMap::new(), )?; - let _sql_to_rel = SqlToRel::new(schema_provider); + let sql_to_rel = SqlToRel::new(schema_provider); struct_field_pairs .into_iter() .map(|(struct_field, generating_expression)| { - if let Some(_generating_expression) = generating_expression { + if let Some(generating_expression) = generating_expression { // TODO: Implement automatic type coercion here, as we have elsewhere. // It is done by calling the Analyzer which inserts CAST operators where necessary. - todo!("support generating expressions"); - /*let df_expr = sql_to_rel.sql_to_expr( + + let df_expr = sql_to_rel.sql_to_expr( generating_expression, &physical_schema, &mut PlannerContext::default(), )?; - let expression = expression_context.compile_expr(&df_expr)?; + Ok(FieldSpec::VirtualField { field: struct_field, - expression, - })*/ + expression: df_expr, + }) } else { Ok(FieldSpec::StructField(struct_field)) }