1717
1818//! Rewrite expressions based on external expression value range guarantees.
1919
20- use std:: borrow:: Cow ;
2120use crate :: { expr:: InList , lit, Between , BinaryExpr , Expr } ;
2221use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRewriter } ;
2322use datafusion_common:: { DataFusionError , HashMap , Result , ScalarValue } ;
2423use datafusion_expr_common:: interval_arithmetic:: { Interval , NullableInterval } ;
24+ use std:: borrow:: Cow ;
2525
2626/// Rewrite expressions to incorporate guarantees.
2727///
@@ -164,13 +164,15 @@ fn rewrite_between(
164164 let expr_interval = match expr_interval {
165165 NullableInterval :: Null { datatype } => {
166166 // Value is guaranteed to be null, so we can simplify to null.
167- return Ok ( Some ( lit ( ScalarValue :: try_new_null ( datatype) . unwrap_or ( ScalarValue :: Null ) ) ) )
168- } ,
167+ return Ok ( Some ( lit (
168+ ScalarValue :: try_new_null ( datatype) . unwrap_or ( ScalarValue :: Null )
169+ ) ) ) ;
170+ }
169171 NullableInterval :: MaybeNull { .. } => {
170172 // Value may or may not be null, so we can't simplify the expression.
171- return Ok ( None )
172- } ,
173- NullableInterval :: NotNull { values } => values
173+ return Ok ( None ) ;
174+ }
175+ NullableInterval :: NotNull { values } => values,
174176 } ;
175177
176178 Ok ( if between_interval. lower ( ) . is_null ( ) {
@@ -181,7 +183,8 @@ fn rewrite_between(
181183 Some ( lit ( between. negated ) )
182184 } else if expr_interval. lt_eq ( & upper_bound) ?. eq ( & Interval :: TRUE ) {
183185 // if <expr> <= high, then certainly null
184- Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) ) . unwrap_or ( ScalarValue :: Null ) ) )
186+ Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
187+ . unwrap_or ( ScalarValue :: Null ) ) )
185188 } else {
186189 // otherwise unknown
187190 None
@@ -194,7 +197,8 @@ fn rewrite_between(
194197 Some ( lit ( between. negated ) )
195198 } else if expr_interval. gt_eq ( & lower_bound) ?. eq ( & Interval :: TRUE ) {
196199 // if <expr> >= low, then certainly null
197- Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) ) . unwrap_or ( ScalarValue :: Null ) ) )
200+ Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) )
201+ . unwrap_or ( ScalarValue :: Null ) ) )
198202 } else {
199203 // otherwise unknown
200204 None
@@ -311,7 +315,6 @@ mod tests {
311315
312316 #[ test]
313317 fn test_not_null_guarantee ( ) {
314-
315318 let guarantees = [
316319 // Note: AlwaysNull case handled by test_column_single_value test,
317320 // since it's a special case of a column with a single value.
@@ -328,55 +331,86 @@ mod tests {
328331 ( col( "x" ) . is_null( ) , Some ( lit( false ) ) ) ,
329332 // x IS NOT NULL => guaranteed true
330333 ( col( "x" ) . is_not_null( ) , Some ( lit( true ) ) ) ,
331-
332334 // [1, 3] BETWEEN 0 AND 10 => guaranteed true
333335 ( col( "x" ) . between( lit( 0 ) , lit( 10 ) ) , Some ( lit( true ) ) ) ,
334336 // x BETWEEN 1 AND -2 => unknown (actually guaranteed false)
335337 ( col( "x" ) . between( lit( 1 ) , lit( -2 ) ) , None ) ,
336-
337338 // [1, 3] BETWEEN NULL AND 0 => guaranteed false
338- ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 0 ) ) , Some ( lit( false ) ) ) ,
339+ (
340+ col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 0 ) ) ,
341+ Some ( lit( false ) ) ,
342+ ) ,
339343 // [1, 3] BETWEEN NULL AND 1 => unknown
340344 ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 1 ) ) , None ) ,
341345 // [1, 3] BETWEEN NULL AND 2 => unknown
342346 ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 2 ) ) , None ) ,
343347 // [1, 3] BETWEEN NULL AND 3 => guaranteed NULL
344- ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 3 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
348+ (
349+ col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 3 ) ) ,
350+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
351+ ) ,
345352 // [1, 3] BETWEEN NULL AND 4 => guaranteed NULL
346- ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 4 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
347-
353+ (
354+ col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 4 ) ) ,
355+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
356+ ) ,
348357 // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
349- ( col( "x" ) . between( lit( 0 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
358+ (
359+ col( "x" ) . between( lit( 0 ) , lit( ScalarValue :: Null ) ) ,
360+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
361+ ) ,
350362 // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
351- ( col( "x" ) . between( lit( 1 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
363+ (
364+ col( "x" ) . between( lit( 1 ) , lit( ScalarValue :: Null ) ) ,
365+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
366+ ) ,
352367 // [1, 3] BETWEEN 2 AND NULL => unknown
353368 ( col( "x" ) . between( lit( 2 ) , lit( ScalarValue :: Null ) ) , None ) ,
354369 // [1, 3] BETWEEN 3 AND NULL => unknown
355370 ( col( "x" ) . between( lit( 3 ) , lit( ScalarValue :: Null ) ) , None ) ,
356371 // [1, 3] BETWEEN 4 AND NULL => guaranteed false
357- ( col( "x" ) . between( lit( 4 ) , lit( ScalarValue :: Null ) ) , Some ( lit( false ) ) ) ,
358-
372+ (
373+ col( "x" ) . between( lit( 4 ) , lit( ScalarValue :: Null ) ) ,
374+ Some ( lit( false ) ) ,
375+ ) ,
359376 // [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false
360- ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 0 ) ) , Some ( lit( true ) ) ) ,
377+ (
378+ col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 0 ) ) ,
379+ Some ( lit( true ) ) ,
380+ ) ,
361381 // [1, 3] NOT BETWEEN NULL AND 1 => unknown
362382 ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 1 ) ) , None ) ,
363383 // [1, 3] NOT BETWEEN NULL AND 2 => unknown
364384 ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 2 ) ) , None ) ,
365385 // [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL
366- ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 3 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
386+ (
387+ col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 3 ) ) ,
388+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
389+ ) ,
367390 // [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL
368- ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 4 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
369-
391+ (
392+ col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 4 ) ) ,
393+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
394+ ) ,
370395 // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
371- ( col( "x" ) . not_between( lit( 0 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
396+ (
397+ col( "x" ) . not_between( lit( 0 ) , lit( ScalarValue :: Null ) ) ,
398+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
399+ ) ,
372400 // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
373- ( col( "x" ) . not_between( lit( 1 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
401+ (
402+ col( "x" ) . not_between( lit( 1 ) , lit( ScalarValue :: Null ) ) ,
403+ Some ( lit( ScalarValue :: Int32 ( None ) ) ) ,
404+ ) ,
374405 // [1, 3] NOT BETWEEN 2 AND NULL => unknown
375406 ( col( "x" ) . not_between( lit( 2 ) , lit( ScalarValue :: Null ) ) , None ) ,
376407 // [1, 3] NOT BETWEEN 3 AND NULL => unknown
377408 ( col( "x" ) . not_between( lit( 3 ) , lit( ScalarValue :: Null ) ) , None ) ,
378409 // [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false
379- ( col( "x" ) . not_between( lit( 4 ) , lit( ScalarValue :: Null ) ) , Some ( lit( true ) ) ) ,
410+ (
411+ col( "x" ) . not_between( lit( 4 ) , lit( ScalarValue :: Null ) ) ,
412+ Some ( lit( true ) ) ,
413+ ) ,
380414 ] ;
381415
382416 for case in is_null_cases {
0 commit comments