Skip to content

Commit 0cf5630

Browse files
authored
Fix output schema generated by CommonSubExprEliminate (#3726)
* CommonSubexprEliminate: Fix additional col schema * Use correct types in test id_array_visitor * Re-enable fall back schema for datatype resolution Fall back to the merged schema from the whole logical plan if the input schema was not sufficient to resolve the datatype of a sub-expression. This re-enables the fallback logic added in 3860cd3 (#1925). * Add comment on fall-back logic using all schemas Point out that it can likely be removed.
1 parent e10d647 commit 0cf5630

File tree

1 file changed

+149
-34
lines changed

1 file changed

+149
-34
lines changed

datafusion/optimizer/src/common_subexpr_eliminate.rs

Lines changed: 149 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use crate::{OptimizerConfig, OptimizerRule};
2121
use arrow::datatypes::DataType;
22-
use datafusion_common::{DFField, DFSchema, DataFusionError, Result};
22+
use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, Result};
2323
use datafusion_expr::{
2424
col,
2525
expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion},
@@ -94,7 +94,10 @@ fn optimize(
9494
schema,
9595
alias,
9696
}) => {
97-
let arrays = to_arrays(expr, input, &mut expr_set)?;
97+
let input_schema = Arc::clone(input.schema());
98+
let all_schemas: Vec<DFSchemaRef> =
99+
plan.all_schemas().into_iter().cloned().collect();
100+
let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?;
98101

99102
let (mut new_expr, new_input) = rewrite_expr(
100103
&[expr],
@@ -112,22 +115,18 @@ fn optimize(
112115
)?))
113116
}
114117
LogicalPlan::Filter(Filter { predicate, input }) => {
115-
let schema = plan.schema().as_ref().clone();
116-
let data_type = if let Ok(data_type) = predicate.get_type(&schema) {
117-
data_type
118-
} else {
119-
// predicate type could not be resolved in schema, fall back to all schemas
120-
let schemas = plan.all_schemas();
121-
let all_schema =
122-
schemas.into_iter().fold(DFSchema::empty(), |mut lhs, rhs| {
123-
lhs.merge(rhs);
124-
lhs
125-
});
126-
predicate.get_type(&all_schema)?
127-
};
118+
let input_schema = Arc::clone(input.schema());
119+
let all_schemas: Vec<DFSchemaRef> =
120+
plan.all_schemas().into_iter().cloned().collect();
128121

129122
let mut id_array = vec![];
130-
expr_to_identifier(predicate, &mut expr_set, &mut id_array, data_type)?;
123+
expr_to_identifier(
124+
predicate,
125+
&mut expr_set,
126+
&mut id_array,
127+
input_schema,
128+
all_schemas,
129+
)?;
131130

132131
let (mut new_expr, new_input) = rewrite_expr(
133132
&[&[predicate.clone()]],
@@ -153,7 +152,11 @@ fn optimize(
153152
window_expr,
154153
schema,
155154
}) => {
156-
let arrays = to_arrays(window_expr, input, &mut expr_set)?;
155+
let input_schema = Arc::clone(input.schema());
156+
let all_schemas: Vec<DFSchemaRef> =
157+
plan.all_schemas().into_iter().cloned().collect();
158+
let arrays =
159+
to_arrays(window_expr, input_schema, all_schemas, &mut expr_set)?;
157160

158161
let (mut new_expr, new_input) = rewrite_expr(
159162
&[window_expr],
@@ -175,8 +178,17 @@ fn optimize(
175178
input,
176179
schema,
177180
}) => {
178-
let group_arrays = to_arrays(group_expr, input, &mut expr_set)?;
179-
let aggr_arrays = to_arrays(aggr_expr, input, &mut expr_set)?;
181+
let input_schema = Arc::clone(input.schema());
182+
let all_schemas: Vec<DFSchemaRef> =
183+
plan.all_schemas().into_iter().cloned().collect();
184+
let group_arrays = to_arrays(
185+
group_expr,
186+
Arc::clone(&input_schema),
187+
all_schemas.clone(),
188+
&mut expr_set,
189+
)?;
190+
let aggr_arrays =
191+
to_arrays(aggr_expr, input_schema, all_schemas, &mut expr_set)?;
180192

181193
let (mut new_expr, new_input) = rewrite_expr(
182194
&[group_expr, aggr_expr],
@@ -197,7 +209,10 @@ fn optimize(
197209
)?))
198210
}
199211
LogicalPlan::Sort(Sort { expr, input, fetch }) => {
200-
let arrays = to_arrays(expr, input, &mut expr_set)?;
212+
let input_schema = Arc::clone(input.schema());
213+
let all_schemas: Vec<DFSchemaRef> =
214+
plan.all_schemas().into_iter().cloned().collect();
215+
let arrays = to_arrays(expr, input_schema, all_schemas, &mut expr_set)?;
201216

202217
let (mut new_expr, new_input) = rewrite_expr(
203218
&[expr],
@@ -255,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
255270

256271
fn to_arrays(
257272
expr: &[Expr],
258-
input: &LogicalPlan,
273+
input_schema: DFSchemaRef,
274+
all_schemas: Vec<DFSchemaRef>,
259275
expr_set: &mut ExprSet,
260276
) -> Result<Vec<Vec<(usize, String)>>> {
261277
expr.iter()
262278
.map(|e| {
263-
let data_type = e.get_type(input.schema())?;
264279
let mut id_array = vec![];
265-
expr_to_identifier(e, expr_set, &mut id_array, data_type)?;
280+
expr_to_identifier(
281+
e,
282+
expr_set,
283+
&mut id_array,
284+
Arc::clone(&input_schema),
285+
all_schemas.clone(),
286+
)?;
266287

267288
Ok(id_array)
268289
})
@@ -370,7 +391,15 @@ struct ExprIdentifierVisitor<'a> {
370391
expr_set: &'a mut ExprSet,
371392
/// series number (usize) and identifier.
372393
id_array: &'a mut Vec<(usize, Identifier)>,
373-
data_type: DataType,
394+
/// input schema for the node that we're optimizing, so we can determine the correct datatype
395+
/// for each subexpression
396+
input_schema: DFSchemaRef,
397+
/// all schemas in the logical plan, as a fall back if we cannot resolve an expression type
398+
/// from the input schema alone
399+
// This fallback should never be necessary as the expression datatype should always be
400+
// resolvable from the input schema of the node that's being optimized.
401+
// todo: This can likely be removed if we are sure it's safe to do so.
402+
all_schemas: Vec<DFSchemaRef>,
374403

375404
// inner states
376405
visit_stack: Vec<VisitRecord>,
@@ -448,7 +477,25 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
448477

449478
self.id_array[idx] = (self.series_number, desc.clone());
450479
self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
451-
let data_type = self.data_type.clone();
480+
481+
let data_type = if let Ok(data_type) = expr.get_type(&self.input_schema) {
482+
data_type
483+
} else {
484+
// Expression type could not be resolved in schema, fall back to all schemas.
485+
//
486+
// This fallback should never be necessary as the expression datatype should always be
487+
// resolvable from the input schema of the node that's being optimized.
488+
// todo: This else-branch can likely be removed if we are sure it's safe to do so.
489+
let merged_schema =
490+
self.all_schemas
491+
.iter()
492+
.fold(DFSchema::empty(), |mut lhs, rhs| {
493+
lhs.merge(rhs);
494+
lhs
495+
});
496+
expr.get_type(&merged_schema)?
497+
};
498+
452499
self.expr_set
453500
.entry(desc)
454501
.or_insert_with(|| (expr.clone(), 0, data_type))
@@ -462,12 +509,14 @@ fn expr_to_identifier(
462509
expr: &Expr,
463510
expr_set: &mut ExprSet,
464511
id_array: &mut Vec<(usize, Identifier)>,
465-
data_type: DataType,
512+
input_schema: DFSchemaRef,
513+
all_schemas: Vec<DFSchemaRef>,
466514
) -> Result<()> {
467515
expr.accept(ExprIdentifierVisitor {
468516
expr_set,
469517
id_array,
470-
data_type,
518+
input_schema,
519+
all_schemas,
471520
visit_stack: vec![],
472521
node_count: 0,
473522
series_number: 0,
@@ -577,7 +626,8 @@ fn replace_common_expr(
577626
mod test {
578627
use super::*;
579628
use crate::test::*;
580-
use datafusion_expr::logical_plan::JoinType;
629+
use arrow::datatypes::{Field, Schema};
630+
use datafusion_expr::logical_plan::{table_scan, JoinType};
581631
use datafusion_expr::{
582632
avg, binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, sum,
583633
Operator,
@@ -597,22 +647,36 @@ mod test {
597647
fn id_array_visitor() -> Result<()> {
598648
let expr = binary_expr(
599649
binary_expr(
600-
sum(binary_expr(col("a"), Operator::Plus, lit("1"))),
650+
sum(binary_expr(col("a"), Operator::Plus, lit(1))),
601651
Operator::Minus,
602652
avg(col("c")),
603653
),
604654
Operator::Multiply,
605655
lit(2),
606656
);
607657

658+
let schema = Arc::new(DFSchema::new_with_metadata(
659+
vec![
660+
DFField::new(None, "a", DataType::Int64, false),
661+
DFField::new(None, "c", DataType::Int64, false),
662+
],
663+
Default::default(),
664+
)?);
665+
608666
let mut id_array = vec![];
609-
expr_to_identifier(&expr, &mut HashMap::new(), &mut id_array, DataType::Int64)?;
667+
expr_to_identifier(
668+
&expr,
669+
&mut HashMap::new(),
670+
&mut id_array,
671+
Arc::clone(&schema),
672+
vec![schema],
673+
)?;
610674

611675
let expected = vec![
612-
(9, "SUM(a + Utf8(\"1\")) - AVG(c) * Int32(2)Int32(2)SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
613-
(7, "SUM(a + Utf8(\"1\")) - AVG(c)AVG(c)cSUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
614-
(4, "SUM(a + Utf8(\"1\"))a + Utf8(\"1\")Utf8(\"1\")a"),
615-
(3, "a + Utf8(\"1\")Utf8(\"1\")a"),
676+
(9, "SUM(a + Int32(1)) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
677+
(7, "SUM(a + Int32(1)) - AVG(c)AVG(c)cSUM(a + Int32(1))a + Int32(1)Int32(1)a"),
678+
(4, "SUM(a + Int32(1))a + Int32(1)Int32(1)a"),
679+
(3, "a + Int32(1)Int32(1)a"),
616680
(1, ""),
617681
(2, ""),
618682
(6, "AVG(c)c"),
@@ -796,4 +860,55 @@ mod test {
796860
assert!(field_set.insert(field.qualified_name()));
797861
}
798862
}
863+
864+
#[test]
865+
fn eliminated_subexpr_datatype() {
866+
use datafusion_expr::cast;
867+
868+
let schema = Schema::new(vec![
869+
Field::new("a", DataType::UInt64, false),
870+
Field::new("b", DataType::UInt64, false),
871+
Field::new("c", DataType::UInt64, false),
872+
]);
873+
874+
let plan = table_scan(Some("table"), &schema, None)
875+
.unwrap()
876+
.filter(
877+
cast(col("a"), DataType::Int64)
878+
.lt(lit(1_i64))
879+
.and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))),
880+
)
881+
.unwrap()
882+
.build()
883+
.unwrap();
884+
let rule = CommonSubexprEliminate {};
885+
let optimized_plan = rule.optimize(&plan, &mut OptimizerConfig::new()).unwrap();
886+
887+
let schema = optimized_plan.schema();
888+
let fields_with_datatypes: Vec<_> = schema
889+
.fields()
890+
.iter()
891+
.map(|field| (field.name(), field.data_type()))
892+
.collect();
893+
let formatted_fields_with_datatype = format!("{fields_with_datatypes:#?}");
894+
let expected = r###"[
895+
(
896+
"CAST(table.a AS Int64)table.a",
897+
Int64,
898+
),
899+
(
900+
"a",
901+
UInt64,
902+
),
903+
(
904+
"b",
905+
UInt64,
906+
),
907+
(
908+
"c",
909+
UInt64,
910+
),
911+
]"###;
912+
assert_eq!(expected, formatted_fields_with_datatype);
913+
}
799914
}

0 commit comments

Comments
 (0)