1818//! Rewrite expressions based on external expression value range guarantees.
1919
2020use std:: borrow:: Cow ;
21-
2221use crate :: { expr:: InList , lit, Between , BinaryExpr , Expr } ;
2322use datafusion_common:: tree_node:: { Transformed , TreeNode , TreeNodeRewriter } ;
2423use datafusion_common:: { DataFusionError , HashMap , Result , ScalarValue } ;
@@ -79,13 +78,21 @@ pub fn rewrite_with_guarantees_map<'a>(
7978 expr : Expr ,
8079 guarantees : & ' a HashMap < & ' a Expr , & ' a NullableInterval > ,
8180) -> Result < Transformed < Expr > > {
81+ if guarantees. is_empty ( ) {
82+ return Ok ( Transformed :: no ( expr) ) ;
83+ }
84+
8285 expr. transform_up ( |e| rewrite_expr ( e, guarantees) )
8386}
8487
8588impl TreeNodeRewriter for GuaranteeRewriter < ' _ > {
8689 type Node = Expr ;
8790
8891 fn f_up ( & mut self , expr : Expr ) -> Result < Transformed < Expr > > {
92+ if self . guarantees . is_empty ( ) {
93+ return Ok ( Transformed :: no ( expr) ) ;
94+ }
95+
8996 rewrite_expr ( expr, & self . guarantees )
9097 }
9198}
@@ -94,10 +101,6 @@ fn rewrite_expr(
94101 expr : Expr ,
95102 guarantees : & HashMap < & Expr , & NullableInterval > ,
96103) -> Result < Transformed < Expr > > {
97- if guarantees. is_empty ( ) {
98- return Ok ( Transformed :: no ( expr) ) ;
99- }
100-
101104 let new_expr = match & expr {
102105 Expr :: IsNull ( inner) => match guarantees. get ( inner. as_ref ( ) ) {
103106 Some ( NullableInterval :: Null { .. } ) => Some ( lit ( true ) ) ,
@@ -136,7 +139,7 @@ fn rewrite_between(
136139 between : & Between ,
137140 guarantees : & HashMap < & Expr , & NullableInterval > ,
138141) -> Result < Option < Expr > , DataFusionError > {
139- let ( Some ( interval ) , Expr :: Literal ( low, _) , Expr :: Literal ( high, _) ) = (
142+ let ( Some ( expr_interval ) , Expr :: Literal ( low, _) , Expr :: Literal ( high, _) ) = (
140143 guarantees. get ( between. expr . as_ref ( ) ) ,
141144 between. low . as_ref ( ) ,
142145 between. high . as_ref ( ) ,
@@ -148,23 +151,64 @@ fn rewrite_between(
148151 let low = ensure_typed_null ( low, high) ?;
149152 let high = ensure_typed_null ( high, & low) ?;
150153
151- let Ok ( values ) = Interval :: try_new ( low, high) else {
154+ let Ok ( between_interval ) = Interval :: try_new ( low, high) else {
152155 // If we can't create an interval from the literals, be conservative and simply leave
153156 // the expression unmodified.
154157 return Ok ( None ) ;
155158 } ;
156159
157- let expr_interval = NullableInterval :: NotNull { values } ;
160+ if between_interval. lower ( ) . is_null ( ) && between_interval. upper ( ) . is_null ( ) {
161+ return Ok ( Some ( lit ( between_interval. lower ( ) . clone ( ) ) ) ) ;
162+ }
158163
159- let contains = expr_interval. contains ( * interval) ?;
164+ let expr_interval = match expr_interval {
165+ NullableInterval :: Null { datatype } => {
166+ // Value is guaranteed to be null, so we can simplify to null.
167+ return Ok ( Some ( lit ( ScalarValue :: try_new_null ( datatype) . unwrap_or ( ScalarValue :: Null ) ) ) )
168+ } ,
169+ NullableInterval :: MaybeNull { .. } => {
170+ // Value may or may not be null, so we can't simplify the expression.
171+ return Ok ( None )
172+ } ,
173+ NullableInterval :: NotNull { values } => values
174+ } ;
160175
161- if contains. is_certainly_true ( ) {
162- Ok ( Some ( lit ( !between. negated ) ) )
163- } else if contains. is_certainly_false ( ) {
164- Ok ( Some ( lit ( between. negated ) ) )
176+ Ok ( if between_interval. lower ( ) . is_null ( ) {
177+ // <expr> (NOT) BETWEEN NULL AND <high>
178+ let upper_bound = Interval :: from ( between_interval. upper ( ) . clone ( ) ) ;
179+ if expr_interval. gt ( & upper_bound) ?. eq ( & Interval :: TRUE ) {
180+ // if <expr> > high, then certainly false
181+ Some ( lit ( between. negated ) )
182+ } else if expr_interval. lt_eq ( & upper_bound) ?. eq ( & Interval :: TRUE ) {
183+ // if <expr> <= high, then certainly null
184+ Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) ) . unwrap_or ( ScalarValue :: Null ) ) )
185+ } else {
186+ // otherwise unknown
187+ None
188+ }
189+ } else if between_interval. upper ( ) . is_null ( ) {
190+ // <expr> (NOT) BETWEEN <low> AND NULL
191+ let lower_bound = Interval :: from ( between_interval. lower ( ) . clone ( ) ) ;
192+ if expr_interval. lt ( & lower_bound) ?. eq ( & Interval :: TRUE ) {
193+ // if <expr> < low, then certainly false
194+ Some ( lit ( between. negated ) )
195+ } else if expr_interval. gt_eq ( & lower_bound) ?. eq ( & Interval :: TRUE ) {
196+ // if <expr> >= low, then certainly null
197+ Some ( lit ( ScalarValue :: try_new_null ( & expr_interval. data_type ( ) ) . unwrap_or ( ScalarValue :: Null ) ) )
198+ } else {
199+ // otherwise unknown
200+ None
201+ }
165202 } else {
166- Ok ( None )
167- }
203+ let contains = between_interval. contains ( expr_interval) ?;
204+ if contains. eq ( & Interval :: TRUE ) {
205+ Some ( lit ( !between. negated ) )
206+ } else if contains. eq ( & Interval :: FALSE ) {
207+ Some ( lit ( between. negated ) )
208+ } else {
209+ None
210+ }
211+ } )
168212}
169213
170214fn ensure_typed_null (
@@ -262,42 +306,89 @@ mod tests {
262306 use super :: * ;
263307
264308 use crate :: { col, Operator } ;
265- use arrow:: datatypes:: DataType ;
266309 use datafusion_common:: tree_node:: TransformedResult ;
267310 use datafusion_common:: ScalarValue ;
268311
269312 #[ test]
270313 fn test_not_null_guarantee ( ) {
271- // IsNull / IsNotNull can be rewritten to true / false
314+
272315 let guarantees = [
273316 // Note: AlwaysNull case handled by test_column_single_value test,
274317 // since it's a special case of a column with a single value.
275318 (
276319 col ( "x" ) ,
277320 NullableInterval :: NotNull {
278- values : Interval :: make_unbounded ( & DataType :: Int32 ) . unwrap ( ) ,
321+ values : Interval :: make ( Some ( 1 ) , Some ( 3 ) ) . unwrap ( ) ,
279322 } ,
280323 ) ,
281324 ] ;
282325
283- // x IS NULL => guaranteed false
284326 let is_null_cases = vec ! [
327+ // x IS NULL => guaranteed false
285328 ( col( "x" ) . is_null( ) , Some ( lit( false ) ) ) ,
329+ // x IS NOT NULL => guaranteed true
286330 ( col( "x" ) . is_not_null( ) , Some ( lit( true ) ) ) ,
287- ( col( "x" ) . between( lit( 1 ) , lit( 2 ) ) , None ) ,
331+
332+ // [1, 3] BETWEEN 0 AND 10 => guaranteed true
333+ ( col( "x" ) . between( lit( 0 ) , lit( 10 ) ) , Some ( lit( true ) ) ) ,
334+ // x BETWEEN 1 AND -2 => unknown (actually guaranteed false)
288335 ( col( "x" ) . between( lit( 1 ) , lit( -2 ) ) , None ) ,
336+
337+ // [1, 3] BETWEEN NULL AND 0 => guaranteed false
338+ ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 0 ) ) , Some ( lit( false ) ) ) ,
339+ // [1, 3] BETWEEN NULL AND 1 => unknown
340+ ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 1 ) ) , None ) ,
341+ // [1, 3] BETWEEN NULL AND 2 => unknown
342+ ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 2 ) ) , None ) ,
343+ // [1, 3] BETWEEN NULL AND 3 => guaranteed NULL
344+ ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 3 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
345+ // [1, 3] BETWEEN NULL AND 4 => guaranteed NULL
346+ ( col( "x" ) . between( lit( ScalarValue :: Null ) , lit( 4 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
347+
348+ // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
349+ ( col( "x" ) . between( lit( 0 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
350+ // [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
351+ ( col( "x" ) . between( lit( 1 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
352+ // [1, 3] BETWEEN 2 AND NULL => unknown
353+ ( col( "x" ) . between( lit( 2 ) , lit( ScalarValue :: Null ) ) , None ) ,
354+ // [1, 3] BETWEEN 3 AND NULL => unknown
355+ ( col( "x" ) . between( lit( 3 ) , lit( ScalarValue :: Null ) ) , None ) ,
356+ // [1, 3] BETWEEN 4 AND NULL => guaranteed false
357+ ( col( "x" ) . between( lit( 4 ) , lit( ScalarValue :: Null ) ) , Some ( lit( false ) ) ) ,
358+
359+ // [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false
360+ ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 0 ) ) , Some ( lit( true ) ) ) ,
361+ // [1, 3] NOT BETWEEN NULL AND 1 => unknown
362+ ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 1 ) ) , None ) ,
363+ // [1, 3] NOT BETWEEN NULL AND 2 => unknown
364+ ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 2 ) ) , None ) ,
365+ // [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL
366+ ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 3 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
367+ // [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL
368+ ( col( "x" ) . not_between( lit( ScalarValue :: Null ) , lit( 4 ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
369+
370+ // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
371+ ( col( "x" ) . not_between( lit( 0 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
372+ // [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
373+ ( col( "x" ) . not_between( lit( 1 ) , lit( ScalarValue :: Null ) ) , Some ( lit( ScalarValue :: Int32 ( None ) ) ) ) ,
374+ // [1, 3] NOT BETWEEN 2 AND NULL => unknown
375+ ( col( "x" ) . not_between( lit( 2 ) , lit( ScalarValue :: Null ) ) , None ) ,
376+ // [1, 3] NOT BETWEEN 3 AND NULL => unknown
377+ ( col( "x" ) . not_between( lit( 3 ) , lit( ScalarValue :: Null ) ) , None ) ,
378+ // [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false
379+ ( col( "x" ) . not_between( lit( 4 ) , lit( ScalarValue :: Null ) ) , Some ( lit( true ) ) ) ,
289380 ] ;
290381
291382 for case in is_null_cases {
292383 let output = rewrite_with_guarantees ( case. 0 . clone ( ) , guarantees. iter ( ) )
293384 . data ( )
294385 . unwrap ( ) ;
295386 let expected = match case. 1 {
296- None => case. 0 ,
387+ None => case. 0 . clone ( ) ,
297388 Some ( expected) => expected,
298389 } ;
299390
300- assert_eq ! ( output, expected) ;
391+ assert_eq ! ( output, expected, "Failed for {}" , case . 0 ) ;
301392 }
302393 }
303394
0 commit comments