Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions datafusion/optimizer/src/eliminate_duplicated_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,23 @@ impl OptimizerRule for EliminateDuplicatedExpr {
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
use std::sync::Arc;

fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
crate::test::assert_optimized_plan_eq(
Arc::new(EliminateDuplicatedExpr::new()),
plan,
expected,
)
macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateDuplicatedExpr::new());
assert_optimized_plan_eq_snapshot!(
rule,
$plan,
@ $expected,
)
}};
}

#[test]
Expand All @@ -137,10 +144,12 @@ mod tests {
.sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
.limit(5, Some(10))?
.build()?;
let expected = "Limit: skip=5, fetch=10\
\n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)

assert_optimized_plan_equal!(plan, @r"
Limit: skip=5, fetch=10
Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST
TableScan: test
")
}

#[test]
Expand All @@ -156,9 +165,11 @@ mod tests {
.sort(sort_exprs)?
.limit(5, Some(10))?
.build()?;
let expected = "Limit: skip=5, fetch=10\
\n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\
\n TableScan: test";
assert_optimized_plan_eq(plan, expected)

assert_optimized_plan_equal!(plan, @r"
Limit: skip=5, fetch=10
Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
TableScan: test
")
}
}
63 changes: 37 additions & 26 deletions datafusion/optimizer/src/eliminate_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,26 @@ impl OptimizerRule for EliminateFilter {
mod tests {
use std::sync::Arc;

use crate::assert_optimized_plan_eq_snapshot;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{
col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
};
use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr};

use crate::eliminate_filter::EliminateFilter;
use crate::test::*;
use datafusion_expr::test::function_stub::sum;

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected)
macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateFilter::new());
assert_optimized_plan_eq_snapshot!(
rule,
$plan,
@ $expected,
)
}};
}

#[test]
Expand All @@ -105,8 +114,7 @@ mod tests {
.build()?;

// No aggregate / scan / limit
let expected = "EmptyRelation";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @"EmptyRelation")
}

#[test]
Expand All @@ -120,8 +128,7 @@ mod tests {
.build()?;

// No aggregate / scan / limit
let expected = "EmptyRelation";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @"EmptyRelation")
}

#[test]
Expand All @@ -139,11 +146,12 @@ mod tests {
.build()?;

// Left side is removed
let expected = "Union\
\n EmptyRelation\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
\n TableScan: test";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @r"
Union
EmptyRelation
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}

#[test]
Expand All @@ -156,9 +164,10 @@ mod tests {
.filter(filter_expr)?
.build()?;

let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
\n TableScan: test";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}

#[test]
Expand All @@ -176,12 +185,13 @@ mod tests {
.build()?;

// Filter is removed
let expected = "Union\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
\n TableScan: test\
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
\n TableScan: test";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @r"
Union
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
TableScan: test
")
}

#[test]
Expand All @@ -202,8 +212,9 @@ mod tests {
.build()?;

// Filter is removed
let expected = "Projection: test.a\
\n EmptyRelation";
assert_optimized_plan_equal(plan, expected)
assert_optimized_plan_equal!(plan, @r"
Projection: test.a
EmptyRelation
")
}
}
121 changes: 47 additions & 74 deletions datafusion/optimizer/src/eliminate_group_by_constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ fn is_constant_expression(expr: &Expr) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use crate::assert_optimized_plan_eq_snapshot;
use crate::test::*;

use arrow::datatypes::DataType;
Expand All @@ -129,6 +130,20 @@ mod tests {

use std::sync::Arc;

macro_rules! assert_optimized_plan_equal {
(
$plan:expr,
@ $expected:literal $(,)?
) => {{
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateGroupByConstant::new());
assert_optimized_plan_eq_snapshot!(
rule,
$plan,
@ $expected,
)
}};
}

#[derive(Debug)]
struct ScalarUDFMock {
signature: Signature,
Expand Down Expand Up @@ -167,17 +182,11 @@ mod tests {
.aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
.build()?;

let expected = "\
Projection: test.a, UInt32(1), count(test.c)\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Projection: test.a, UInt32(1), count(test.c)
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
Expand All @@ -187,17 +196,11 @@ mod tests {
.aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
.build()?;

let expected = "\
Projection: Utf8(\"test\"), UInt32(123), count(test.c)\
\n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r#"
Projection: Utf8("test"), UInt32(123), count(test.c)
Aggregate: groupBy=[[]], aggr=[[count(test.c)]]
TableScan: test
"#)
}

#[test]
Expand All @@ -207,16 +210,10 @@ mod tests {
.aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
.build()?;

let expected = "\
Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
Expand All @@ -226,16 +223,10 @@ mod tests {
.aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
.build()?;

let expected = "\
Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]
TableScan: test
")
}

#[test]
Expand All @@ -248,17 +239,11 @@ mod tests {
)?
.build()?;

let expected = "\
Projection: UInt32(123) AS const, test.a, count(test.c)\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Projection: UInt32(123) AS const, test.a, count(test.c)
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
Expand All @@ -273,17 +258,11 @@ mod tests {
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
.build()?;

let expected = "\
Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
TableScan: test
")
}

#[test]
Expand All @@ -298,15 +277,9 @@ mod tests {
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
.build()?;

let expected = "\
Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\
\n TableScan: test\
";

assert_optimized_plan_eq(
Arc::new(EliminateGroupByConstant::new()),
plan,
expected,
)
assert_optimized_plan_equal!(plan, @r"
Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]
TableScan: test
")
}
}
Loading