Skip to content

Commit 76ced31

Browse files
haohuaijinuniversalmind303alamb
authored
feat: impl the basic string_agg function (#8148)
* init impl * add support for larget utf8 * add some test * support null * remove redundance code * remove redundance code * add more test * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 <universalmind.candy@gmail.com> * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: universalmind303 <universalmind.candy@gmail.com> * add suggest * Update datafusion/physical-expr/src/aggregate/string_agg.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Update datafusion/sqllogictest/test_files/aggregate.slt Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * fix ci --------- Co-authored-by: universalmind303 <universalmind.candy@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent a984f08 commit 76ced31

File tree

12 files changed

+386
-0
lines changed

12 files changed

+386
-0
lines changed

datafusion/expr/src/aggregate_function.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ pub enum AggregateFunction {
100100
BoolAnd,
101101
/// Bool Or
102102
BoolOr,
103+
/// string_agg
104+
StringAgg,
103105
}
104106

105107
impl AggregateFunction {
@@ -141,6 +143,7 @@ impl AggregateFunction {
141143
BitXor => "BIT_XOR",
142144
BoolAnd => "BOOL_AND",
143145
BoolOr => "BOOL_OR",
146+
StringAgg => "STRING_AGG",
144147
}
145148
}
146149
}
@@ -171,6 +174,7 @@ impl FromStr for AggregateFunction {
171174
"array_agg" => AggregateFunction::ArrayAgg,
172175
"first_value" => AggregateFunction::FirstValue,
173176
"last_value" => AggregateFunction::LastValue,
177+
"string_agg" => AggregateFunction::StringAgg,
174178
// statistical
175179
"corr" => AggregateFunction::Correlation,
176180
"covar" => AggregateFunction::Covariance,
@@ -299,6 +303,7 @@ impl AggregateFunction {
299303
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
300304
Ok(coerced_data_types[0].clone())
301305
}
306+
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
302307
}
303308
}
304309
}
@@ -408,6 +413,9 @@ impl AggregateFunction {
408413
.collect(),
409414
Volatility::Immutable,
410415
),
416+
AggregateFunction::StringAgg => {
417+
Signature::uniform(2, STRINGS.to_vec(), Volatility::Immutable)
418+
}
411419
}
412420
}
413421
}

datafusion/expr/src/type_coercion/aggregates.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,23 @@ pub fn coerce_types(
298298
| AggregateFunction::FirstValue
299299
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
300300
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
301+
AggregateFunction::StringAgg => {
302+
if !is_string_agg_supported_arg_type(&input_types[0]) {
303+
return plan_err!(
304+
"The function {:?} does not support inputs of type {:?}",
305+
agg_fun,
306+
input_types[0]
307+
);
308+
}
309+
if !is_string_agg_supported_arg_type(&input_types[1]) {
310+
return plan_err!(
311+
"The function {:?} does not support inputs of type {:?}",
312+
agg_fun,
313+
input_types[1]
314+
);
315+
}
316+
Ok(vec![LargeUtf8, input_types[1].clone()])
317+
}
301318
}
302319
}
303320

@@ -565,6 +582,15 @@ pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool
565582
)
566583
}
567584

585+
/// Return `true` if `arg_type` is of a [`DataType`] that the
586+
/// [`AggregateFunction::StringAgg`] aggregation can operate on.
587+
pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
588+
matches!(
589+
arg_type,
590+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Null
591+
)
592+
}
593+
568594
#[cfg(test)]
569595
mod tests {
570596
use super::*;

datafusion/physical-expr/src/aggregate/build_in.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,22 @@ pub fn create_aggregate_expr(
369369
ordering_req.to_vec(),
370370
ordering_types,
371371
)),
372+
(AggregateFunction::StringAgg, false) => {
373+
if !ordering_req.is_empty() {
374+
return not_impl_err!(
375+
"STRING_AGG(ORDER BY a ASC) order-sensitive aggregations are not available"
376+
);
377+
}
378+
Arc::new(expressions::StringAgg::new(
379+
input_phy_exprs[0].clone(),
380+
input_phy_exprs[1].clone(),
381+
name,
382+
data_type,
383+
))
384+
}
385+
(AggregateFunction::StringAgg, true) => {
386+
return not_impl_err!("STRING_AGG(DISTINCT) aggregations are not available");
387+
}
372388
})
373389
}
374390

