Skip to content

Commit ac1631a

Browse files
authored
Add simplification rules for the CONCAT function (#3684)
* simpl concat Signed-off-by: remzi <13716567376yh@gmail.com> * update after type coercion Signed-off-by: remzi <13716567376yh@gmail.com> Signed-off-by: remzi <13716567376yh@gmail.com>
1 parent 0cf5630 commit ac1631a

File tree

2 files changed

+81
-0
lines changed

2 files changed

+81
-0
lines changed

datafusion/optimizer/src/simplify_expressions.rs

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,12 +878,56 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
878878
out_expr.rewrite(self)?
879879
}
880880

881+
// concat
882+
ScalarFunction {
883+
fun: BuiltinScalarFunction::Concat,
884+
args,
885+
} => {
886+
let mut new_args = Vec::with_capacity(args.len());
887+
let mut contiguous_scalar = "".to_string();
888+
for e in args {
889+
match e {
890+
// All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
891+
// Concatenate it with `contiguous_scalar`.
892+
Expr::Literal(
893+
ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
894+
) => {
895+
if let Some(s) = x {
896+
contiguous_scalar += &s;
897+
}
898+
}
899+
// If the arg is not a literal, we should first push the current `contiguous_scalar`
900+
// to the `new_args` (if it is not empty) and reset it to empty string.
901+
// Then pushing this arg to the `new_args`.
902+
e => {
903+
if !contiguous_scalar.is_empty() {
904+
new_args.push(Expr::Literal(ScalarValue::Utf8(Some(
905+
contiguous_scalar.clone(),
906+
))));
907+
contiguous_scalar = "".to_string();
908+
}
909+
new_args.push(e);
910+
}
911+
}
912+
}
913+
if !contiguous_scalar.is_empty() {
914+
new_args
915+
.push(Expr::Literal(ScalarValue::Utf8(Some(contiguous_scalar))));
916+
}
917+
918+
ScalarFunction {
919+
fun: BuiltinScalarFunction::Concat,
920+
args: new_args,
921+
}
922+
}
923+
881924
// concat_ws
882925
ScalarFunction {
883926
fun: BuiltinScalarFunction::ConcatWithSeparator,
884927
args,
885928
} => {
886929
match &args[..] {
930+
// concat_ws(null, ..) --> null
887931
[Expr::Literal(sp), ..] if sp.is_null() => {
888932
Expr::Literal(ScalarValue::Utf8(None))
889933
}
@@ -1352,6 +1396,30 @@ mod tests {
13521396
}
13531397
}
13541398

1399+
#[test]
1400+
fn test_simplify_concat() {
1401+
fn build_concat_expr(args: &[Expr]) -> Expr {
1402+
Expr::ScalarFunction {
1403+
fun: BuiltinScalarFunction::Concat,
1404+
args: args.to_vec(),
1405+
}
1406+
}
1407+
1408+
let null = Expr::Literal(ScalarValue::Utf8(None));
1409+
let expr = build_concat_expr(&[
1410+
null.clone(),
1411+
col("c0"),
1412+
lit("hello "),
1413+
null.clone(),
1414+
lit("rust"),
1415+
col("c1"),
1416+
lit(""),
1417+
null,
1418+
]);
1419+
let expected = build_concat_expr(&[col("c0"), lit("hello rust"), col("c1")]);
1420+
assert_eq!(simplify(expr), expected)
1421+
}
1422+
13551423
// ------------------------------
13561424
// --- ConstEvaluator tests -----
13571425
// ------------------------------

datafusion/optimizer/tests/integration-test.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,19 @@ fn between_date64_plus_interval() -> Result<()> {
199199
Ok(())
200200
}
201201

202+
#[test]
203+
fn concat_literals() -> Result<()> {
204+
let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
205+
AS col
206+
FROM test";
207+
let plan = test_sql(sql)?;
208+
let expected =
209+
"Projection: concat(Utf8(\"1\"), CAST(test.col_int32 AS Utf8), Utf8(\"0hello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
210+
\n TableScan: test projection=[col_int32, col_utf8]";
211+
assert_eq!(expected, format!("{:?}", plan));
212+
Ok(())
213+
}
214+
202215
fn test_sql(sql: &str) -> Result<LogicalPlan> {
203216
// parse the SQL
204217
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...

0 commit comments

Comments
 (0)