Skip to content

Commit e1ca364

Browse files
committed
tmp
1 parent b78231e commit e1ca364

File tree

6 files changed

+90
-70
lines changed

6 files changed

+90
-70
lines changed

benchmarks/expected-plans/q11.txt

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
Sort: value DESC NULLS FIRST
22
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value
33
Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
4-
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty), __sq_1.__value
5-
CrossJoin:
6-
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
7-
Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost
8-
Inner Join: supplier.s_nationkey = nation.n_nationkey
9-
Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
10-
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
11-
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
12-
TableScan: supplier projection=[s_suppkey, s_nationkey]
13-
Projection: nation.n_nationkey
14-
Filter: nation.n_name = Utf8("GERMANY")
15-
TableScan: nation projection=[n_nationkey, n_name]
16-
SubqueryAlias: __sq_1
17-
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value
18-
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
19-
Projection: partsupp.ps_availqty, partsupp.ps_supplycost
20-
Inner Join: supplier.s_nationkey = nation.n_nationkey
21-
Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
22-
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
23-
TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]
24-
TableScan: supplier projection=[s_suppkey, s_nationkey]
25-
Projection: nation.n_nationkey
26-
Filter: nation.n_name = Utf8("GERMANY")
27-
TableScan: nation projection=[n_nationkey, n_name]
4+
CrossJoin:
5+
Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
6+
Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost
7+
Inner Join: supplier.s_nationkey = nation.n_nationkey
8+
Projection: partsupp.ps_partkey, partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
9+
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
10+
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]
11+
TableScan: supplier projection=[s_suppkey, s_nationkey]
12+
Projection: nation.n_nationkey
13+
Filter: nation.n_name = Utf8("GERMANY")
14+
TableScan: nation projection=[n_nationkey, n_name]
15+
SubqueryAlias: __sq_1
16+
Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value
17+
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
18+
Projection: partsupp.ps_availqty, partsupp.ps_supplycost
19+
Inner Join: supplier.s_nationkey = nation.n_nationkey
20+
Projection: partsupp.ps_availqty, partsupp.ps_supplycost, supplier.s_nationkey
21+
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
22+
TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]
23+
TableScan: supplier projection=[s_suppkey, s_nationkey]
24+
Projection: nation.n_nationkey
25+
Filter: nation.n_name = Utf8("GERMANY")
26+
TableScan: nation projection=[n_nationkey, n_name]

benchmarks/expected-plans/q22.txt

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ Sort: custsale.cntrycode ASC NULLS LAST
44
SubqueryAlias: custsale
55
Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal
66
Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __sq_1.__value
7-
Projection: customer.c_phone, customer.c_acctbal, __sq_1.__value
8-
CrossJoin:
7+
CrossJoin:
8+
Projection: customer.c_phone, customer.c_acctbal
99
LeftAnti Join: customer.c_custkey = orders.o_custkey
1010
Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
11-
TableScan: customer projection=[c_custkey, c_name, c_address, c_nationkey, c_phone, c_acctbal, c_mktsegment, c_comment]
12-
TableScan: orders projection=[o_orderkey, o_custkey, o_orderstatus, o_totalprice, o_orderdate, o_orderpriority, o_clerk, o_shippriority, o_comment]
13-
SubqueryAlias: __sq_1
14-
Projection: AVG(customer.c_acctbal) AS __value
15-
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
16-
Projection: customer.c_acctbal
17-
Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
18-
TableScan: customer projection=[c_phone, c_acctbal]
11+
TableScan: customer projection=[c_custkey, c_phone, c_acctbal]
12+
TableScan: orders projection=[o_custkey]
13+
SubqueryAlias: __sq_1
14+
Projection: AVG(customer.c_acctbal) AS __value
15+
Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]
16+
Projection: customer.c_acctbal
17+
Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
18+
TableScan: customer projection=[c_phone, c_acctbal]

