Skip to content

Commit 58868ea

Browse files
committed
Merge remote-tracking branch 'alamb/alamb/less_clone' into guarantees
# Conflicts: # datafusion/expr/src/expr_rewriter/guarantees.rs
2 parents 68a9dea + b53d2cf commit 58868ea

File tree

1 file changed

+67
-73
lines changed

1 file changed

+67
-73
lines changed

datafusion/expr/src/expr_rewriter/guarantees.rs

Lines changed: 67 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -101,50 +101,42 @@ fn rewrite_expr(
101101
expr: Expr,
102102
guarantees: &HashMap<&Expr, &NullableInterval>,
103103
) -> Result<Transformed<Expr>> {
104-
let new_expr = match &expr {
104+
// If an expression collapses to a single value, replace it with a literal
105+
if let Some(interval) = guarantees.get(&expr) {
106+
if let Some(value) = interval.single_value() {
107+
return Ok(Transformed::yes(lit(value)));
108+
}
109+
}
110+
111+
let result = match expr {
105112
Expr::IsNull(inner) => match guarantees.get(inner.as_ref()) {
106-
Some(NullableInterval::Null { .. }) => Some(lit(true)),
107-
Some(NullableInterval::NotNull { .. }) => Some(lit(false)),
108-
_ => None,
113+
Some(NullableInterval::Null { .. }) => Transformed::yes(lit(true)),
114+
Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(false)),
115+
_ => Transformed::no(Expr::IsNull(inner)),
109116
},
110117
Expr::IsNotNull(inner) => match guarantees.get(inner.as_ref()) {
111-
Some(NullableInterval::Null { .. }) => Some(lit(false)),
112-
Some(NullableInterval::NotNull { .. }) => Some(lit(true)),
113-
_ => None,
118+
Some(NullableInterval::Null { .. }) => Transformed::yes(lit(false)),
119+
Some(NullableInterval::NotNull { .. }) => Transformed::yes(lit(true)),
120+
_ => Transformed::no(Expr::IsNotNull(inner)),
114121
},
115122
Expr::Between(b) => rewrite_between(b, guarantees)?,
116123
Expr::BinaryExpr(b) => rewrite_binary_expr(b, guarantees)?,
117124
Expr::InList(i) => rewrite_inlist(i, guarantees)?,
118-
_ => None,
125+
expr => Transformed::no(expr),
119126
};
120-
121-
if let Some(e) = new_expr {
122-
return Ok(Transformed::yes(e));
123-
}
124-
125-
match guarantees.get(&expr) {
126-
Some(interval) => {
127-
// If an expression collapses to a single value, replace it with a literal
128-
if let Some(value) = interval.single_value() {
129-
Ok(Transformed::yes(lit(value)))
130-
} else {
131-
Ok(Transformed::no(expr))
132-
}
133-
}
134-
_ => Ok(Transformed::no(expr)),
135-
}
127+
Ok(result)
136128
}
137129

