Skip to content

Commit ff404cd

Browse files
authored
Migrate Optimizer tests to insta, part2 (#15884)
* migrate tests in `replace_distinct_aggregate.rs` * migrate tests in `replace_distinct_aggregate.rs` * migrate tests in `push_down_limit.rs` * migrate tests in `eliminate_duplicated_expr.rs` * migrate tests in `eliminate_filter.rs` * migrate tests in `eliminate_group_by_constant.rs` to insta * migrate tests in `eliminate_join.rs` to use snapshot assertions * migrate tests in `eliminate_nested_union.rs` to use snapshot assertions * migrate tests in `eliminate_outer_join.rs` to use snapshot assertions * migrate tests in `filter_null_join_keys.rs` to use snapshot assertions * fix Type inferance * fix macro to use crate path for OptimizerContext and Optimizer * clean up
1 parent b782cff commit ff404cd

10 files changed

+700
-494
lines changed

datafusion/optimizer/src/eliminate_duplicated_expr.rs

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,23 @@ impl OptimizerRule for EliminateDuplicatedExpr {
118118
#[cfg(test)]
119119
mod tests {
120120
use super::*;
121+
use crate::assert_optimized_plan_eq_snapshot;
121122
use crate::test::*;
122123
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
123124
use std::sync::Arc;
124125

125-
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
126-
crate::test::assert_optimized_plan_eq(
127-
Arc::new(EliminateDuplicatedExpr::new()),
128-
plan,
129-
expected,
130-
)
126+
macro_rules! assert_optimized_plan_equal {
127+
(
128+
$plan:expr,
129+
@ $expected:literal $(,)?
130+
) => {{
131+
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateDuplicatedExpr::new());
132+
assert_optimized_plan_eq_snapshot!(
133+
rule,
134+
$plan,
135+
@ $expected,
136+
)
137+
}};
131138
}
132139

133140
#[test]
@@ -137,10 +144,12 @@ mod tests {
137144
.sort_by(vec![col("a"), col("a"), col("b"), col("c")])?
138145
.limit(5, Some(10))?
139146
.build()?;
140-
let expected = "Limit: skip=5, fetch=10\
141-
\n Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST\
142-
\n TableScan: test";
143-
assert_optimized_plan_eq(plan, expected)
147+
148+
assert_optimized_plan_equal!(plan, @r"
149+
Limit: skip=5, fetch=10
150+
Sort: test.a ASC NULLS LAST, test.b ASC NULLS LAST, test.c ASC NULLS LAST
151+
TableScan: test
152+
")
144153
}
145154

146155
#[test]
@@ -156,9 +165,11 @@ mod tests {
156165
.sort(sort_exprs)?
157166
.limit(5, Some(10))?
158167
.build()?;
159-
let expected = "Limit: skip=5, fetch=10\
160-
\n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\
161-
\n TableScan: test";
162-
assert_optimized_plan_eq(plan, expected)
168+
169+
assert_optimized_plan_equal!(plan, @r"
170+
Limit: skip=5, fetch=10
171+
Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST
172+
TableScan: test
173+
")
163174
}
164175
}

datafusion/optimizer/src/eliminate_filter.rs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,26 @@ impl OptimizerRule for EliminateFilter {
8181
mod tests {
8282
use std::sync::Arc;
8383

84+
use crate::assert_optimized_plan_eq_snapshot;
8485
use datafusion_common::{Result, ScalarValue};
85-
use datafusion_expr::{
86-
col, lit, logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan,
87-
};
86+
use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, Expr};
8887

8988
use crate::eliminate_filter::EliminateFilter;
9089
use crate::test::*;
9190
use datafusion_expr::test::function_stub::sum;
9291

93-
fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
94-
assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected)
92+
macro_rules! assert_optimized_plan_equal {
93+
(
94+
$plan:expr,
95+
@ $expected:literal $(,)?
96+
) => {{
97+
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateFilter::new());
98+
assert_optimized_plan_eq_snapshot!(
99+
rule,
100+
$plan,
101+
@ $expected,
102+
)
103+
}};
95104
}
96105

97106
#[test]
@@ -105,8 +114,7 @@ mod tests {
105114
.build()?;
106115

107116
// No aggregate / scan / limit
108-
let expected = "EmptyRelation";
109-
assert_optimized_plan_equal(plan, expected)
117+
assert_optimized_plan_equal!(plan, @"EmptyRelation")
110118
}
111119

112120
#[test]
@@ -120,8 +128,7 @@ mod tests {
120128
.build()?;
121129

122130
// No aggregate / scan / limit
123-
let expected = "EmptyRelation";
124-
assert_optimized_plan_equal(plan, expected)
131+
assert_optimized_plan_equal!(plan, @"EmptyRelation")
125132
}
126133

127134
#[test]
@@ -139,11 +146,12 @@ mod tests {
139146
.build()?;
140147

141148
// Left side is removed
142-
let expected = "Union\
143-
\n EmptyRelation\
144-
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
145-
\n TableScan: test";
146-
assert_optimized_plan_equal(plan, expected)
149+
assert_optimized_plan_equal!(plan, @r"
150+
Union
151+
EmptyRelation
152+
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
153+
TableScan: test
154+
")
147155
}
148156

149157
#[test]
@@ -156,9 +164,10 @@ mod tests {
156164
.filter(filter_expr)?
157165
.build()?;
158166

159-
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
160-
\n TableScan: test";
161-
assert_optimized_plan_equal(plan, expected)
167+
assert_optimized_plan_equal!(plan, @r"
168+
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
169+
TableScan: test
170+
")
162171
}
163172

164173
#[test]
@@ -176,12 +185,13 @@ mod tests {
176185
.build()?;
177186

178187
// Filter is removed
179-
let expected = "Union\
180-
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
181-
\n TableScan: test\
182-
\n Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]\
183-
\n TableScan: test";
184-
assert_optimized_plan_equal(plan, expected)
188+
assert_optimized_plan_equal!(plan, @r"
189+
Union
190+
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
191+
TableScan: test
192+
Aggregate: groupBy=[[test.a]], aggr=[[sum(test.b)]]
193+
TableScan: test
194+
")
185195
}
186196

187197
#[test]
@@ -202,8 +212,9 @@ mod tests {
202212
.build()?;
203213

204214
// Filter is removed
205-
let expected = "Projection: test.a\
206-
\n EmptyRelation";
207-
assert_optimized_plan_equal(plan, expected)
215+
assert_optimized_plan_equal!(plan, @r"
216+
Projection: test.a
217+
EmptyRelation
218+
")
208219
}
209220
}

datafusion/optimizer/src/eliminate_group_by_constant.rs

Lines changed: 47 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ fn is_constant_expression(expr: &Expr) -> bool {
115115
#[cfg(test)]
116116
mod tests {
117117
use super::*;
118+
use crate::assert_optimized_plan_eq_snapshot;
118119
use crate::test::*;
119120

120121
use arrow::datatypes::DataType;
@@ -129,6 +130,20 @@ mod tests {
129130

130131
use std::sync::Arc;
131132

133+
macro_rules! assert_optimized_plan_equal {
134+
(
135+
$plan:expr,
136+
@ $expected:literal $(,)?
137+
) => {{
138+
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(EliminateGroupByConstant::new());
139+
assert_optimized_plan_eq_snapshot!(
140+
rule,
141+
$plan,
142+
@ $expected,
143+
)
144+
}};
145+
}
146+
132147
#[derive(Debug)]
133148
struct ScalarUDFMock {
134149
signature: Signature,
@@ -167,17 +182,11 @@ mod tests {
167182
.aggregate(vec![col("a"), lit(1u32)], vec![count(col("c"))])?
168183
.build()?;
169184

170-
let expected = "\
171-
Projection: test.a, UInt32(1), count(test.c)\
172-
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
173-
\n TableScan: test\
174-
";
175-
176-
assert_optimized_plan_eq(
177-
Arc::new(EliminateGroupByConstant::new()),
178-
plan,
179-
expected,
180-
)
185+
assert_optimized_plan_equal!(plan, @r"
186+
Projection: test.a, UInt32(1), count(test.c)
187+
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
188+
TableScan: test
189+
")
181190
}
182191

183192
#[test]
@@ -187,17 +196,11 @@ mod tests {
187196
.aggregate(vec![lit("test"), lit(123u32)], vec![count(col("c"))])?
188197
.build()?;
189198

190-
let expected = "\
191-
Projection: Utf8(\"test\"), UInt32(123), count(test.c)\
192-
\n Aggregate: groupBy=[[]], aggr=[[count(test.c)]]\
193-
\n TableScan: test\
194-
";
195-
196-
assert_optimized_plan_eq(
197-
Arc::new(EliminateGroupByConstant::new()),
198-
plan,
199-
expected,
200-
)
199+
assert_optimized_plan_equal!(plan, @r#"
200+
Projection: Utf8("test"), UInt32(123), count(test.c)
201+
Aggregate: groupBy=[[]], aggr=[[count(test.c)]]
202+
TableScan: test
203+
"#)
201204
}
202205

203206
#[test]
@@ -207,16 +210,10 @@ mod tests {
207210
.aggregate(vec![col("a"), col("b")], vec![count(col("c"))])?
208211
.build()?;
209212

210-
let expected = "\
211-
Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]\
212-
\n TableScan: test\
213-
";
214-
215-
assert_optimized_plan_eq(
216-
Arc::new(EliminateGroupByConstant::new()),
217-
plan,
218-
expected,
219-
)
213+
assert_optimized_plan_equal!(plan, @r"
214+
Aggregate: groupBy=[[test.a, test.b]], aggr=[[count(test.c)]]
215+
TableScan: test
216+
")
220217
}
221218

222219
#[test]
@@ -226,16 +223,10 @@ mod tests {
226223
.aggregate(vec![lit(123u32)], Vec::<Expr>::new())?
227224
.build()?;
228225

229-
let expected = "\
230-
Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]\
231-
\n TableScan: test\
232-
";
233-
234-
assert_optimized_plan_eq(
235-
Arc::new(EliminateGroupByConstant::new()),
236-
plan,
237-
expected,
238-
)
226+
assert_optimized_plan_equal!(plan, @r"
227+
Aggregate: groupBy=[[UInt32(123)]], aggr=[[]]
228+
TableScan: test
229+
")
239230
}
240231

241232
#[test]
@@ -248,17 +239,11 @@ mod tests {
248239
)?
249240
.build()?;
250241

251-
let expected = "\
252-
Projection: UInt32(123) AS const, test.a, count(test.c)\
253-
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
254-
\n TableScan: test\
255-
";
256-
257-
assert_optimized_plan_eq(
258-
Arc::new(EliminateGroupByConstant::new()),
259-
plan,
260-
expected,
261-
)
242+
assert_optimized_plan_equal!(plan, @r"
243+
Projection: UInt32(123) AS const, test.a, count(test.c)
244+
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
245+
TableScan: test
246+
")
262247
}
263248

264249
#[test]
@@ -273,17 +258,11 @@ mod tests {
273258
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
274259
.build()?;
275260

276-
let expected = "\
277-
Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)\
278-
\n Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]\
279-
\n TableScan: test\
280-
";
281-
282-
assert_optimized_plan_eq(
283-
Arc::new(EliminateGroupByConstant::new()),
284-
plan,
285-
expected,
286-
)
261+
assert_optimized_plan_equal!(plan, @r"
262+
Projection: scalar_fn_mock(UInt32(123)), test.a, count(test.c)
263+
Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
264+
TableScan: test
265+
")
287266
}
288267

289268
#[test]
@@ -298,15 +277,9 @@ mod tests {
298277
.aggregate(vec![udf_expr, col("a")], vec![count(col("c"))])?
299278
.build()?;
300279

301-
let expected = "\
302-
Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]\
303-
\n TableScan: test\
304-
";
305-
306-
assert_optimized_plan_eq(
307-
Arc::new(EliminateGroupByConstant::new()),
308-
plan,
309-
expected,
310-
)
280+
assert_optimized_plan_equal!(plan, @r"
281+
Aggregate: groupBy=[[scalar_fn_mock(UInt32(123)), test.a]], aggr=[[count(test.c)]]
282+
TableScan: test
283+
")
311284
}
312285
}

0 commit comments

Comments
 (0)