@@ -25,9 +25,14 @@ use crate::array_agg::ArrayAgg;
25
25
26
26
use arrow:: array:: ArrayRef ;
27
27
use arrow:: datatypes:: { DataType , Field , FieldRef } ;
28
- use datafusion_common:: cast:: { as_generic_string_array, as_string_view_array} ;
29
- use datafusion_common:: { internal_err, not_impl_err, Result , ScalarValue } ;
28
+ use datafusion_common:: cast:: {
29
+ as_generic_string_array, as_string_array, as_string_view_array,
30
+ } ;
31
+ use datafusion_common:: {
32
+ internal_datafusion_err, internal_err, not_impl_err, Result , ScalarValue ,
33
+ } ;
30
34
use datafusion_expr:: function:: AccumulatorArgs ;
35
+ use datafusion_expr:: utils:: format_state_name;
31
36
use datafusion_expr:: {
32
37
Accumulator , AggregateUDFImpl , Documentation , Signature , TypeSignature , Volatility ,
33
38
} ;
@@ -120,6 +125,8 @@ impl Default for StringAgg {
120
125
}
121
126
}
122
127
128
+ /// If there is no `distinct` and `order by` required by the `string_agg` call, a
129
+ /// more efficient accumulator `SimpleStringAggAccumulator` will be used.
123
130
impl AggregateUDFImpl for StringAgg {
124
131
fn as_any ( & self ) -> & dyn Any {
125
132
self
@@ -138,7 +145,21 @@ impl AggregateUDFImpl for StringAgg {
138
145
}
139
146
140
147
fn state_fields ( & self , args : StateFieldsArgs ) -> Result < Vec < FieldRef > > {
141
- self . array_agg . state_fields ( args)
148
+ // See comments in `impl AggregateUDFImpl ...` for more detail
149
+ let no_order_no_distinct =
150
+ ( args. ordering_fields . is_empty ( ) ) && ( !args. is_distinct ) ;
151
+ if no_order_no_distinct {
152
+ // Case `SimpleStringAggAccumulator`
153
+ Ok ( vec ! [ Field :: new(
154
+ format_state_name( args. name, "string_agg" ) ,
155
+ DataType :: LargeUtf8 ,
156
+ true ,
157
+ )
158
+ . into( ) ] )
159
+ } else {
160
+ // Case `StringAggAccumulator`
161
+ self . array_agg . state_fields ( args)
162
+ }
142
163
}
143
164
144
165
fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
@@ -161,21 +182,31 @@ impl AggregateUDFImpl for StringAgg {
161
182
) ;
162
183
} ;
163
184
164
- let array_agg_acc = self . array_agg . accumulator ( AccumulatorArgs {
165
- return_field : Field :: new (
166
- "f" ,
167
- DataType :: new_list ( acc_args. return_field . data_type ( ) . clone ( ) , true ) ,
168
- true ,
169
- )
170
- . into ( ) ,
171
- exprs : & filter_index ( acc_args. exprs , 1 ) ,
172
- ..acc_args
173
- } ) ?;
185
+ // See comments in `impl AggregateUDFImpl ...` for more detail
186
+ let no_order_no_distinct =
187
+ acc_args. order_bys . is_empty ( ) && ( !acc_args. is_distinct ) ;
174
188
175
- Ok ( Box :: new ( StringAggAccumulator :: new (
176
- array_agg_acc,
177
- delimiter,
178
- ) ) )
189
+ if no_order_no_distinct {
190
+ // simple case (more efficient)
191
+ Ok ( Box :: new ( SimpleStringAggAccumulator :: new ( delimiter) ) )
192
+ } else {
193
+ // general case
194
+ let array_agg_acc = self . array_agg . accumulator ( AccumulatorArgs {
195
+ return_field : Field :: new (
196
+ "f" ,
197
+ DataType :: new_list ( acc_args. return_field . data_type ( ) . clone ( ) , true ) ,
198
+ true ,
199
+ )
200
+ . into ( ) ,
201
+ exprs : & filter_index ( acc_args. exprs , 1 ) ,
202
+ ..acc_args
203
+ } ) ?;
204
+
205
+ Ok ( Box :: new ( StringAggAccumulator :: new (
206
+ array_agg_acc,
207
+ delimiter,
208
+ ) ) )
209
+ }
179
210
}
180
211
181
212
fn reverse_expr ( & self ) -> datafusion_expr:: ReversedUDAF {
@@ -187,6 +218,7 @@ impl AggregateUDFImpl for StringAgg {
187
218
}
188
219
}
189
220
221
+ /// StringAgg accumulator for the general case (with order or distinct specified)
190
222
#[ derive( Debug ) ]
191
223
pub ( crate ) struct StringAggAccumulator {
192
224
array_agg_acc : Box < dyn Accumulator > ,
@@ -269,6 +301,105 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> Vec<T> {
269
301
. collect :: < Vec < _ > > ( )
270
302
}
271
303
304
+ /// StringAgg accumulator for the simple case (no order or distinct specified)
305
+ /// This accumulator is more efficient than `StringAggAccumulator`
306
+ /// because it accumulates the string directly,
307
+ /// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
308
+ #[ derive( Debug ) ]
309
+ pub ( crate ) struct SimpleStringAggAccumulator {
310
+ delimiter : String ,
311
+ /// Updated during `update_batch()`. e.g. "foo,bar"
312
+ accumulated_string : String ,
313
+ has_value : bool ,
314
+ }
315
+
316
+ impl SimpleStringAggAccumulator {
317
+ pub fn new ( delimiter : & str ) -> Self {
318
+ Self {
319
+ delimiter : delimiter. to_string ( ) ,
320
+ accumulated_string : "" . to_string ( ) ,
321
+ has_value : false ,
322
+ }
323
+ }
324
+
325
+ #[ inline]
326
+ fn append_strings < ' a , I > ( & mut self , iter : I )
327
+ where
328
+ I : Iterator < Item = Option < & ' a str > > ,
329
+ {
330
+ for value in iter. flatten ( ) {
331
+ if self . has_value {
332
+ self . accumulated_string . push_str ( & self . delimiter ) ;
333
+ }
334
+
335
+ self . accumulated_string . push_str ( value) ;
336
+ self . has_value = true ;
337
+ }
338
+ }
339
+ }
340
+
341
+ impl Accumulator for SimpleStringAggAccumulator {
342
+ fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
343
+ let string_arr = values. first ( ) . ok_or_else ( || {
344
+ internal_datafusion_err ! (
345
+ "Planner should ensure its first arg is Utf8/Utf8View"
346
+ )
347
+ } ) ?;
348
+
349
+ match string_arr. data_type ( ) {
350
+ DataType :: Utf8 => {
351
+ let array = as_string_array ( string_arr) ?;
352
+ self . append_strings ( array. iter ( ) ) ;
353
+ }
354
+ DataType :: LargeUtf8 => {
355
+ let array = as_generic_string_array :: < i64 > ( string_arr) ?;
356
+ self . append_strings ( array. iter ( ) ) ;
357
+ }
358
+ DataType :: Utf8View => {
359
+ let array = as_string_view_array ( string_arr) ?;
360
+ self . append_strings ( array. iter ( ) ) ;
361
+ }
362
+ other => {
363
+ return internal_err ! (
364
+ "Planner should ensure string_agg first argument is Utf8-like, found {other}"
365
+ ) ;
366
+ }
367
+ }
368
+
369
+ Ok ( ( ) )
370
+ }
371
+
372
+ fn evaluate ( & mut self ) -> Result < ScalarValue > {
373
+ let result = if self . has_value {
374
+ ScalarValue :: LargeUtf8 ( Some ( std:: mem:: take ( & mut self . accumulated_string ) ) )
375
+ } else {
376
+ ScalarValue :: LargeUtf8 ( None )
377
+ } ;
378
+
379
+ self . has_value = false ;
380
+ Ok ( result)
381
+ }
382
+
383
+ fn size ( & self ) -> usize {
384
+ size_of_val ( self ) + self . delimiter . capacity ( ) + self . accumulated_string . capacity ( )
385
+ }
386
+
387
+ fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
388
+ let result = if self . has_value {
389
+ ScalarValue :: LargeUtf8 ( Some ( std:: mem:: take ( & mut self . accumulated_string ) ) )
390
+ } else {
391
+ ScalarValue :: LargeUtf8 ( None )
392
+ } ;
393
+ self . has_value = false ;
394
+
395
+ Ok ( vec ! [ result] )
396
+ }
397
+
398
+ fn merge_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
399
+ self . update_batch ( values)
400
+ }
401
+ }
402
+
272
403
#[ cfg( test) ]
273
404
mod tests {
274
405
use super :: * ;
0 commit comments