@@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator;
29
29
use datafusion_common:: tree_node:: {
30
30
Transformed , TransformedResult , TreeNode , TreeNodeRecursion , TreeNodeRewriter ,
31
31
} ;
32
- use datafusion_common:: { plan_err, Column , Result , ScalarValue } ;
32
+ use datafusion_common:: { internal_err , plan_err, Column , Result , ScalarValue } ;
33
33
use datafusion_expr:: expr_rewriter:: create_col_from_scalar_expr;
34
34
use datafusion_expr:: logical_plan:: { JoinType , Subquery } ;
35
35
use datafusion_expr:: utils:: conjunction;
@@ -50,7 +50,7 @@ impl ScalarSubqueryToJoin {
50
50
/// # Arguments
51
51
/// * `predicate` - A conjunction to split and search
52
52
///
53
- /// Returns a tuple (subqueries, rewrite expression )
53
+ /// Returns a tuple (subqueries, alias )
54
54
fn extract_subquery_exprs (
55
55
& self ,
56
56
predicate : & Expr ,
@@ -71,19 +71,36 @@ impl ScalarSubqueryToJoin {
71
71
impl OptimizerRule for ScalarSubqueryToJoin {
72
72
fn try_optimize (
73
73
& self ,
74
- plan : & LogicalPlan ,
75
- config : & dyn OptimizerConfig ,
74
+ _plan : & LogicalPlan ,
75
+ _config : & dyn OptimizerConfig ,
76
76
) -> Result < Option < LogicalPlan > > {
77
+ internal_err ! ( "Should have called ScalarSubqueryToJoin::rewrite" )
78
+ }
79
+
80
+ fn supports_rewrite ( & self ) -> bool {
81
+ true
82
+ }
83
+
84
+ fn rewrite (
85
+ & self ,
86
+ plan : LogicalPlan ,
87
+ config : & dyn OptimizerConfig ,
88
+ ) -> Result < Transformed < LogicalPlan > > {
77
89
match plan {
78
90
LogicalPlan :: Filter ( filter) => {
91
+ // Optimization: skip the rest of the rule and its copies if
92
+ // there are no scalar subqueries
93
+ if !contains_scalar_subquery ( & filter. predicate ) {
94
+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter) ) ) ;
95
+ }
96
+
79
97
let ( subqueries, mut rewrite_expr) = self . extract_subquery_exprs (
80
98
& filter. predicate ,
81
99
config. alias_generator ( ) ,
82
100
) ?;
83
101
84
102
if subqueries. is_empty ( ) {
85
- // regular filter, no subquery exists clause here
86
- return Ok ( None ) ;
103
+ return internal_err ! ( "Expected subqueries not found in filter" ) ;
87
104
}
88
105
89
106
// iterate through all subqueries in predicate, turning each into a left join
@@ -94,16 +111,13 @@ impl OptimizerRule for ScalarSubqueryToJoin {
94
111
{
95
112
if !expr_check_map. is_empty ( ) {
96
113
rewrite_expr = rewrite_expr
97
- . clone ( )
98
114
. transform_up ( |expr| {
99
- if let Expr :: Column ( col) = & expr {
100
- if let Some ( map_expr) =
101
- expr_check_map. get ( & col. name )
102
- {
103
- Ok ( Transformed :: yes ( map_expr. clone ( ) ) )
104
- } else {
105
- Ok ( Transformed :: no ( expr) )
106
- }
115
+ // replace column references with entry in map, if it exists
116
+ if let Some ( map_expr) = expr
117
+ . try_as_col ( )
118
+ . and_then ( |col| expr_check_map. get ( & col. name ) )
119
+ {
120
+ Ok ( Transformed :: yes ( map_expr. clone ( ) ) )
107
121
} else {
108
122
Ok ( Transformed :: no ( expr) )
109
123
}
@@ -113,15 +127,21 @@ impl OptimizerRule for ScalarSubqueryToJoin {
113
127
cur_input = optimized_subquery;
114
128
} else {
115
129
// if we can't handle all of the subqueries then bail for now
116
- return Ok ( None ) ;
130
+ return Ok ( Transformed :: no ( LogicalPlan :: Filter ( filter ) ) ) ;
117
131
}
118
132
}
119
133
let new_plan = LogicalPlanBuilder :: from ( cur_input)
120
134
. filter ( rewrite_expr) ?
121
135
. build ( ) ?;
122
- Ok ( Some ( new_plan) )
136
+ Ok ( Transformed :: yes ( new_plan) )
123
137
}
124
138
LogicalPlan :: Projection ( projection) => {
139
+ // Optimization: skip the rest of the rule and its copies if
140
+ // there are no scalar subqueries
141
+ if !projection. expr . iter ( ) . any ( contains_scalar_subquery) {
142
+ return Ok ( Transformed :: no ( LogicalPlan :: Projection ( projection) ) ) ;
143
+ }
144
+
125
145
let mut all_subqueryies = vec ! [ ] ;
126
146
let mut expr_to_rewrite_expr_map = HashMap :: new ( ) ;
127
147
let mut subquery_to_expr_map = HashMap :: new ( ) ;
@@ -135,8 +155,7 @@ impl OptimizerRule for ScalarSubqueryToJoin {
135
155
expr_to_rewrite_expr_map. insert ( expr, rewrite_exprs) ;
136
156
}
137
157
if all_subqueryies. is_empty ( ) {
138
- // regular projection, no subquery exists clause here
139
- return Ok ( None ) ;
158
+ return internal_err ! ( "Expected subqueries not found in projection" ) ;
140
159
}
141
160
// iterate through all subqueries in predicate, turning each into a left join
142
161
let mut cur_input = projection. input . as_ref ( ) . clone ( ) ;
@@ -153,14 +172,13 @@ impl OptimizerRule for ScalarSubqueryToJoin {
153
172
let new_expr = rewrite_expr
154
173
. clone ( )
155
174
. transform_up ( |expr| {
156
- if let Expr :: Column ( col) = & expr {
157
- if let Some ( map_expr) =
175
+ // replace column references with entry in map, if it exists
176
+ if let Some ( map_expr) =
177
+ expr. try_as_col ( ) . and_then ( |col| {
158
178
expr_check_map. get ( & col. name )
159
- {
160
- Ok ( Transformed :: yes ( map_expr. clone ( ) ) )
161
- } else {
162
- Ok ( Transformed :: no ( expr) )
163
- }
179
+ } )
180
+ {
181
+ Ok ( Transformed :: yes ( map_expr. clone ( ) ) )
164
182
} else {
165
183
Ok ( Transformed :: no ( expr) )
166
184
}
@@ -172,7 +190,7 @@ impl OptimizerRule for ScalarSubqueryToJoin {
172
190
}
173
191
} else {
174
192
// if we can't handle all of the subqueries then bail for now
175
- return Ok ( None ) ;
193
+ return Ok ( Transformed :: no ( LogicalPlan :: Projection ( projection ) ) ) ;
176
194
}
177
195
}
178
196
@@ -190,10 +208,10 @@ impl OptimizerRule for ScalarSubqueryToJoin {
190
208
let new_plan = LogicalPlanBuilder :: from ( cur_input)
191
209
. project ( proj_exprs) ?
192
210
. build ( ) ?;
193
- Ok ( Some ( new_plan) )
211
+ Ok ( Transformed :: yes ( new_plan) )
194
212
}
195
213
196
- _ => Ok ( None ) ,
214
+ plan => Ok ( Transformed :: no ( plan ) ) ,
197
215
}
198
216
}
199
217
@@ -206,6 +224,13 @@ impl OptimizerRule for ScalarSubqueryToJoin {
206
224
}
207
225
}
208
226
227
+ /// Returns true if the expression has a scalar subquery somewhere in it
228
+ /// false otherwise
229
+ fn contains_scalar_subquery ( expr : & Expr ) -> bool {
230
+ expr. exists ( |expr| Ok ( matches ! ( expr, Expr :: ScalarSubquery ( _) ) ) )
231
+ . expect ( "Inner is always Ok" )
232
+ }
233
+
209
234
struct ExtractScalarSubQuery {
210
235
sub_query_info : Vec < ( Subquery , String ) > ,
211
236
alias_gen : Arc < AliasGenerator > ,
0 commit comments