Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,53 @@ impl TreeNodeRewriter for Simplifier<'_> {
Transformed::no(Expr::BinaryExpr(BinaryExpr { left, op, right }))
}
}
// For case:
// date_part('YEAR', expr) IN (literal1, literal2, ...)
Expr::InList(InList {
expr,
list,
negated,
}) => {
if list.len() > THRESHOLD_INLINE_INLIST || list.iter().any(is_null) {
return Ok(Transformed::no(Expr::InList(InList {
expr,
list,
negated,
})));
}

let (op, combiner): (Operator, fn(Expr, Expr) -> Expr) =
if negated { (NotEq, and) } else { (Eq, or) };

let mut rewritten: Option<Expr> = None;
for item in &list {
let PreimageResult::Range { interval, expr } =
get_preimage(expr.as_ref(), item, info)?
else {
return Ok(Transformed::no(Expr::InList(InList {
expr,
list,
negated,
})));
};

let range_expr = rewrite_with_preimage(*interval, op, expr)?.data;
rewritten = Some(match rewritten {
None => range_expr,
Some(acc) => combiner(acc, range_expr),
});
}

if let Some(rewritten) = rewritten {
Transformed::yes(rewritten)
} else {
Transformed::no(Expr::InList(InList {
expr,
list,
negated,
}))
}
}

// no additional rewrites possible
expr => Transformed::no(expr),
Expand Down
43 changes: 42 additions & 1 deletion datafusion/optimizer/src/simplify_expressions/udf_preimage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ mod test {
use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Expr, Operator, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility, and, binary_expr, col, lit, preimage::PreimageResult,
Signature, Volatility, and, binary_expr, col, lit, or, preimage::PreimageResult,
simplify::SimplifyContext,
};

Expand Down Expand Up @@ -164,6 +164,15 @@ mod test {
)?),
})
}
Expr::Literal(ScalarValue::Int32(Some(600)), _) => {
Ok(PreimageResult::Range {
expr,
interval: Box::new(Interval::try_new(
ScalarValue::Int32(Some(300)),
ScalarValue::Int32(Some(400)),
)?),
})
}
_ => Ok(PreimageResult::None),
}
}
Expand Down Expand Up @@ -311,6 +320,38 @@ mod test {
assert_eq!(optimize_test(expr, &schema), expected);
}

#[test]
fn test_preimage_in_list_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], false);
let expected = or(
and(col("x").gt_eq(lit(100)), col("x").lt(lit(200))),
and(col("x").gt_eq(lit(300)), col("x").lt(lit(400))),
);

assert_eq!(optimize_test(expr, &schema), expected);
}

#[test]
fn test_preimage_not_in_list_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list(vec![lit(500), lit(600)], true);
let expected = and(
or(col("x").lt(lit(100)), col("x").gt_eq(lit(200))),
or(col("x").lt(lit(300)), col("x").gt_eq(lit(400))),
);

assert_eq!(optimize_test(expr, &schema), expected);
}

#[test]
fn test_preimage_in_list_long_list_no_rewrite() {
let schema = test_schema();
let expr = preimage_udf_expr().in_list((1..100).map(lit).collect(), false);

assert_eq!(optimize_test(expr.clone(), &schema), expr);
}

#[test]
fn test_preimage_non_literal_rhs_no_rewrite() {
// Non-literal RHS should not be rewritten.
Expand Down
25 changes: 24 additions & 1 deletion datafusion/sqllogictest/test_files/datetime/date_part.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,19 @@ NULL
1990-01-01
2030-01-01

# IN list optimization
query D
select c from t1 where extract(year from c) in (1990, 2024);
----
1990-01-01
2024-01-01

# NOT IN list optimization (NULL does not satisfy NOT IN)
query D
select c from t1 where extract(year from c) not in (1990, 2024);
----
2030-01-01

# Check that date_part is not in the explain statements

query TT
Expand Down Expand Up @@ -1329,6 +1342,16 @@ physical_plan
01)FilterExec: c@0 < 2024-01-01 OR c@0 >= 2025-01-01 OR c@0 IS NULL
02)--DataSourceExec: partitions=1, partition_sizes=[1]

query TT
explain select c from t1 where extract (year from c) in (1990, 2024)
----
logical_plan
01)Filter: t1.c >= Date32("1990-01-01") AND t1.c < Date32("1991-01-01") OR t1.c >= Date32("2024-01-01") AND t1.c < Date32("2025-01-01")
02)--TableScan: t1 projection=[c]
physical_plan
01)FilterExec: c@0 >= 1990-01-01 AND c@0 < 1991-01-01 OR c@0 >= 2024-01-01 AND c@0 < 2025-01-01
02)--DataSourceExec: partitions=1, partition_sizes=[1]

# Simple optimizations, column on RHS

query D
Expand Down Expand Up @@ -1730,4 +1753,4 @@ logical_plan
02)--TableScan: t1 projection=[c]
physical_plan
01)FilterExec: c@0 >= 2024-01-01 AND c@0 < 2025-01-01
02)--DataSourceExec: partitions=1, partition_sizes=[1]
02)--DataSourceExec: partitions=1, partition_sizes=[1]
23 changes: 23 additions & 0 deletions datafusion/sqllogictest/test_files/floor_preimage.slt
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ query I rowsort
SELECT id FROM test_data WHERE floor(float_val) = arrow_cast(5.5, 'Float64');
----

# IN list: floor(x) IN (5, 7) matches [5.0, 6.0) and [7.0, 8.0)
query I rowsort
SELECT id FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
----
1
2
5

# NOT IN list: floor(x) NOT IN (5, 7) excludes matching ranges and NULLs
query I rowsort
SELECT id FROM test_data WHERE floor(float_val) NOT IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
----
3
4

##########
## EXPLAIN Tests - Plan Optimization
##########
Expand Down Expand Up @@ -177,6 +192,14 @@ logical_plan
01)Filter: floor(test_data.float_val) = Float64(9007199254740992)
02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val]

# 9. IN list: each list item is rewritten with preimage and OR-ed together
query TT
EXPLAIN SELECT * FROM test_data WHERE floor(float_val) IN (arrow_cast(5, 'Float64'), arrow_cast(7, 'Float64'));
----
logical_plan
01)Filter: test_data.float_val >= Float64(5) AND test_data.float_val < Float64(6) OR test_data.float_val >= Float64(7) AND test_data.float_val < Float64(8)
02)--TableScan: test_data projection=[id, float_val, int_val, decimal_val]

# Data correctness: floor(col) = 2^53 returns no rows (no value in test_data has floor exactly 2^53)
query I rowsort
SELECT id FROM test_data WHERE floor(float_val) = 9007199254740992;
Expand Down