@@ -101,50 +101,42 @@ fn rewrite_expr(
101101 expr : Expr ,
102102 guarantees : & HashMap < & Expr , & NullableInterval > ,
103103) -> Result < Transformed < Expr > > {
104- let new_expr = match & expr {
104+ // If an expression collapses to a single value, replace it with a literal
105+ if let Some ( interval) = guarantees. get ( & expr) {
106+ if let Some ( value) = interval. single_value ( ) {
107+ return Ok ( Transformed :: yes ( lit ( value) ) ) ;
108+ }
109+ }
110+
111+ let result = match expr {
105112 Expr :: IsNull ( inner) => match guarantees. get ( inner. as_ref ( ) ) {
106- Some ( NullableInterval :: Null { .. } ) => Some ( lit ( true ) ) ,
107- Some ( NullableInterval :: NotNull { .. } ) => Some ( lit ( false ) ) ,
108- _ => None ,
113+ Some ( NullableInterval :: Null { .. } ) => Transformed :: yes ( lit ( true ) ) ,
114+ Some ( NullableInterval :: NotNull { .. } ) => Transformed :: yes ( lit ( false ) ) ,
115+ _ => Transformed :: no ( Expr :: IsNull ( inner ) ) ,
109116 } ,
110117 Expr :: IsNotNull ( inner) => match guarantees. get ( inner. as_ref ( ) ) {
111- Some ( NullableInterval :: Null { .. } ) => Some ( lit ( false ) ) ,
112- Some ( NullableInterval :: NotNull { .. } ) => Some ( lit ( true ) ) ,
113- _ => None ,
118+ Some ( NullableInterval :: Null { .. } ) => Transformed :: yes ( lit ( false ) ) ,
119+ Some ( NullableInterval :: NotNull { .. } ) => Transformed :: yes ( lit ( true ) ) ,
120+ _ => Transformed :: no ( Expr :: IsNotNull ( inner ) ) ,
114121 } ,
115122 Expr :: Between ( b) => rewrite_between ( b, guarantees) ?,
116123 Expr :: BinaryExpr ( b) => rewrite_binary_expr ( b, guarantees) ?,
117124 Expr :: InList ( i) => rewrite_inlist ( i, guarantees) ?,
118- _ => None ,
125+ expr => Transformed :: no ( expr ) ,
119126 } ;
120-
121- if let Some ( e) = new_expr {
122- return Ok ( Transformed :: yes ( e) ) ;
123- }
124-
125- match guarantees. get ( & expr) {
126- Some ( interval) => {
127- // If an expression collapses to a single value, replace it with a literal
128- if let Some ( value) = interval. single_value ( ) {
129- Ok ( Transformed :: yes ( lit ( value) ) )
130- } else {
131- Ok ( Transformed :: no ( expr) )
132- }
133- }
134- _ => Ok ( Transformed :: no ( expr) ) ,
135- }
127+ Ok ( result)
136128}
137129
138130fn rewrite_between (
139- between : & Between ,
131+ between : Between ,
140132 guarantees : & HashMap < & Expr , & NullableInterval > ,
141- ) -> Result < Option < Expr > , DataFusionError > {
133+ ) -> Result < Transformed < Expr > > {
142134 let ( Some ( expr_interval) , Expr :: Literal ( low, _) , Expr :: Literal ( high, _) ) = (
143135 guarantees. get ( between. expr . as_ref ( ) ) ,
144136 between. low . as_ref ( ) ,
145137 between. high . as_ref ( ) ,
146138 ) else {
147- return Ok ( None ) ;
139+ return Ok ( Transformed :: no ( Expr :: Between ( between ) ) ) ;
148140 } ;
149141
150142 // Ensure that, if low or high are null, their type matches the other bound
@@ -154,65 +146,66 @@ fn rewrite_between(
154146 let Ok ( between_interval) = Interval :: try_new ( low, high) else {
155147 // If we can't create an interval from the literals, be conservative and simply leave
156148 // the expression unmodified.
157- return Ok ( None ) ;
149+ return Ok ( Transformed :: no ( Expr :: Between ( between ) ) ) ;
158150 } ;
159151
160152 if between_interval. lower ( ) . is_null ( ) && between_interval. upper ( ) . is_null ( ) {
161- return Ok ( Some ( lit ( between_interval. lower ( ) . clone ( ) ) ) ) ;
153+ return Ok ( Transformed :: yes ( lit ( between_interval. lower ( ) . clone ( ) ) ) ) ;
162154 }
163155
164156 let expr_interval = match expr_interval {
165157 NullableInterval :: Null { datatype } => {
166158 // Value is guaranteed to be null, so we can simplify to null.
167- return Ok ( Some ( lit (
159+ return Ok ( Transformed :: yes ( lit (
168160 ScalarValue :: try_new_null ( datatype) . unwrap_or ( ScalarValue :: Null )
169161 ) ) ) ;
170162 }
171163 NullableInterval :: MaybeNull { .. } => {
172164 // Value may or may not be null, so we can't simplify the expression.
173- return Ok ( None ) ;
165+ return Ok ( Transformed :: no ( Expr :: Between ( between ) ) ) ;
174166 }
175167 NullableInterval :: NotNull { values } => values,
176168 } ;
177169
178- Ok ( if between_interval. lower ( ) . is_null ( ) {
170+ let result = if between_interval. lower ( ) . is_null ( ) {
179171 // <expr> (NOT) BETWEEN NULL AND <high>
180172 let upper_bound = Interval :: from ( between_interval. upper ( ) . clone ( ) ) ;
181173 if expr_interval. gt ( & upper_bound) ?. eq ( & Interval :: TRUE ) {
182174 // if <expr> > high, then certainly false
183- Some ( lit ( between. negated ) )
175+ Transformed :: yes ( lit ( between. negated ) )
184176 } else if expr_interval. lt_eq ( & upper_bound) ?. eq ( & Interval :: TRUE ) {
185177 // if <expr> <= high, then certainly null
186- Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
178+ Transformed :: yes ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
187179 . unwrap_or ( ScalarValue :: Null ) ) )
188180 } else {
189181 // otherwise unknown
190- None
182+ Transformed :: no ( Expr :: Between ( between ) )
191183 }
192184 } else if between_interval. upper ( ) . is_null ( ) {
193185 // <expr> (NOT) BETWEEN <low> AND NULL
194186 let lower_bound = Interval :: from ( between_interval. lower ( ) . clone ( ) ) ;
195187 if expr_interval. lt ( & lower_bound) ?. eq ( & Interval :: TRUE ) {
196188 // if <expr> < low, then certainly false
197- Some ( lit ( between. negated ) )
189+ Transformed :: yes ( lit ( between. negated ) )
198190 } else if expr_interval. gt_eq ( & lower_bound) ?. eq ( & Interval :: TRUE ) {
199191 // if <expr> >= low, then certainly null
200- Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
192+ Transformed :: yes ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
201193 . unwrap_or ( ScalarValue :: Null ) ) )
202194 } else {
203195 // otherwise unknown
204- None
196+ Transformed :: no ( Expr :: Between ( between ) )
205197 }
206198 } else {
207199 let contains = between_interval. contains ( expr_interval) ?;
208200 if contains. eq ( & Interval :: TRUE ) {
209- Some ( lit ( !between. negated ) )
201+ Transformed :: yes ( lit ( !between. negated ) )
210202 } else if contains. eq ( & Interval :: FALSE ) {
211- Some ( lit ( between. negated ) )
203+ Transformed :: yes ( lit ( between. negated ) )
212204 } else {
213- None
205+ Transformed :: no ( Expr :: Between ( between ) )
214206 }
215- } )
207+ } ;
208+ Ok ( result)
216209}
217210
218211fn ensure_typed_null (
@@ -229,9 +222,9 @@ fn ensure_typed_null(
229222}
230223
231224fn rewrite_binary_expr (
232- binary : & BinaryExpr ,
225+ binary : BinaryExpr ,
233226 guarantees : & HashMap < & Expr , & NullableInterval > ,
234- ) -> Result < Option < Expr > , DataFusionError > {
227+ ) -> Result < Transformed < Expr > , DataFusionError > {
235228 // The left or right side of expression might either have a guarantee
236229 // or be a literal. Either way, we can resolve them to a NullableInterval.
237230 let left_interval = guarantees
@@ -255,53 +248,53 @@ fn rewrite_binary_expr(
255248 }
256249 } ) ;
257250
258- Ok ( match ( left_interval, right_interval) {
259- ( Some ( left_interval) , Some ( right_interval) ) => {
260- let result =
261- left_interval. apply_operator ( & binary. op , right_interval. as_ref ( ) ) ?;
262- if result. is_certainly_true ( ) {
263- Some ( lit ( true ) )
264- } else if result. is_certainly_false ( ) {
265- Some ( lit ( false ) )
266- } else {
267- None
268- }
251+ if let ( Some ( left_interval) , Some ( right_interval) ) = ( left_interval, right_interval) {
252+ let result = left_interval. apply_operator ( & binary. op , right_interval. as_ref ( ) ) ?;
253+ if result. is_certainly_true ( ) {
254+ return Ok ( Transformed :: yes ( lit ( true ) ) ) ;
255+ } else if result. is_certainly_false ( ) {
256+ return Ok ( Transformed :: yes ( lit ( false ) ) ) ;
269257 }
270- _ => None ,
271- } )
258+ }
259+ Ok ( Transformed :: no ( Expr :: BinaryExpr ( binary ) ) )
272260}
273261
274262fn rewrite_inlist (
275- inlist : & InList ,
263+ inlist : InList ,
276264 guarantees : & HashMap < & Expr , & NullableInterval > ,
277- ) -> Result < Option < Expr > , DataFusionError > {
265+ ) -> Result < Transformed < Expr > , DataFusionError > {
278266 let Some ( interval) = guarantees. get ( inlist. expr . as_ref ( ) ) else {
279- return Ok ( None ) ;
267+ return Ok ( Transformed :: no ( Expr :: InList ( inlist ) ) ) ;
280268 } ;
281269
270+ let InList {
271+ expr,
272+ list,
273+ negated,
274+ } = inlist;
275+
282276 // Can remove items from the list that don't match the guarantee
283- let new_list: Vec < Expr > = inlist
284- . list
285- . iter ( )
277+ let list: Vec < Expr > = list
278+ . into_iter ( )
286279 . filter_map ( |expr| {
287- if let Expr :: Literal ( item, _) = expr {
280+ if let Expr :: Literal ( item, _) = & expr {
288281 match interval. contains ( NullableInterval :: from ( item. clone ( ) ) ) {
289282 // If we know for certain the value isn't in the column's interval,
290283 // we can skip checking it.
291284 Ok ( interval) if interval. is_certainly_false ( ) => None ,
292- Ok ( _) => Some ( Ok ( expr. clone ( ) ) ) ,
285+ Ok ( _) => Some ( Ok ( expr) ) ,
293286 Err ( e) => Some ( Err ( e) ) ,
294287 }
295288 } else {
296- Some ( Ok ( expr. clone ( ) ) )
289+ Some ( Ok ( expr) )
297290 }
298291 } )
299292 . collect :: < Result < _ , DataFusionError > > ( ) ?;
300293
301- Ok ( Some ( Expr :: InList ( InList {
302- expr : inlist . expr . clone ( ) ,
303- list : new_list ,
304- negated : inlist . negated ,
294+ Ok ( Transformed :: yes ( Expr :: InList ( InList {
295+ expr,
296+ list,
297+ negated,
305298 } ) ) )
306299}
307300
@@ -315,6 +308,7 @@ mod tests {
315308
316309 #[ test]
317310 fn test_not_null_guarantee ( ) {
311+ // IsNull / IsNotNull can be rewritten to true / false
318312 let guarantees = [
319313 // Note: AlwaysNull case handled by test_column_single_value test,
320314 // since it's a special case of a column with a single value.
@@ -468,7 +462,7 @@ mod tests {
468462 ScalarValue :: Date32 ( Some ( 18628 ) ) ,
469463 ScalarValue :: Date32 ( None ) ,
470464 )
471- . unwrap ( ) ,
465+ . unwrap ( ) ,
472466 } ,
473467 ) ,
474468 ] ;
@@ -546,7 +540,7 @@ mod tests {
546540 ScalarValue :: from ( "abc" ) ,
547541 ScalarValue :: from ( "def" ) ,
548542 )
549- . unwrap ( ) ,
543+ . unwrap ( ) ,
550544 } ,
551545 ) ,
552546 ] ;
@@ -627,7 +621,7 @@ mod tests {
627621 ScalarValue :: Int32 ( Some ( 1 ) ) ,
628622 ScalarValue :: Int32 ( Some ( 10 ) ) ,
629623 )
630- . unwrap ( ) ,
624+ . unwrap ( ) ,
631625 } ,
632626 ) ,
633627 ] ;
0 commit comments