datafusion/optimizer/src/optimizer.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,6 @@ impl Optimizer {
227227
match result {
228228
Ok(Some(plan)) => {
229229
if !plan.schema().equivalent_names_and_types(new_plan.schema()) {
230-
let _str = format!("{:?}", new_plan);
231-
let _new_str = format!("{:?}", plan);
232-
233230
return Err(DataFusionError::Internal(format!(
234231
"Optimizer rule '{}' failed, due to generate a different schema, original schema: {:?}, new schema: {:?}",
235232
rule.name(),

datafusion/optimizer/src/push_down_filter.rs

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use datafusion_expr::{
2323
logical_plan::{CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union},
2424
or,
2525
utils::from_plan,
26-
BinaryExpr, Expr, Filter, Operator, Projection, TableProviderFilterPushDown,
26+
BinaryExpr, Expr, Filter, Operator, TableProviderFilterPushDown,
2727
};
2828
use std::collections::{HashMap, HashSet};
2929
use std::iter::once;
@@ -769,28 +769,6 @@ pub fn replace_cols_by_name(
769769
e.rewrite(&mut ColumnReplacer { replace_map })
770770
}
771771

772-
pub fn collect_projection_expr(projection: &Projection) -> HashMap<String, Expr> {
773-
projection
774-
.schema
775-
.fields()
776-
.iter()
777-
.enumerate()
778-
.flat_map(|(i, field)| {
779-
// strip alias, as they should not be part of filters
780-
let expr = match &projection.expr[i] {
781-
Expr::Alias(expr, _) => expr.as_ref().clone(),
782-
expr => expr.clone(),
783-
};
784-
785-
// Convert both qualified and unqualified fields
786-
[
787-
(field.name().clone(), expr.clone()),
788-
(field.qualified_name(), expr),
789-
]
790-
})
791-
.collect::<HashMap<_, _>>()
792-
}
793-
794772
#[cfg(test)]
795773
mod tests {
796774
use super::*;

datafusion/optimizer/src/push_down_projection.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! Projection Push Down optimizer rule ensures that only referenced columns are
1919
//! loaded into memory
2020
21-
use crate::push_down_filter::{collect_projection_expr, replace_cols_by_name};
21+
use crate::push_down_filter::replace_cols_by_name;
2222
use crate::{utils, OptimizerConfig, OptimizerRule};
2323
use arrow::error::Result as ArrowResult;
2424
use datafusion_common::{
@@ -178,6 +178,31 @@ impl OptimizerRule for PushDownProjection {
178178
plan.with_new_inputs(&[new_join])?
179179
}
180180
}
181+
LogicalPlan::CrossJoin(join) => {
182+
// collect column in on/filter in join and projection.
183+
let mut push_columns: HashSet<Column> = HashSet::new();
184+
for e in projection.expr.iter() {
185+
expr_to_columns(e, &mut push_columns)?;
186+
}
187+
let new_left = generate_projection(
188+
&push_columns,
189+
join.left.schema().clone(),
190+
join.left.clone(),
191+
)?;
192+
let new_right = generate_projection(
193+
&push_columns,
194+
join.right.schema().clone(),
195+
join.right.clone(),
196+
)?;
197+
let new_join = child_plan.with_new_inputs(&[new_left, new_right])?;
198+
199+
if can_eliminate(projection, child_plan.schema()) {
200+
// can eliminate projection
201+
new_join
202+
} else {
203+
plan.with_new_inputs(&[new_join])?
204+
}
205+
}
181206
LogicalPlan::TableScan(scan) => {
182207
let mut used_columns: HashSet<Column> = HashSet::new();
183208
for expr in projection.expr.iter() {
@@ -409,6 +434,28 @@ fn generate_column_replace_map(
409434
.collect()
410435
}
411436

437+
pub fn collect_projection_expr(projection: &Projection) -> HashMap<String, Expr> {
438+
projection
439+
.schema
440+
.fields()
441+
.iter()
442+
.enumerate()
443+
.flat_map(|(i, field)| {
444+
// strip alias, as they should not be part of filters
445+
let expr = match &projection.expr[i] {
446+
Expr::Alias(expr, _) => expr.as_ref().clone(),
447+
expr => expr.clone(),
448+
};
449+
450+
// Convert both qualified and unqualified fields
451+
[
452+
(field.name().clone(), expr.clone()),
453+
(field.qualified_name(), expr),
454+
]
455+
})
456+
.collect::<HashMap<_, _>>()
457+
}
458+
412459
fn can_eliminate(projection: &Projection, schema: &DFSchemaRef) -> bool {
413460
if projection.expr.len() != schema.fields().len() {
414461
return false;

datafusion/optimizer/tests/integration-test.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@ fn subquery_filter_with_cast() -> Result<()> {
6565
let plan = test_sql(sql)?;
6666
let expected = "Projection: test.col_int32\
6767
\n Filter: CAST(test.col_int32 AS Float64) > __sq_1.__value\
68-
\n Projection: test.col_int32, __sq_1.__value\
69-
\n CrossJoin:\
70-
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64, col_ts_nano_none, col_ts_nano_utc]\
71-
\n SubqueryAlias: __sq_1\
72-
\n Projection: AVG(test.col_int32) AS __value\
73-
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
74-
\n Projection: test.col_int32\
75-
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
76-
\n TableScan: test projection=[col_int32, col_utf8]";
68+
\n CrossJoin:\
69+
\n TableScan: test projection=[col_int32]\
70+
\n SubqueryAlias: __sq_1\
71+
\n Projection: AVG(test.col_int32) AS __value\
72+
\n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\
73+
\n Projection: test.col_int32\
74+
\n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\
75+
\n TableScan: test projection=[col_int32, col_utf8]";
7776
assert_eq!(expected, format!("{:?}", plan));
7877
Ok(())
7978
}

0 commit comments

Comments
 (0)