datafusion/physical-expr/src/aggregate/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ pub(crate) mod covariance;
4343
pub(crate) mod first_last;
4444
pub(crate) mod grouping;
4545
pub(crate) mod median;
46+
pub(crate) mod string_agg;
4647
#[macro_use]
4748
pub(crate) mod min_max;
4849
pub mod build_in;
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! [`StringAgg`] and [`StringAggAccumulator`] accumulator for the `string_agg` function
19+
20+
use crate::aggregate::utils::down_cast_any_ref;
21+
use crate::expressions::{format_state_name, Literal};
22+
use crate::{AggregateExpr, PhysicalExpr};
23+
use arrow::array::ArrayRef;
24+
use arrow::datatypes::{DataType, Field};
25+
use datafusion_common::cast::as_generic_string_array;
26+
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
27+
use datafusion_expr::Accumulator;
28+
use std::any::Any;
29+
use std::sync::Arc;
30+
31+
/// STRING_AGG aggregate expression
32+
#[derive(Debug)]
33+
pub struct StringAgg {
34+
name: String,
35+
data_type: DataType,
36+
expr: Arc<dyn PhysicalExpr>,
37+
delimiter: Arc<dyn PhysicalExpr>,
38+
nullable: bool,
39+
}
40+
41+
impl StringAgg {
42+
/// Create a new StringAgg aggregate function
43+
pub fn new(
44+
expr: Arc<dyn PhysicalExpr>,
45+
delimiter: Arc<dyn PhysicalExpr>,
46+
name: impl Into<String>,
47+
data_type: DataType,
48+
) -> Self {
49+
Self {
50+
name: name.into(),
51+
data_type,
52+
delimiter,
53+
expr,
54+
nullable: true,
55+
}
56+
}
57+
}
58+
59+
impl AggregateExpr for StringAgg {
60+
fn as_any(&self) -> &dyn Any {
61+
self
62+
}
63+
64+
fn field(&self) -> Result<Field> {
65+
Ok(Field::new(
66+
&self.name,
67+
self.data_type.clone(),
68+
self.nullable,
69+
))
70+
}
71+
72+
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
73+
if let Some(delimiter) = self.delimiter.as_any().downcast_ref::<Literal>() {
74+
match delimiter.value() {
75+
ScalarValue::Utf8(Some(delimiter))
76+
| ScalarValue::LargeUtf8(Some(delimiter)) => {
77+
return Ok(Box::new(StringAggAccumulator::new(delimiter)));
78+
}
79+
ScalarValue::Null => {
80+
return Ok(Box::new(StringAggAccumulator::new("")));
81+
}
82+
_ => return not_impl_err!("StringAgg not supported for {}", self.name),
83+
}
84+
}
85+
not_impl_err!("StringAgg not supported for {}", self.name)
86+
}
87+
88+
fn state_fields(&self) -> Result<Vec<Field>> {
89+
Ok(vec![Field::new(
90+
format_state_name(&self.name, "string_agg"),
91+
self.data_type.clone(),
92+
self.nullable,
93+
)])
94+
}
95+
96+
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
97+
vec![self.expr.clone(), self.delimiter.clone()]
98+
}
99+
100+
fn name(&self) -> &str {
101+
&self.name
102+
}
103+
}
104+
105+
impl PartialEq<dyn Any> for StringAgg {
106+
fn eq(&self, other: &dyn Any) -> bool {
107+
down_cast_any_ref(other)
108+
.downcast_ref::<Self>()
109+
.map(|x| {
110+
self.name == x.name
111+
&& self.data_type == x.data_type
112+
&& self.expr.eq(&x.expr)
113+
&& self.delimiter.eq(&x.delimiter)
114+
})
115+
.unwrap_or(false)
116+
}
117+
}
118+
119+
#[derive(Debug)]
120+
pub(crate) struct StringAggAccumulator {
121+
values: Option<String>,
122+
delimiter: String,
123+
}
124+
125+
impl StringAggAccumulator {
126+
pub fn new(delimiter: &str) -> Self {
127+
Self {
128+
values: None,
129+
delimiter: delimiter.to_string(),
130+
}
131+
}
132+
}
133+
134+
impl Accumulator for StringAggAccumulator {
135+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
136+
let string_array: Vec<_> = as_generic_string_array::<i64>(&values[0])?
137+
.iter()
138+
.filter_map(|v| v.as_ref().map(ToString::to_string))
139+
.collect();
140+
if !string_array.is_empty() {
141+
let s = string_array.join(self.delimiter.as_str());
142+
let v = self.values.get_or_insert("".to_string());
143+
if !v.is_empty() {
144+
v.push_str(self.delimiter.as_str());
145+
}
146+
v.push_str(s.as_str());
147+
}
148+
Ok(())
149+
}
150+
151+
fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
152+
self.update_batch(values)?;
153+
Ok(())
154+
}
155+
156+
fn state(&self) -> Result<Vec<ScalarValue>> {
157+
Ok(vec![self.evaluate()?])
158+
}
159+
160+
fn evaluate(&self) -> Result<ScalarValue> {
161+
Ok(ScalarValue::LargeUtf8(self.values.clone()))
162+
}
163+
164+
fn size(&self) -> usize {
165+
std::mem::size_of_val(self)
166+
+ self.values.as_ref().map(|v| v.capacity()).unwrap_or(0)
167+
+ self.delimiter.capacity()
168+
}
169+
}
170+
171+
#[cfg(test)]
172+
mod tests {
173+
use super::*;
174+
use crate::expressions::tests::aggregate;
175+
use crate::expressions::{col, create_aggregate_expr, try_cast};
176+
use arrow::array::ArrayRef;
177+
use arrow::datatypes::*;
178+
use arrow::record_batch::RecordBatch;
179+
use arrow_array::LargeStringArray;
180+
use arrow_array::StringArray;
181+
use datafusion_expr::type_coercion::aggregates::coerce_types;
182+
use datafusion_expr::AggregateFunction;
183+
184+
fn assert_string_aggregate(
185+
array: ArrayRef,
186+
function: AggregateFunction,
187+
distinct: bool,
188+
expected: ScalarValue,
189+
delimiter: String,
190+
) {
191+
let data_type = array.data_type();
192+
let sig = function.signature();
193+
let coerced =
194+
coerce_types(&function, &[data_type.clone(), DataType::Utf8], &sig).unwrap();
195+
196+
let input_schema = Schema::new(vec![Field::new("a", data_type.clone(), true)]);
197+
let batch =
198+
RecordBatch::try_new(Arc::new(input_schema.clone()), vec![array]).unwrap();
199+
200+
let input = try_cast(
201+
col("a", &input_schema).unwrap(),
202+
&input_schema,
203+
coerced[0].clone(),
204+
)
205+
.unwrap();
206+
207+
let delimiter = Arc::new(Literal::new(ScalarValue::Utf8(Some(delimiter))));
208+
let schema = Schema::new(vec![Field::new("a", coerced[0].clone(), true)]);
209+
let agg = create_aggregate_expr(
210+
&function,
211+
distinct,
212+
&[input, delimiter],
213+
&[],
214+
&schema,
215+
"agg",
216+
)
217+
.unwrap();
218+
219+
let result = aggregate(&batch, agg).unwrap();
220+
assert_eq!(expected, result);
221+
}
222+
223+
#[test]
224+
fn string_agg_utf8() {
225+
let a: ArrayRef = Arc::new(StringArray::from(vec!["h", "e", "l", "l", "o"]));
226+
assert_string_aggregate(
227+
a,
228+
AggregateFunction::StringAgg,
229+
false,
230+
ScalarValue::LargeUtf8(Some("h,e,l,l,o".to_owned())),
231+
",".to_owned(),
232+
);
233+
}
234+
235+
#[test]
236+
fn string_agg_largeutf8() {
237+
let a: ArrayRef = Arc::new(LargeStringArray::from(vec!["h", "e", "l", "l", "o"]));
238+
assert_string_aggregate(
239+
a,
240+
AggregateFunction::StringAgg,
241+
false,
242+
ScalarValue::LargeUtf8(Some("h|e|l|l|o".to_owned())),
243+
"|".to_owned(),
244+
);
245+
}
246+
}

datafusion/physical-expr/src/expressions/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pub use crate::aggregate::min_max::{MaxAccumulator, MinAccumulator};
6363
pub use crate::aggregate::regr::{Regr, RegrType};
6464
pub use crate::aggregate::stats::StatsType;
6565
pub use crate::aggregate::stddev::{Stddev, StddevPop};
66+
pub use crate::aggregate::string_agg::StringAgg;
6667
pub use crate::aggregate::sum::Sum;
6768
pub use crate::aggregate::sum_distinct::DistinctSum;
6869
pub use crate::aggregate::variance::{Variance, VariancePop};

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ enum AggregateFunction {
686686
REGR_SXX = 32;
687687
REGR_SYY = 33;
688688
REGR_SXY = 34;
689+
STRING_AGG = 35;
689690
}
690691

691692
message AggregateExprNode {

0 commit comments

Comments
 (0)