Skip to content
Merged
26 changes: 15 additions & 11 deletions datafusion/core/tests/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use arrow::datatypes::{
DataType, Field, Fields, Schema, SchemaBuilder, SchemaRef, TimeUnit,
};
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{TransformedResult, TreeNode};
use datafusion_common::tree_node::TransformedResult;
use datafusion_common::{plan_err, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::interval_arithmetic::{Interval, NullableInterval};
use datafusion_expr::{
Expand All @@ -37,14 +37,14 @@ use datafusion_expr::{
use datafusion_functions::core::expr_ext::FieldAccessor;
use datafusion_optimizer::analyzer::Analyzer;
use datafusion_optimizer::optimizer::Optimizer;
use datafusion_optimizer::simplify_expressions::GuaranteeRewriter;
use datafusion_optimizer::{OptimizerConfig, OptimizerContext};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::Statement;
use datafusion_sql::sqlparser::dialect::GenericDialect;
use datafusion_sql::sqlparser::parser::Parser;

use chrono::DateTime;
use datafusion_expr::expr_rewriter::rewrite_with_guarantees;
use datafusion_functions::datetime;

#[cfg(test)]
Expand Down Expand Up @@ -304,8 +304,6 @@ fn test_inequalities_non_null_bounded() {
),
];

let mut rewriter = GuaranteeRewriter::new(guarantees.iter());

// (original_expr, expected_simplification)
let simplified_cases = &[
(col("x").lt(lit(0)), false),
Expand Down Expand Up @@ -337,7 +335,7 @@ fn test_inequalities_non_null_bounded() {
),
];

validate_simplified_cases(&mut rewriter, simplified_cases);
validate_simplified_cases(&guarantees, simplified_cases);

let unchanged_cases = &[
col("x").gt(lit(2)),
Expand All @@ -348,26 +346,32 @@ fn test_inequalities_non_null_bounded() {
col("x").not_between(lit(3), lit(10)),
];

validate_unchanged_cases(&mut rewriter, unchanged_cases);
validate_unchanged_cases(&guarantees, unchanged_cases);
}

fn validate_simplified_cases<T>(rewriter: &mut GuaranteeRewriter, cases: &[(Expr, T)])
where
fn validate_simplified_cases<T>(
guarantees: &[(Expr, NullableInterval)],
cases: &[(Expr, T)],
) where
ScalarValue: From<T>,
T: Clone,
{
for (expr, expected_value) in cases {
let output = expr.clone().rewrite(rewriter).data().unwrap();
let output = rewrite_with_guarantees(expr.clone(), guarantees)
.data()
.unwrap();
let expected = lit(ScalarValue::from(expected_value.clone()));
assert_eq!(
output, expected,
"{expr} simplified to {output}, but expected {expected}"
);
}
}
fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) {
fn validate_unchanged_cases(guarantees: &[(Expr, NullableInterval)], cases: &[Expr]) {
for expr in cases {
let output = expr.clone().rewrite(rewriter).data().unwrap();
let output = rewrite_with_guarantees(expr.clone(), guarantees)
.data()
.unwrap();
assert_eq!(
&output, expr,
"{expr} was simplified to {output}, but expected it to be unchanged"
Expand Down
Loading