Skip to content

Commit 0f3cf27

Browse files
2010YOUY01vegarsti
andauthored
perf: Faster string_agg() aggregate function (1000x speed for no DISTINCT and ORDER case) (#17837)
* impl SimpleStringAggAccumulator for performance * lint * review: rename in_progress_string --> accumualted_string * Update datafusion/functions-aggregate/src/string_agg.rs Co-authored-by: Vegard Stikbakke <vegard.stikbakke@gmail.com> * Update datafusion/functions-aggregate/src/string_agg.rs Co-authored-by: Vegard Stikbakke <vegard.stikbakke@gmail.com> --------- Co-authored-by: Vegard Stikbakke <vegard.stikbakke@gmail.com>
1 parent 21b6df1 commit 0f3cf27

File tree

1 file changed

+148
-17
lines changed

1 file changed

+148
-17
lines changed

datafusion/functions-aggregate/src/string_agg.rs

Lines changed: 148 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@ use crate::array_agg::ArrayAgg;
2525

2626
use arrow::array::ArrayRef;
2727
use arrow::datatypes::{DataType, Field, FieldRef};
28-
use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
29-
use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue};
28+
use datafusion_common::cast::{
29+
as_generic_string_array, as_string_array, as_string_view_array,
30+
};
31+
use datafusion_common::{
32+
internal_datafusion_err, internal_err, not_impl_err, Result, ScalarValue,
33+
};
3034
use datafusion_expr::function::AccumulatorArgs;
35+
use datafusion_expr::utils::format_state_name;
3136
use datafusion_expr::{
3237
Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, Volatility,
3338
};
@@ -120,6 +125,8 @@ impl Default for StringAgg {
120125
}
121126
}
122127

128+
/// If there is no `distinct` and `order by` required by the `string_agg` call, a
129+
/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
123130
impl AggregateUDFImpl for StringAgg {
124131
fn as_any(&self) -> &dyn Any {
125132
self
@@ -138,7 +145,21 @@ impl AggregateUDFImpl for StringAgg {
138145
}
139146

140147
fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
141-
self.array_agg.state_fields(args)
148+
// See comments in `impl AggregateUDFImpl ...` for more detail
149+
let no_order_no_distinct =
150+
(args.ordering_fields.is_empty()) && (!args.is_distinct);
151+
if no_order_no_distinct {
152+
// Case `SimpleStringAggAccumulator`
153+
Ok(vec![Field::new(
154+
format_state_name(args.name, "string_agg"),
155+
DataType::LargeUtf8,
156+
true,
157+
)
158+
.into()])
159+
} else {
160+
// Case `StringAggAccumulator`
161+
self.array_agg.state_fields(args)
162+
}
142163
}
143164

144165
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
@@ -161,21 +182,31 @@ impl AggregateUDFImpl for StringAgg {
161182
);
162183
};
163184

164-
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
165-
return_field: Field::new(
166-
"f",
167-
DataType::new_list(acc_args.return_field.data_type().clone(), true),
168-
true,
169-
)
170-
.into(),
171-
exprs: &filter_index(acc_args.exprs, 1),
172-
..acc_args
173-
})?;
185+
// See comments in `impl AggregateUDFImpl ...` for more detail
186+
let no_order_no_distinct =
187+
acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
174188

175-
Ok(Box::new(StringAggAccumulator::new(
176-
array_agg_acc,
177-
delimiter,
178-
)))
189+
if no_order_no_distinct {
190+
// simple case (more efficient)
191+
Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
192+
} else {
193+
// general case
194+
let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
195+
return_field: Field::new(
196+
"f",
197+
DataType::new_list(acc_args.return_field.data_type().clone(), true),
198+
true,
199+
)
200+
.into(),
201+
exprs: &filter_index(acc_args.exprs, 1),
202+
..acc_args
203+
})?;
204+
205+
Ok(Box::new(StringAggAccumulator::new(
206+
array_agg_acc,
207+
delimiter,
208+
)))
209+
}
179210
}
180211

181212
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
@@ -187,6 +218,7 @@ impl AggregateUDFImpl for StringAgg {
187218
}
188219
}
189220

221+
/// StringAgg accumulator for the general case (with order or distinct specified)
190222
#[derive(Debug)]
191223
pub(crate) struct StringAggAccumulator {
192224
array_agg_acc: Box<dyn Accumulator>,
@@ -269,6 +301,105 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
269301
.collect::<Vec<_>>()
270302
}
271303

304+
/// StringAgg accumulator for the simple case (no order or distinct specified)
305+
/// This accumulator is more efficient than `StringAggAccumulator`
306+
/// because it accumulates the string directly,
307+
/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
308+
#[derive(Debug)]
309+
pub(crate) struct SimpleStringAggAccumulator {
310+
delimiter: String,
311+
/// Updated during `update_batch()`. e.g. "foo,bar"
312+
accumulated_string: String,
313+
has_value: bool,
314+
}
315+
316+
impl SimpleStringAggAccumulator {
317+
pub fn new(delimiter: &str) -> Self {
318+
Self {
319+
delimiter: delimiter.to_string(),
320+
accumulated_string: "".to_string(),
321+
has_value: false,
322+
}
323+
}
324+
325+
#[inline]
326+
fn append_strings<'a, I>(&mut self, iter: I)
327+
where
328+
I: Iterator<Item = Option<&'a str>>,
329+
{
330+
for value in iter.flatten() {
331+
if self.has_value {
332+
self.accumulated_string.push_str(&self.delimiter);
333+
}
334+
335+
self.accumulated_string.push_str(value);
336+
self.has_value = true;
337+
}
338+
}
339+
}
340+
341+
impl Accumulator for SimpleStringAggAccumulator {
342+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
343+
let string_arr = values.first().ok_or_else(|| {
344+
internal_datafusion_err!(
345+
"Planner should ensure its first arg is Utf8/Utf8View"
346+
)
347+
})?;
348+
349+
match string_arr.data_type() {
350+
DataType::Utf8 => {
351+
let array = as_string_array(string_arr)?;
352+
self.append_strings(array.iter());
353+
}
354+
DataType::LargeUtf8 => {
355+
let array = as_generic_string_array::<i64>(string_arr)?;
356+
self.append_strings(array.iter());
357+
}
358+
DataType::Utf8View => {
359+
let array = as_string_view_array(string_arr)?;
360+
self.append_strings(array.iter());
361+
}
362+
other => {
363+
return internal_err!(
364+
"Planner should ensure string_agg first argument is Utf8-like, found {other}"
365+
);
366+
}
367+
}
368+
369+
Ok(())
370+
}
371+
372+
fn evaluate(&mut self) -> Result<ScalarValue> {
373+
let result = if self.has_value {
374+
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
375+
} else {
376+
ScalarValue::LargeUtf8(None)
377+
};
378+
379+
self.has_value = false;
380+
Ok(result)
381+
}
382+
383+
fn size(&self) -> usize {
384+
size_of_val(self) + self.delimiter.capacity() + self.accumulated_string.capacity()
385+
}
386+
387+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
388+
let result = if self.has_value {
389+
ScalarValue::LargeUtf8(Some(std::mem::take(&mut self.accumulated_string)))
390+
} else {
391+
ScalarValue::LargeUtf8(None)
392+
};
393+
self.has_value = false;
394+
395+
Ok(vec![result])
396+
}
397+
398+
fn merge_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
399+
self.update_batch(values)
400+
}
401+
}
402+
272403
#[cfg(test)]
273404
mod tests {
274405
use super::*;

0 commit comments

Comments
 (0)