19
19
20
20
use crate :: { OptimizerConfig , OptimizerRule } ;
21
21
use arrow:: datatypes:: DataType ;
22
- use datafusion_common:: { DFField , DFSchema , DataFusionError , Result } ;
22
+ use datafusion_common:: { DFField , DFSchema , DFSchemaRef , DataFusionError , Result } ;
23
23
use datafusion_expr:: {
24
24
col,
25
25
expr_rewriter:: { ExprRewritable , ExprRewriter , RewriteRecursion } ,
@@ -94,7 +94,10 @@ fn optimize(
94
94
schema,
95
95
alias,
96
96
} ) => {
97
- let arrays = to_arrays ( expr, input, & mut expr_set) ?;
97
+ let input_schema = Arc :: clone ( input. schema ( ) ) ;
98
+ let all_schemas: Vec < DFSchemaRef > =
99
+ plan. all_schemas ( ) . into_iter ( ) . cloned ( ) . collect ( ) ;
100
+ let arrays = to_arrays ( expr, input_schema, all_schemas, & mut expr_set) ?;
98
101
99
102
let ( mut new_expr, new_input) = rewrite_expr (
100
103
& [ expr] ,
@@ -112,22 +115,18 @@ fn optimize(
112
115
) ?) )
113
116
}
114
117
LogicalPlan :: Filter ( Filter { predicate, input } ) => {
115
- let schema = plan. schema ( ) . as_ref ( ) . clone ( ) ;
116
- let data_type = if let Ok ( data_type) = predicate. get_type ( & schema) {
117
- data_type
118
- } else {
119
- // predicate type could not be resolved in schema, fall back to all schemas
120
- let schemas = plan. all_schemas ( ) ;
121
- let all_schema =
122
- schemas. into_iter ( ) . fold ( DFSchema :: empty ( ) , |mut lhs, rhs| {
123
- lhs. merge ( rhs) ;
124
- lhs
125
- } ) ;
126
- predicate. get_type ( & all_schema) ?
127
- } ;
118
+ let input_schema = Arc :: clone ( input. schema ( ) ) ;
119
+ let all_schemas: Vec < DFSchemaRef > =
120
+ plan. all_schemas ( ) . into_iter ( ) . cloned ( ) . collect ( ) ;
128
121
129
122
let mut id_array = vec ! [ ] ;
130
- expr_to_identifier ( predicate, & mut expr_set, & mut id_array, data_type) ?;
123
+ expr_to_identifier (
124
+ predicate,
125
+ & mut expr_set,
126
+ & mut id_array,
127
+ input_schema,
128
+ all_schemas,
129
+ ) ?;
131
130
132
131
let ( mut new_expr, new_input) = rewrite_expr (
133
132
& [ & [ predicate. clone ( ) ] ] ,
@@ -153,7 +152,11 @@ fn optimize(
153
152
window_expr,
154
153
schema,
155
154
} ) => {
156
- let arrays = to_arrays ( window_expr, input, & mut expr_set) ?;
155
+ let input_schema = Arc :: clone ( input. schema ( ) ) ;
156
+ let all_schemas: Vec < DFSchemaRef > =
157
+ plan. all_schemas ( ) . into_iter ( ) . cloned ( ) . collect ( ) ;
158
+ let arrays =
159
+ to_arrays ( window_expr, input_schema, all_schemas, & mut expr_set) ?;
157
160
158
161
let ( mut new_expr, new_input) = rewrite_expr (
159
162
& [ window_expr] ,
@@ -175,8 +178,17 @@ fn optimize(
175
178
input,
176
179
schema,
177
180
} ) => {
178
- let group_arrays = to_arrays ( group_expr, input, & mut expr_set) ?;
179
- let aggr_arrays = to_arrays ( aggr_expr, input, & mut expr_set) ?;
181
+ let input_schema = Arc :: clone ( input. schema ( ) ) ;
182
+ let all_schemas: Vec < DFSchemaRef > =
183
+ plan. all_schemas ( ) . into_iter ( ) . cloned ( ) . collect ( ) ;
184
+ let group_arrays = to_arrays (
185
+ group_expr,
186
+ Arc :: clone ( & input_schema) ,
187
+ all_schemas. clone ( ) ,
188
+ & mut expr_set,
189
+ ) ?;
190
+ let aggr_arrays =
191
+ to_arrays ( aggr_expr, input_schema, all_schemas, & mut expr_set) ?;
180
192
181
193
let ( mut new_expr, new_input) = rewrite_expr (
182
194
& [ group_expr, aggr_expr] ,
@@ -197,7 +209,10 @@ fn optimize(
197
209
) ?) )
198
210
}
199
211
LogicalPlan :: Sort ( Sort { expr, input, fetch } ) => {
200
- let arrays = to_arrays ( expr, input, & mut expr_set) ?;
212
+ let input_schema = Arc :: clone ( input. schema ( ) ) ;
213
+ let all_schemas: Vec < DFSchemaRef > =
214
+ plan. all_schemas ( ) . into_iter ( ) . cloned ( ) . collect ( ) ;
215
+ let arrays = to_arrays ( expr, input_schema, all_schemas, & mut expr_set) ?;
201
216
202
217
let ( mut new_expr, new_input) = rewrite_expr (
203
218
& [ expr] ,
@@ -255,14 +270,20 @@ fn pop_expr(new_expr: &mut Vec<Vec<Expr>>) -> Result<Vec<Expr>> {
255
270
256
271
fn to_arrays (
257
272
expr : & [ Expr ] ,
258
- input : & LogicalPlan ,
273
+ input_schema : DFSchemaRef ,
274
+ all_schemas : Vec < DFSchemaRef > ,
259
275
expr_set : & mut ExprSet ,
260
276
) -> Result < Vec < Vec < ( usize , String ) > > > {
261
277
expr. iter ( )
262
278
. map ( |e| {
263
- let data_type = e. get_type ( input. schema ( ) ) ?;
264
279
let mut id_array = vec ! [ ] ;
265
- expr_to_identifier ( e, expr_set, & mut id_array, data_type) ?;
280
+ expr_to_identifier (
281
+ e,
282
+ expr_set,
283
+ & mut id_array,
284
+ Arc :: clone ( & input_schema) ,
285
+ all_schemas. clone ( ) ,
286
+ ) ?;
266
287
267
288
Ok ( id_array)
268
289
} )
@@ -370,7 +391,15 @@ struct ExprIdentifierVisitor<'a> {
370
391
expr_set : & ' a mut ExprSet ,
371
392
/// series number (usize) and identifier.
372
393
id_array : & ' a mut Vec < ( usize , Identifier ) > ,
373
- data_type : DataType ,
394
+ /// input schema for the node that we're optimizing, so we can determine the correct datatype
395
+ /// for each subexpression
396
+ input_schema : DFSchemaRef ,
397
+ /// all schemas in the logical plan, as a fall back if we cannot resolve an expression type
398
+ /// from the input schema alone
399
+ // This fallback should never be necessary as the expression datatype should always be
400
+ // resolvable from the input schema of the node that's being optimized.
401
+ // todo: This can likely be removed if we are sure it's safe to do so.
402
+ all_schemas : Vec < DFSchemaRef > ,
374
403
375
404
// inner states
376
405
visit_stack : Vec < VisitRecord > ,
@@ -448,7 +477,25 @@ impl ExpressionVisitor for ExprIdentifierVisitor<'_> {
448
477
449
478
self . id_array [ idx] = ( self . series_number , desc. clone ( ) ) ;
450
479
self . visit_stack . push ( VisitRecord :: ExprItem ( desc. clone ( ) ) ) ;
451
- let data_type = self . data_type . clone ( ) ;
480
+
481
+ let data_type = if let Ok ( data_type) = expr. get_type ( & self . input_schema ) {
482
+ data_type
483
+ } else {
484
+ // Expression type could not be resolved in schema, fall back to all schemas.
485
+ //
486
+ // This fallback should never be necessary as the expression datatype should always be
487
+ // resolvable from the input schema of the node that's being optimized.
488
+ // todo: This else-branch can likely be removed if we are sure it's safe to do so.
489
+ let merged_schema =
490
+ self . all_schemas
491
+ . iter ( )
492
+ . fold ( DFSchema :: empty ( ) , |mut lhs, rhs| {
493
+ lhs. merge ( rhs) ;
494
+ lhs
495
+ } ) ;
496
+ expr. get_type ( & merged_schema) ?
497
+ } ;
498
+
452
499
self . expr_set
453
500
. entry ( desc)
454
501
. or_insert_with ( || ( expr. clone ( ) , 0 , data_type) )
@@ -462,12 +509,14 @@ fn expr_to_identifier(
462
509
expr : & Expr ,
463
510
expr_set : & mut ExprSet ,
464
511
id_array : & mut Vec < ( usize , Identifier ) > ,
465
- data_type : DataType ,
512
+ input_schema : DFSchemaRef ,
513
+ all_schemas : Vec < DFSchemaRef > ,
466
514
) -> Result < ( ) > {
467
515
expr. accept ( ExprIdentifierVisitor {
468
516
expr_set,
469
517
id_array,
470
- data_type,
518
+ input_schema,
519
+ all_schemas,
471
520
visit_stack : vec ! [ ] ,
472
521
node_count : 0 ,
473
522
series_number : 0 ,
@@ -577,7 +626,8 @@ fn replace_common_expr(
577
626
mod test {
578
627
use super :: * ;
579
628
use crate :: test:: * ;
580
- use datafusion_expr:: logical_plan:: JoinType ;
629
+ use arrow:: datatypes:: { Field , Schema } ;
630
+ use datafusion_expr:: logical_plan:: { table_scan, JoinType } ;
581
631
use datafusion_expr:: {
582
632
avg, binary_expr, col, lit, logical_plan:: builder:: LogicalPlanBuilder , sum,
583
633
Operator ,
@@ -597,22 +647,36 @@ mod test {
597
647
fn id_array_visitor ( ) -> Result < ( ) > {
598
648
let expr = binary_expr (
599
649
binary_expr (
600
- sum ( binary_expr ( col ( "a" ) , Operator :: Plus , lit ( "1" ) ) ) ,
650
+ sum ( binary_expr ( col ( "a" ) , Operator :: Plus , lit ( 1 ) ) ) ,
601
651
Operator :: Minus ,
602
652
avg ( col ( "c" ) ) ,
603
653
) ,
604
654
Operator :: Multiply ,
605
655
lit ( 2 ) ,
606
656
) ;
607
657
658
+ let schema = Arc :: new ( DFSchema :: new_with_metadata (
659
+ vec ! [
660
+ DFField :: new( None , "a" , DataType :: Int64 , false ) ,
661
+ DFField :: new( None , "c" , DataType :: Int64 , false ) ,
662
+ ] ,
663
+ Default :: default ( ) ,
664
+ ) ?) ;
665
+
608
666
let mut id_array = vec ! [ ] ;
609
- expr_to_identifier ( & expr, & mut HashMap :: new ( ) , & mut id_array, DataType :: Int64 ) ?;
667
+ expr_to_identifier (
668
+ & expr,
669
+ & mut HashMap :: new ( ) ,
670
+ & mut id_array,
671
+ Arc :: clone ( & schema) ,
672
+ vec ! [ schema] ,
673
+ ) ?;
610
674
611
675
let expected = vec ! [
612
- ( 9 , "SUM(a + Utf8( \" 1 \" )) - AVG(c) * Int32(2)Int32(2)SUM(a + Utf8( \" 1 \" )) - AVG(c)AVG(c)cSUM(a + Utf8( \" 1 \" ))a + Utf8( \" 1 \" )Utf8( \" 1 \" )a" ) ,
613
- ( 7 , "SUM(a + Utf8( \" 1 \" )) - AVG(c)AVG(c)cSUM(a + Utf8( \" 1 \" ))a + Utf8( \" 1 \" )Utf8( \" 1 \" )a" ) ,
614
- ( 4 , "SUM(a + Utf8( \" 1 \" ))a + Utf8( \" 1 \" )Utf8( \" 1 \" )a" ) ,
615
- ( 3 , "a + Utf8( \" 1 \" )Utf8( \" 1 \" )a" ) ,
676
+ ( 9 , "SUM(a + Int32(1 )) - AVG(c) * Int32(2)Int32(2)SUM(a + Int32(1 )) - AVG(c)AVG(c)cSUM(a + Int32(1 ))a + Int32(1)Int32(1 )a" ) ,
677
+ ( 7 , "SUM(a + Int32(1 )) - AVG(c)AVG(c)cSUM(a + Int32(1 ))a + Int32(1)Int32(1 )a" ) ,
678
+ ( 4 , "SUM(a + Int32(1 ))a + Int32(1)Int32(1 )a" ) ,
679
+ ( 3 , "a + Int32(1)Int32(1 )a" ) ,
616
680
( 1 , "" ) ,
617
681
( 2 , "" ) ,
618
682
( 6 , "AVG(c)c" ) ,
@@ -796,4 +860,55 @@ mod test {
796
860
assert ! ( field_set. insert( field. qualified_name( ) ) ) ;
797
861
}
798
862
}
863
+
864
+ #[ test]
865
+ fn eliminated_subexpr_datatype ( ) {
866
+ use datafusion_expr:: cast;
867
+
868
+ let schema = Schema :: new ( vec ! [
869
+ Field :: new( "a" , DataType :: UInt64 , false ) ,
870
+ Field :: new( "b" , DataType :: UInt64 , false ) ,
871
+ Field :: new( "c" , DataType :: UInt64 , false ) ,
872
+ ] ) ;
873
+
874
+ let plan = table_scan ( Some ( "table" ) , & schema, None )
875
+ . unwrap ( )
876
+ . filter (
877
+ cast ( col ( "a" ) , DataType :: Int64 )
878
+ . lt ( lit ( 1_i64 ) )
879
+ . and ( cast ( col ( "a" ) , DataType :: Int64 ) . not_eq ( lit ( 1_i64 ) ) ) ,
880
+ )
881
+ . unwrap ( )
882
+ . build ( )
883
+ . unwrap ( ) ;
884
+ let rule = CommonSubexprEliminate { } ;
885
+ let optimized_plan = rule. optimize ( & plan, & mut OptimizerConfig :: new ( ) ) . unwrap ( ) ;
886
+
887
+ let schema = optimized_plan. schema ( ) ;
888
+ let fields_with_datatypes: Vec < _ > = schema
889
+ . fields ( )
890
+ . iter ( )
891
+ . map ( |field| ( field. name ( ) , field. data_type ( ) ) )
892
+ . collect ( ) ;
893
+ let formatted_fields_with_datatype = format ! ( "{fields_with_datatypes:#?}" ) ;
894
+ let expected = r###"[
895
+ (
896
+ "CAST(table.a AS Int64)table.a",
897
+ Int64,
898
+ ),
899
+ (
900
+ "a",
901
+ UInt64,
902
+ ),
903
+ (
904
+ "b",
905
+ UInt64,
906
+ ),
907
+ (
908
+ "c",
909
+ UInt64,
910
+ ),
911
+ ]"### ;
912
+ assert_eq ! ( expected, formatted_fields_with_datatype) ;
913
+ }
799
914
}
0 commit comments