Skip to content

Commit

Permalink
Rewrite table scans as projections
Browse files Browse the repository at this point in the history
Add a SourceRewriter that separates table scans into a projection and a
table scan so that the projection can include virtual fields.
  • Loading branch information
jbeisen committed Jan 22, 2024
1 parent c381cec commit 9d4c259
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 53 deletions.
73 changes: 39 additions & 34 deletions arroyo-df/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod logical;
pub mod physical;
mod plan_graph;
pub mod schemas;
mod source_rewriter;
mod tables;
pub mod types;

Expand All @@ -41,6 +42,7 @@ use schemas::{
window_arrow_struct,
};

use source_rewriter::SourceRewriter;
use tables::{Insert, Table};
use types::interval_month_day_nanos_to_duration;

Expand Down Expand Up @@ -375,6 +377,31 @@ impl ArroyoSchemaProvider {

Ok(name)
}

fn get_table_source_with_fields(
&self,
name: &str,
fields: Vec<Field>,
) -> datafusion_common::Result<Arc<dyn TableSource>> {
let table = self
.get_table(name)
.ok_or_else(|| DataFusionError::Plan(format!("Table {} not found", name)))?;

let fields = table
.get_fields()
.iter()
.filter_map(|field| {
if fields.contains(field) {
Some(field.clone())
} else {
None
}
})
.collect::<Vec<_>>();

let schema = Arc::new(Schema::new_with_metadata(fields, HashMap::new()));
Ok(create_table(name.to_string(), schema))
}
}

