Skip to content

Commit 35c12fe

Browse files
committed
Add more between tests
1 parent d706186 commit 35c12fe

File tree

1 file changed

+113
-22
lines changed

1 file changed

+113
-22
lines changed

datafusion/expr/src/expr_rewriter/guarantees.rs

Lines changed: 113 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
//! Rewrite expressions based on external expression value range guarantees.
1919
2020
use std::borrow::Cow;
21-
2221
use crate::{expr::InList, lit, Between, BinaryExpr, Expr};
2322
use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter};
2423
use datafusion_common::{DataFusionError, HashMap, Result, ScalarValue};
@@ -79,13 +78,21 @@ pub fn rewrite_with_guarantees_map<'a>(
7978
expr: Expr,
8079
guarantees: &'a HashMap<&'a Expr, &'a NullableInterval>,
8180
) -> Result<Transformed<Expr>> {
81+
if guarantees.is_empty() {
82+
return Ok(Transformed::no(expr));
83+
}
84+
8285
expr.transform_up(|e| rewrite_expr(e, guarantees))
8386
}
8487

8588
impl TreeNodeRewriter for GuaranteeRewriter<'_> {
8689
type Node = Expr;
8790

8891
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
92+
if self.guarantees.is_empty() {
93+
return Ok(Transformed::no(expr));
94+
}
95+
8996
rewrite_expr(expr, &self.guarantees)
9097
}
9198
}
@@ -94,10 +101,6 @@ fn rewrite_expr(
94101
expr: Expr,
95102
guarantees: &HashMap<&Expr, &NullableInterval>,
96103
) -> Result<Transformed<Expr>> {
97-
if guarantees.is_empty() {
98-
return Ok(Transformed::no(expr));
99-
}
100-
101104
let new_expr = match &expr {
102105
Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
103106
Some(NullableInterval::Null { .. }) => Some(lit(true)),
@@ -136,7 +139,7 @@ fn rewrite_between(
136139
between: &Between,
137140
guarantees: &HashMap<&Expr, &NullableInterval>,
138141
) -> Result<Option<Expr>, DataFusionError> {
139-
let (Some(interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
142+
let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
140143
guarantees.get(between.expr.as_ref()),
141144
between.low.as_ref(),
142145
between.high.as_ref(),
@@ -148,23 +151,64 @@ fn rewrite_between(
148151
let low = ensure_typed_null(low, high)?;
149152
let high = ensure_typed_null(high, &low)?;
150153

151-
let Ok(values) = Interval::try_new(low, high) else {
154+
let Ok(between_interval) = Interval::try_new(low, high) else {
152155
// If we can't create an interval from the literals, be conservative and simply leave
153156
// the expression unmodified.
154157
return Ok(None);
155158
};
156159

157-
let expr_interval = NullableInterval::NotNull { values };
160+
if between_interval.lower().is_null() && between_interval.upper().is_null() {
161+
return Ok(Some(lit(between_interval.lower().clone())));
162+
}
158163

159-
let contains = expr_interval.contains(*interval)?;
164+
let expr_interval = match expr_interval {
165+
NullableInterval::Null { datatype } => {
166+
// Value is guaranteed to be null, so we can simplify to null.
167+
return Ok(Some(lit(ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null))))
168+
},
169+
NullableInterval::MaybeNull { .. } => {
170+
// Value may or may not be null, so we can't simplify the expression.
171+
return Ok(None)
172+
},
173+
NullableInterval::NotNull { values } => values
174+
};
160175

161-
if contains.is_certainly_true() {
162-
Ok(Some(lit(!between.negated)))
163-
} else if contains.is_certainly_false() {
164-
Ok(Some(lit(between.negated)))
176+
Ok(if between_interval.lower().is_null() {
177+
// <expr> (NOT) BETWEEN NULL AND <high>
178+
let upper_bound = Interval::from(between_interval.upper().clone());
179+
if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
180+
// if <expr> > high, then certainly false
181+
Some(lit(between.negated))
182+
} else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
183+
// if <expr> <= high, then certainly null
184+
Some(lit(ScalarValue::try_new_null(&expr_interval.data_type()).unwrap_or(ScalarValue::Null)))
185+
} else {
186+
// otherwise unknown
187+
None
188+
}
189+
} else if between_interval.upper().is_null() {
190+
// <expr> (NOT) BETWEEN <low> AND NULL
191+
let lower_bound = Interval::from(between_interval.lower().clone());
192+
if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
193+
// if <expr> < low, then certainly false
194+
Some(lit(between.negated))
195+
} else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
196+
// if <expr> >= low, then certainly null
197+
Some(lit(ScalarValue::try_new_null(&expr_interval.data_type()).unwrap_or(ScalarValue::Null)))
198+
} else {
199+
// otherwise unknown
200+
None
201+
}
165202
} else {
166-
Ok(None)
167-
}
203+
let contains = between_interval.contains(expr_interval)?;
204+
if contains.eq(&Interval::TRUE) {
205+
Some(lit(!between.negated))
206+
} else if contains.eq(&Interval::FALSE) {
207+
Some(lit(between.negated))
208+
} else {
209+
None
210+
}
211+
})
168212
}
169213

170214
fn ensure_typed_null(
@@ -262,42 +306,89 @@ mod tests {
262306
use super::*;
263307

264308
use crate::{col, Operator};
265-
use arrow::datatypes::DataType;
266309
use datafusion_common::tree_node::TransformedResult;
267310
use datafusion_common::ScalarValue;
268311

269312
#[test]
270313
fn test_not_null_guarantee() {
271-
// IsNull / IsNotNull can be rewritten to true / false
314+
272315
let guarantees = [
273316
// Note: AlwaysNull case handled by test_column_single_value test,
274317
// since it's a special case of a column with a single value.
275318
(
276319
col("x"),
277320
NullableInterval::NotNull {
278-
values: Interval::make_unbounded(&DataType::Int32).unwrap(),
321+
values: Interval::make(Some(1), Some(3)).unwrap(),
279322
},
280323
),
281324
];
282325

283-
// x IS NULL => guaranteed false
284326
let is_null_cases = vec![
327+
// x IS NULL => guaranteed false
285328
(col("x").is_null(), Some(lit(false))),
329+
// x IS NOT NULL => guaranteed true
286330
(col("x").is_not_null(), Some(lit(true))),
287-
(col("x").between(lit(1), lit(2)), None),
331+
332+
// [1, 3] BETWEEN 0 AND 10 => guaranteed true
333+
(col("x").between(lit(0), lit(10)), Some(lit(true))),
334+
// x BETWEEN 1 AND -2 => unknown (actually guaranteed false)
288335
(col("x").between(lit(1), lit(-2)), None),
336+
337+
// [1, 3] BETWEEN NULL AND 0 => guaranteed false
338+
(col("x").between(lit(ScalarValue::Null), lit(0)), Some(lit(false))),
339+
// [1, 3] BETWEEN NULL AND 1 => unknown
340+
(col("x").between(lit(ScalarValue::Null), lit(1)), None),
341+
// [1, 3] BETWEEN NULL AND 2 => unknown
342+
(col("x").between(lit(ScalarValue::Null), lit(2)), None),
343+
// [1, 3] BETWEEN NULL AND 3 => guaranteed NULL
344+
(col("x").between(lit(ScalarValue::Null), lit(3)), Some(lit(ScalarValue::Int32(None)))),
345+
// [1, 3] BETWEEN NULL AND 4 => guaranteed NULL
346+
(col("x").between(lit(ScalarValue::Null), lit(4)), Some(lit(ScalarValue::Int32(None)))),
347+
348+
// [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
349+
(col("x").between(lit(0), lit(ScalarValue::Null)), Some(lit(ScalarValue::Int32(None)))),
350+
// [1, 3] BETWEEN 1 AND NULL => guaranteed NULL
351+
(col("x").between(lit(1), lit(ScalarValue::Null)), Some(lit(ScalarValue::Int32(None)))),
352+
// [1, 3] BETWEEN 2 AND NULL => unknown
353+
(col("x").between(lit(2), lit(ScalarValue::Null)), None),
354+
// [1, 3] BETWEEN 3 AND NULL => unknown
355+
(col("x").between(lit(3), lit(ScalarValue::Null)), None),
356+
// [1, 3] BETWEEN 4 AND NULL => guaranteed false
357+
(col("x").between(lit(4), lit(ScalarValue::Null)), Some(lit(false))),
358+
359+
// [1, 3] NOT BETWEEN NULL AND 0 => guaranteed false
360+
(col("x").not_between(lit(ScalarValue::Null), lit(0)), Some(lit(true))),
361+
// [1, 3] NOT BETWEEN NULL AND 1 => unknown
362+
(col("x").not_between(lit(ScalarValue::Null), lit(1)), None),
363+
// [1, 3] NOT BETWEEN NULL AND 2 => unknown
364+
(col("x").not_between(lit(ScalarValue::Null), lit(2)), None),
365+
// [1, 3] NOT BETWEEN NULL AND 3 => guaranteed NULL
366+
(col("x").not_between(lit(ScalarValue::Null), lit(3)), Some(lit(ScalarValue::Int32(None)))),
367+
// [1, 3] NOT BETWEEN NULL AND 4 => guaranteed NULL
368+
(col("x").not_between(lit(ScalarValue::Null), lit(4)), Some(lit(ScalarValue::Int32(None)))),
369+
370+
// [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
371+
(col("x").not_between(lit(0), lit(ScalarValue::Null)), Some(lit(ScalarValue::Int32(None)))),
372+
// [1, 3] NOT BETWEEN 1 AND NULL => guaranteed NULL
373+
(col("x").not_between(lit(1), lit(ScalarValue::Null)), Some(lit(ScalarValue::Int32(None)))),
374+
// [1, 3] NOT BETWEEN 2 AND NULL => unknown
375+
(col("x").not_between(lit(2), lit(ScalarValue::Null)), None),
376+
// [1, 3] NOT BETWEEN 3 AND NULL => unknown
377+
(col("x").not_between(lit(3), lit(ScalarValue::Null)), None),
378+
// [1, 3] NOT BETWEEN 4 AND NULL => guaranteed false
379+
(col("x").not_between(lit(4), lit(ScalarValue::Null)), Some(lit(true))),
289380
];
290381

291382
for case in is_null_cases {
292383
let output = rewrite_with_guarantees(case.0.clone(), guarantees.iter())
293384
.data()
294385
.unwrap();
295386
let expected = match case.1 {
296-
None => case.0,
387+
None => case.0.clone(),
297388
Some(expected) => expected,
298389
};
299390

300-
assert_eq!(output, expected);
391+
assert_eq!(output, expected, "Failed for {}", case.0);
301392
}
302393
}
303394

0 commit comments

Comments
 (0)