138130
fn rewrite_between(
139-
between: &Between,
131+
between: Between,
140132
guarantees: &HashMap<&Expr, &NullableInterval>,
141-
) -> Result<Option<Expr>, DataFusionError> {
133+
) -> Result<Transformed<Expr>> {
142134
let (Some(expr_interval), Expr::Literal(low, _), Expr::Literal(high, _)) = (
143135
guarantees.get(between.expr.as_ref()),
144136
between.low.as_ref(),
145137
between.high.as_ref(),
146138
) else {
147-
return Ok(None);
139+
return Ok(Transformed::no(Expr::Between(between)));
148140
};
149141

150142
// Ensure that, if low or high are null, their type matches the other bound
@@ -154,65 +146,66 @@ fn rewrite_between(
154146
let Ok(between_interval) = Interval::try_new(low, high) else {
155147
// If we can't create an interval from the literals, be conservative and simply leave
156148
// the expression unmodified.
157-
return Ok(None);
149+
return Ok(Transformed::no(Expr::Between(between)));
158150
};
159151

160152
if between_interval.lower().is_null() && between_interval.upper().is_null() {
161-
return Ok(Some(lit(between_interval.lower().clone())));
153+
return Ok(Transformed::yes(lit(between_interval.lower().clone())));
162154
}
163155

164156
let expr_interval = match expr_interval {
165157
NullableInterval::Null { datatype } => {
166158
// Value is guaranteed to be null, so we can simplify to null.
167-
return Ok(Some(lit(
159+
return Ok(Transformed::yes(lit(
168160
ScalarValue::try_new_null(datatype).unwrap_or(ScalarValue::Null)
169161
)));
170162
}
171163
NullableInterval::MaybeNull { .. } => {
172164
// Value may or may not be null, so we can't simplify the expression.
173-
return Ok(None);
165+
return Ok(Transformed::no(Expr::Between(between)));
174166
}
175167
NullableInterval::NotNull { values } => values,
176168
};
177169

178-
Ok(if between_interval.lower().is_null() {
170+
let result = if between_interval.lower().is_null() {
179171
// <expr> (NOT) BETWEEN NULL AND <high>
180172
let upper_bound = Interval::from(between_interval.upper().clone());
181173
if expr_interval.gt(&upper_bound)?.eq(&Interval::TRUE) {
182174
// if <expr> > high, then certainly false
183-
Some(lit(between.negated))
175+
Transformed::yes(lit(between.negated))
184176
} else if expr_interval.lt_eq(&upper_bound)?.eq(&Interval::TRUE) {
185177
// if <expr> <= high, then certainly null
186-
Some(lit(ScalarValue::try_new_null(&expr_interval.data_type())
178+
Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
187179
.unwrap_or(ScalarValue::Null)))
188180
} else {
189181
// otherwise unknown
190-
None
182+
Transformed::no(Expr::Between(between))
191183
}
192184
} else if between_interval.upper().is_null() {
193185
// <expr> (NOT) BETWEEN <low> AND NULL
194186
let lower_bound = Interval::from(between_interval.lower().clone());
195187
if expr_interval.lt(&lower_bound)?.eq(&Interval::TRUE) {
196188
// if <expr> < low, then certainly false
197-
Some(lit(between.negated))
189+
Transformed::yes(lit(between.negated))
198190
} else if expr_interval.gt_eq(&lower_bound)?.eq(&Interval::TRUE) {
199191
// if <expr> >= low, then certainly null
200-
Some(lit(ScalarValue::try_new_null(&expr_interval.data_type())
192+
Transformed::yes(lit(ScalarValue::try_new_null(&expr_interval.data_type())
201193
.unwrap_or(ScalarValue::Null)))
202194
} else {
203195
// otherwise unknown
204-
None
196+
Transformed::no(Expr::Between(between))
205197
}
206198
} else {
207199
let contains = between_interval.contains(expr_interval)?;
208200
if contains.eq(&Interval::TRUE) {
209-
Some(lit(!between.negated))
201+
Transformed::yes(lit(!between.negated))
210202
} else if contains.eq(&Interval::FALSE) {
211-
Some(lit(between.negated))
203+
Transformed::yes(lit(between.negated))
212204
} else {
213-
None
205+
Transformed::no(Expr::Between(between))
214206
}
215-
})
207+
};
208+
Ok(result)
216209
}
217210

218211
fn ensure_typed_null(
@@ -229,9 +222,9 @@ fn ensure_typed_null(
229222
}
230223

231224
fn rewrite_binary_expr(
232-
binary: &BinaryExpr,
225+
binary: BinaryExpr,
233226
guarantees: &HashMap<&Expr, &NullableInterval>,
234-
) -> Result<Option<Expr>, DataFusionError> {
227+
) -> Result<Transformed<Expr>, DataFusionError> {
235228
// The left or right side of expression might either have a guarantee
236229
// or be a literal. Either way, we can resolve them to a NullableInterval.
237230
let left_interval = guarantees
@@ -255,53 +248,53 @@ fn rewrite_binary_expr(
255248
}
256249
});
257250

258-
Ok(match (left_interval, right_interval) {
259-
(Some(left_interval), Some(right_interval)) => {
260-
let result =
261-
left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
262-
if result.is_certainly_true() {
263-
Some(lit(true))
264-
} else if result.is_certainly_false() {
265-
Some(lit(false))
266-
} else {
267-
None
268-
}
251+
if let (Some(left_interval), Some(right_interval)) = (left_interval, right_interval) {
252+
let result = left_interval.apply_operator(&binary.op, right_interval.as_ref())?;
253+
if result.is_certainly_true() {
254+
return Ok(Transformed::yes(lit(true)));
255+
} else if result.is_certainly_false() {
256+
return Ok(Transformed::yes(lit(false)));
269257
}
270-
_ => None,
271-
})
258+
}
259+
Ok(Transformed::no(Expr::BinaryExpr(binary)))
272260
}
273261

274262
fn rewrite_inlist(
275-
inlist: &InList,
263+
inlist: InList,
276264
guarantees: &HashMap<&Expr, &NullableInterval>,
277-
) -> Result<Option<Expr>, DataFusionError> {
265+
) -> Result<Transformed<Expr>, DataFusionError> {
278266
let Some(interval) = guarantees.get(inlist.expr.as_ref()) else {
279-
return Ok(None);
267+
return Ok(Transformed::no(Expr::InList(inlist)));
280268
};
281269

270+
let InList {
271+
expr,
272+
list,
273+
negated,
274+
} = inlist;
275+
282276
// Can remove items from the list that don't match the guarantee
283-
let new_list: Vec<Expr> = inlist
284-
.list
285-
.iter()
277+
let list: Vec<Expr> = list
278+
.into_iter()
286279
.filter_map(|expr| {
287-
if let Expr::Literal(item, _) = expr {
280+
if let Expr::Literal(item, _) = &expr {
288281
match interval.contains(NullableInterval::from(item.clone())) {
289282
// If we know for certain the value isn't in the column's interval,
290283
// we can skip checking it.
291284
Ok(interval) if interval.is_certainly_false() => None,
292-
Ok(_) => Some(Ok(expr.clone())),
285+
Ok(_) => Some(Ok(expr)),
293286
Err(e) => Some(Err(e)),
294287
}
295288
} else {
296-
Some(Ok(expr.clone()))
289+
Some(Ok(expr))
297290
}
298291
})
299292
.collect::<Result<_, DataFusionError>>()?;
300293

301-
Ok(Some(Expr::InList(InList {
302-
expr: inlist.expr.clone(),
303-
list: new_list,
304-
negated: inlist.negated,
294+
Ok(Transformed::yes(Expr::InList(InList {
295+
expr,
296+
list,
297+
negated,
305298
})))
306299
}
307300

@@ -315,6 +308,7 @@ mod tests {
315308

316309
#[test]
317310
fn test_not_null_guarantee() {
311+
// IsNull / IsNotNull can be rewritten to true / false
318312
let guarantees = [
319313
// Note: AlwaysNull case handled by test_column_single_value test,
320314
// since it's a special case of a column with a single value.
@@ -468,7 +462,7 @@ mod tests {
468462
ScalarValue::Date32(Some(18628)),
469463
ScalarValue::Date32(None),
470464
)
471-
.unwrap(),
465+
.unwrap(),
472466
},
473467
),
474468
];
@@ -546,7 +540,7 @@ mod tests {
546540
ScalarValue::from("abc"),
547541
ScalarValue::from("def"),
548542
)
549-
.unwrap(),
543+
.unwrap(),
550544
},
551545
),
552546
];
@@ -627,7 +621,7 @@ mod tests {
627621
ScalarValue::Int32(Some(1)),
628622
ScalarValue::Int32(Some(10)),
629623
)
630-
.unwrap(),
624+
.unwrap(),
631625
},
632626
),
633627
];

0 commit comments

Comments
 (0)