pub fn parse_dependencies(definition: &str) -> Result<String> {
Expand Down Expand Up @@ -922,39 +949,12 @@ impl TreeNodeRewriter for QueryToGraphVisitor {
}))
}
LogicalPlan::TableScan(table_scan) => {
if let Some(projection_indices) = table_scan.projection {
let qualifier = table_scan.table_name.clone();
let projected_schema = DFSchema::try_from_qualified_schema(
qualifier.clone(),
table_scan.source.schema().as_ref(),
)?;
let input_table_scan = LogicalPlan::TableScan(TableScan {
table_name: table_scan.table_name.clone(),
source: table_scan.source.clone(),
projection: None,
projected_schema: Arc::new(projected_schema),
filters: table_scan.filters.clone(),
fetch: table_scan.fetch,
});
let projection_expressions: Vec<_> = projection_indices
.into_iter()
.map(|index| {
Expr::Column(Column {
relation: Some(qualifier.clone()),
name: table_scan.source.schema().fields()[index]
.name()
.to_string(),
})
})
.collect();
let projection = LogicalPlan::Projection(datafusion_expr::Projection::try_new(
projection_expressions,
Arc::new(input_table_scan),
)?);
let mut timestamp_rewriter = TimestampRewriter {};
let projection = projection.rewrite(&mut timestamp_rewriter)?;
return projection.rewrite(self);
if table_scan.projection.is_some() {
return Err(DataFusionError::Internal(
"Unexpected projection in table scan".to_string(),
));
}

let node_index = match self.table_source_to_nodes.get(&table_scan.table_name) {
Some(node_index) => *node_index,
None => {
Expand Down Expand Up @@ -1061,8 +1061,13 @@ pub async fn parse_and_get_arrow_program(
Insert::Anonymous { logical_plan } => (logical_plan, None),
};

let plan_with_timestamp = plan.rewrite(&mut TimestampRewriter {})?;
let plan_rewrite = plan_with_timestamp.rewrite(&mut rewriter).unwrap();
let plan_rewrite = plan
.rewrite(&mut SourceRewriter {
schema_provider: schema_provider.clone(),
})?
.rewrite(&mut TimestampRewriter {})?
.rewrite(&mut rewriter)
.unwrap();

println!("REWRITE: {}", plan_rewrite.display_graphviz());

Expand Down
82 changes: 82 additions & 0 deletions arroyo-df/src/source_rewriter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use crate::tables::FieldSpec;
use crate::tables::Table::ConnectorTable;
use crate::ArroyoSchemaProvider;
use arrow_schema::Schema;
use datafusion_common::tree_node::TreeNodeRewriter;
use datafusion_common::{Column, DFSchema, DataFusionError, Result as DFResult};
use datafusion_expr::{Expr, LogicalPlan, TableScan};
use std::sync::Arc;

/// Rewrites a logical plan to move projections out of table scans
/// and into a separate projection node which may include virtual fields.
pub struct SourceRewriter {
pub(crate) schema_provider: ArroyoSchemaProvider,
}

impl TreeNodeRewriter for SourceRewriter {
type N = LogicalPlan;

fn mutate(&mut self, node: Self::N) -> DFResult<Self::N> {
let LogicalPlan::TableScan(table_scan) = node.clone() else {
return Ok(node);
};

let table_name = table_scan.table_name.table();
let table = self
.schema_provider
.get_table(table_name)
.ok_or_else(|| DataFusionError::Plan(format!("Table {} not found", table_name)))?;

let ConnectorTable(table) = table else {
return Ok(node);
};

let qualifier = table_scan.table_name.clone();

let expressions = table
.fields
.iter()
.map(|field| match field {
FieldSpec::StructField(f) => Expr::Column(Column {
relation: Some(qualifier.clone()),
name: f.name().to_string(),
}),
FieldSpec::VirtualField { field, expression } => {
expression.clone().alias(field.name().to_string())
}
})
.collect::<Vec<_>>();

let non_virtual_fields = table
.fields
.iter()
.filter_map(|field| match field {
FieldSpec::StructField(f) => Some(f.clone()),
_ => None,
})
.collect::<Vec<_>>();

let table_scan_schema = DFSchema::try_from_qualified_schema(
qualifier.clone(),
&Schema::new(non_virtual_fields.clone()),
)?;

let table_scan_table_source = self
.schema_provider
.get_table_source_with_fields(table_name, non_virtual_fields)
.unwrap();

let input_table_scan = LogicalPlan::TableScan(TableScan {
table_name: table_scan.table_name.clone(),
source: table_scan_table_source,
projection: None, // None because we are taking it out
projected_schema: Arc::new(table_scan_schema),
filters: table_scan.filters.clone(),
fetch: table_scan.fetch,
});

Ok(LogicalPlan::Projection(
datafusion_expr::Projection::try_new(expressions, Arc::new(input_table_scan))?,
))
}
}
29 changes: 10 additions & 19 deletions arroyo-df/src/tables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use arroyo_rpc::api_types::connections::{
};
use arroyo_rpc::formats::{BadData, Format, Framing};
use arroyo_types::ArroyoExtensionType;
use datafusion::sql::planner::PlannerContext;
use datafusion::sql::sqlparser::ast::Query;
use datafusion::{
optimizer::{analyzer::Analyzer, optimizer::Optimizer, OptimizerContext},
Expand All @@ -22,8 +23,7 @@ use datafusion::{
use datafusion_common::Column;
use datafusion_common::{config::ConfigOptions, DFField, DFSchema};
use datafusion_expr::{
CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr, LogicalPlan, Projection,
WriteOp,
CreateMemoryTable, CreateView, DdlStatement, DmlStatement, Expr, LogicalPlan, WriteOp,
};
use tracing::info;

Expand Down Expand Up @@ -227,14 +227,6 @@ impl ConnectorTable {
.unwrap_or(false)
}

fn virtual_field_projection(&self) -> Result<Option<Projection>> {
if self.has_virtual_fields() {
bail!("virtual fields not supported in Arrow");
} else {
Ok(None)
}
}

fn timestamp_override(&self) -> Result<Option<Expr>> {
if let Some(field_name) = &self.event_time_field {
if self.is_update() {
Expand Down Expand Up @@ -317,7 +309,6 @@ impl ConnectorTable {
bail!("can't read from a source with virtual fields and update mode.")
}

let _virtual_field_projection = self.virtual_field_projection()?;
let timestamp_override = self.timestamp_override()?;
let watermark_column = self.watermark_column()?;

Expand Down Expand Up @@ -491,7 +482,7 @@ impl Table {
)
.collect();

let _physical_schema = DFSchema::new_with_metadata(
let physical_schema = DFSchema::new_with_metadata(
physical_fields
.iter()
.map(|f| {
Expand All @@ -505,24 +496,24 @@ impl Table {
HashMap::new(),
)?;

let _sql_to_rel = SqlToRel::new(schema_provider);
let sql_to_rel = SqlToRel::new(schema_provider);
struct_field_pairs
.into_iter()
.map(|(struct_field, generating_expression)| {
if let Some(_generating_expression) = generating_expression {
if let Some(generating_expression) = generating_expression {
// TODO: Implement automatic type coercion here, as we have elsewhere.
// It is done by calling the Analyzer which inserts CAST operators where necessary.
todo!("support generating expressions");
/*let df_expr = sql_to_rel.sql_to_expr(

let df_expr = sql_to_rel.sql_to_expr(
generating_expression,
&physical_schema,
&mut PlannerContext::default(),
)?;
let expression = expression_context.compile_expr(&df_expr)?;

Ok(FieldSpec::VirtualField {
field: struct_field,
expression,
})*/
expression: df_expr,
})
} else {
Ok(FieldSpec::StructField(struct_field))
}
Expand Down

0 comments on commit 9d4c259

Please sign in to comment.