1515// specific language governing permissions and limitations
1616// under the License.
1717
18+ mod native;
1819mod strings;
1920
2021use std:: any:: Any ;
21- use std:: cmp:: Eq ;
2222use std:: collections:: HashSet ;
2323use std:: fmt:: Debug ;
24- use std:: hash:: Hash ;
2524use std:: sync:: Arc ;
2625
2726use ahash:: RandomState ;
2827use arrow:: array:: { Array , ArrayRef } ;
2928use arrow:: datatypes:: { DataType , Field , TimeUnit } ;
3029use arrow_array:: types:: {
31- ArrowPrimitiveType , Date32Type , Date64Type , Decimal128Type , Decimal256Type ,
32- Float16Type , Float32Type , Float64Type , Int16Type , Int32Type , Int64Type , Int8Type ,
33- Time32MillisecondType , Time32SecondType , Time64MicrosecondType , Time64NanosecondType ,
30+ Date32Type , Date64Type , Decimal128Type , Decimal256Type , Float16Type , Float32Type ,
31+ Float64Type , Int16Type , Int32Type , Int64Type , Int8Type , Time32MillisecondType ,
32+ Time32SecondType , Time64MicrosecondType , Time64NanosecondType ,
3433 TimestampMicrosecondType , TimestampMillisecondType , TimestampNanosecondType ,
3534 TimestampSecondType , UInt16Type , UInt32Type , UInt64Type , UInt8Type ,
3635} ;
37- use arrow_array:: PrimitiveArray ;
3836
39- use datafusion_common:: cast:: { as_list_array, as_primitive_array} ;
40- use datafusion_common:: utils:: array_into_list_array;
4137use datafusion_common:: { Result , ScalarValue } ;
4238use datafusion_expr:: Accumulator ;
4339
40+ use crate :: aggregate:: count_distinct:: native:: {
41+ FloatDistinctCountAccumulator , PrimitiveDistinctCountAccumulator ,
42+ } ;
4443use crate :: aggregate:: count_distinct:: strings:: StringDistinctCountAccumulator ;
45- use crate :: aggregate:: utils:: { down_cast_any_ref, Hashable } ;
44+ use crate :: aggregate:: utils:: down_cast_any_ref;
4645use crate :: expressions:: format_state_name;
4746use crate :: { AggregateExpr , PhysicalExpr } ;
4847
49- type DistinctScalarValues = ScalarValue ;
50-
5148/// Expression for a COUNT(DISTINCT) aggregation.
5249#[ derive( Debug ) ]
5350pub struct DistinctCount {
@@ -101,46 +98,46 @@ impl AggregateExpr for DistinctCount {
10198 use TimeUnit :: * ;
10299
103100 Ok ( match & self . state_data_type {
104- Int8 => Box :: new ( NativeDistinctCountAccumulator :: < Int8Type > :: new ( ) ) ,
105- Int16 => Box :: new ( NativeDistinctCountAccumulator :: < Int16Type > :: new ( ) ) ,
106- Int32 => Box :: new ( NativeDistinctCountAccumulator :: < Int32Type > :: new ( ) ) ,
107- Int64 => Box :: new ( NativeDistinctCountAccumulator :: < Int64Type > :: new ( ) ) ,
108- UInt8 => Box :: new ( NativeDistinctCountAccumulator :: < UInt8Type > :: new ( ) ) ,
109- UInt16 => Box :: new ( NativeDistinctCountAccumulator :: < UInt16Type > :: new ( ) ) ,
110- UInt32 => Box :: new ( NativeDistinctCountAccumulator :: < UInt32Type > :: new ( ) ) ,
111- UInt64 => Box :: new ( NativeDistinctCountAccumulator :: < UInt64Type > :: new ( ) ) ,
101+ Int8 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Int8Type > :: new ( ) ) ,
102+ Int16 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Int16Type > :: new ( ) ) ,
103+ Int32 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Int32Type > :: new ( ) ) ,
104+ Int64 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Int64Type > :: new ( ) ) ,
105+ UInt8 => Box :: new ( PrimitiveDistinctCountAccumulator :: < UInt8Type > :: new ( ) ) ,
106+ UInt16 => Box :: new ( PrimitiveDistinctCountAccumulator :: < UInt16Type > :: new ( ) ) ,
107+ UInt32 => Box :: new ( PrimitiveDistinctCountAccumulator :: < UInt32Type > :: new ( ) ) ,
108+ UInt64 => Box :: new ( PrimitiveDistinctCountAccumulator :: < UInt64Type > :: new ( ) ) ,
112109 Decimal128 ( _, _) => {
113- Box :: new ( NativeDistinctCountAccumulator :: < Decimal128Type > :: new ( ) )
110+ Box :: new ( PrimitiveDistinctCountAccumulator :: < Decimal128Type > :: new ( ) )
114111 }
115112 Decimal256 ( _, _) => {
116- Box :: new ( NativeDistinctCountAccumulator :: < Decimal256Type > :: new ( ) )
113+ Box :: new ( PrimitiveDistinctCountAccumulator :: < Decimal256Type > :: new ( ) )
117114 }
118115
119- Date32 => Box :: new ( NativeDistinctCountAccumulator :: < Date32Type > :: new ( ) ) ,
120- Date64 => Box :: new ( NativeDistinctCountAccumulator :: < Date64Type > :: new ( ) ) ,
121- Time32 ( Millisecond ) => {
122- Box :: new ( NativeDistinctCountAccumulator :: < Time32MillisecondType > :: new ( ) )
123- }
116+ Date32 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Date32Type > :: new ( ) ) ,
117+ Date64 => Box :: new ( PrimitiveDistinctCountAccumulator :: < Date64Type > :: new ( ) ) ,
118+ Time32 ( Millisecond ) => Box :: new ( PrimitiveDistinctCountAccumulator :: <
119+ Time32MillisecondType ,
120+ > :: new ( ) ) ,
124121 Time32 ( Second ) => {
125- Box :: new ( NativeDistinctCountAccumulator :: < Time32SecondType > :: new ( ) )
126- }
127- Time64 ( Microsecond ) => {
128- Box :: new ( NativeDistinctCountAccumulator :: < Time64MicrosecondType > :: new ( ) )
122+ Box :: new ( PrimitiveDistinctCountAccumulator :: < Time32SecondType > :: new ( ) )
129123 }
124+ Time64 ( Microsecond ) => Box :: new ( PrimitiveDistinctCountAccumulator :: <
125+ Time64MicrosecondType ,
126+ > :: new ( ) ) ,
130127 Time64 ( Nanosecond ) => {
131- Box :: new ( NativeDistinctCountAccumulator :: < Time64NanosecondType > :: new ( ) )
128+ Box :: new ( PrimitiveDistinctCountAccumulator :: < Time64NanosecondType > :: new ( ) )
132129 }
133- Timestamp ( Microsecond , _) => Box :: new ( NativeDistinctCountAccumulator :: <
130+ Timestamp ( Microsecond , _) => Box :: new ( PrimitiveDistinctCountAccumulator :: <
134131 TimestampMicrosecondType ,
135132 > :: new ( ) ) ,
136- Timestamp ( Millisecond , _) => Box :: new ( NativeDistinctCountAccumulator :: <
133+ Timestamp ( Millisecond , _) => Box :: new ( PrimitiveDistinctCountAccumulator :: <
137134 TimestampMillisecondType ,
138135 > :: new ( ) ) ,
139- Timestamp ( Nanosecond , _) => {
140- Box :: new ( NativeDistinctCountAccumulator :: < TimestampNanosecondType > :: new ( ) )
141- }
136+ Timestamp ( Nanosecond , _) => Box :: new ( PrimitiveDistinctCountAccumulator :: <
137+ TimestampNanosecondType ,
138+ > :: new ( ) ) ,
142139 Timestamp ( Second , _) => {
143- Box :: new ( NativeDistinctCountAccumulator :: < TimestampSecondType > :: new ( ) )
140+ Box :: new ( PrimitiveDistinctCountAccumulator :: < TimestampSecondType > :: new ( ) )
144141 }
145142
146143 Float16 => Box :: new ( FloatDistinctCountAccumulator :: < Float16Type > :: new ( ) ) ,
@@ -175,9 +172,13 @@ impl PartialEq<dyn Any> for DistinctCount {
175172 }
176173}
177174
175+ /// General purpose distinct accumulator that works for any DataType by using
176+ /// [`ScalarValue`]. Some types have specialized accumulators that are (much)
177+ /// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
178+ /// [`StringDistinctCountAccumulator`]
178179#[ derive( Debug ) ]
179180struct DistinctCountAccumulator {
180- values : HashSet < DistinctScalarValues , RandomState > ,
181+ values : HashSet < ScalarValue , RandomState > ,
181182 state_data_type : DataType ,
182183}
183184
@@ -186,7 +187,7 @@ impl DistinctCountAccumulator {
186187 // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
187188 fn fixed_size ( & self ) -> usize {
188189 std:: mem:: size_of_val ( self )
189- + ( std:: mem:: size_of :: < DistinctScalarValues > ( ) * self . values . capacity ( ) )
190+ + ( std:: mem:: size_of :: < ScalarValue > ( ) * self . values . capacity ( ) )
190191 + self
191192 . values
192193 . iter ( )
@@ -199,7 +200,7 @@ impl DistinctCountAccumulator {
199200 // calculates the size as accurate as possible, call to this method is expensive
200201 fn full_size ( & self ) -> usize {
201202 std:: mem:: size_of_val ( self )
202- + ( std:: mem:: size_of :: < DistinctScalarValues > ( ) * self . values . capacity ( ) )
203+ + ( std:: mem:: size_of :: < ScalarValue > ( ) * self . values . capacity ( ) )
203204 + self
204205 . values
205206 . iter ( )
@@ -260,182 +261,6 @@ impl Accumulator for DistinctCountAccumulator {
260261 }
261262}
262263
263- #[ derive( Debug ) ]
264- struct NativeDistinctCountAccumulator < T >
265- where
266- T : ArrowPrimitiveType + Send ,
267- T :: Native : Eq + Hash ,
268- {
269- values : HashSet < T :: Native , RandomState > ,
270- }
271-
272- impl < T > NativeDistinctCountAccumulator < T >
273- where
274- T : ArrowPrimitiveType + Send ,
275- T :: Native : Eq + Hash ,
276- {
277- fn new ( ) -> Self {
278- Self {
279- values : HashSet :: default ( ) ,
280- }
281- }
282- }
283-
284- impl < T > Accumulator for NativeDistinctCountAccumulator < T >
285- where
286- T : ArrowPrimitiveType + Send + Debug ,
287- T :: Native : Eq + Hash ,
288- {
289- fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
290- let arr = Arc :: new ( PrimitiveArray :: < T > :: from_iter_values (
291- self . values . iter ( ) . cloned ( ) ,
292- ) ) as ArrayRef ;
293- let list = Arc :: new ( array_into_list_array ( arr) ) ;
294- Ok ( vec ! [ ScalarValue :: List ( list) ] )
295- }
296-
297- fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
298- if values. is_empty ( ) {
299- return Ok ( ( ) ) ;
300- }
301-
302- let arr = as_primitive_array :: < T > ( & values[ 0 ] ) ?;
303- arr. iter ( ) . for_each ( |value| {
304- if let Some ( value) = value {
305- self . values . insert ( value) ;
306- }
307- } ) ;
308-
309- Ok ( ( ) )
310- }
311-
312- fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
313- if states. is_empty ( ) {
314- return Ok ( ( ) ) ;
315- }
316- assert_eq ! (
317- states. len( ) ,
318- 1 ,
319- "count_distinct states must be single array"
320- ) ;
321-
322- let arr = as_list_array ( & states[ 0 ] ) ?;
323- arr. iter ( ) . try_for_each ( |maybe_list| {
324- if let Some ( list) = maybe_list {
325- let list = as_primitive_array :: < T > ( & list) ?;
326- self . values . extend ( list. values ( ) )
327- } ;
328- Ok ( ( ) )
329- } )
330- }
331-
332- fn evaluate ( & mut self ) -> Result < ScalarValue > {
333- Ok ( ScalarValue :: Int64 ( Some ( self . values . len ( ) as i64 ) ) )
334- }
335-
336- fn size ( & self ) -> usize {
337- let estimated_buckets = ( self . values . len ( ) . checked_mul ( 8 ) . unwrap_or ( usize:: MAX )
338- / 7 )
339- . next_power_of_two ( ) ;
340-
341- // Size of accumulator
342- // + size of entry * number of buckets
343- // + 1 byte for each bucket
344- // + fixed size of HashSet
345- std:: mem:: size_of_val ( self )
346- + std:: mem:: size_of :: < T :: Native > ( ) * estimated_buckets
347- + estimated_buckets
348- + std:: mem:: size_of_val ( & self . values )
349- }
350- }
351-
352- #[ derive( Debug ) ]
353- struct FloatDistinctCountAccumulator < T >
354- where
355- T : ArrowPrimitiveType + Send ,
356- {
357- values : HashSet < Hashable < T :: Native > , RandomState > ,
358- }
359-
360- impl < T > FloatDistinctCountAccumulator < T >
361- where
362- T : ArrowPrimitiveType + Send ,
363- {
364- fn new ( ) -> Self {
365- Self {
366- values : HashSet :: default ( ) ,
367- }
368- }
369- }
370-
371- impl < T > Accumulator for FloatDistinctCountAccumulator < T >
372- where
373- T : ArrowPrimitiveType + Send + Debug ,
374- {
375- fn state ( & mut self ) -> Result < Vec < ScalarValue > > {
376- let arr = Arc :: new ( PrimitiveArray :: < T > :: from_iter_values (
377- self . values . iter ( ) . map ( |v| v. 0 ) ,
378- ) ) as ArrayRef ;
379- let list = Arc :: new ( array_into_list_array ( arr) ) ;
380- Ok ( vec ! [ ScalarValue :: List ( list) ] )
381- }
382-
383- fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
384- if values. is_empty ( ) {
385- return Ok ( ( ) ) ;
386- }
387-
388- let arr = as_primitive_array :: < T > ( & values[ 0 ] ) ?;
389- arr. iter ( ) . for_each ( |value| {
390- if let Some ( value) = value {
391- self . values . insert ( Hashable ( value) ) ;
392- }
393- } ) ;
394-
395- Ok ( ( ) )
396- }
397-
398- fn merge_batch ( & mut self , states : & [ ArrayRef ] ) -> Result < ( ) > {
399- if states. is_empty ( ) {
400- return Ok ( ( ) ) ;
401- }
402- assert_eq ! (
403- states. len( ) ,
404- 1 ,
405- "count_distinct states must be single array"
406- ) ;
407-
408- let arr = as_list_array ( & states[ 0 ] ) ?;
409- arr. iter ( ) . try_for_each ( |maybe_list| {
410- if let Some ( list) = maybe_list {
411- let list = as_primitive_array :: < T > ( & list) ?;
412- self . values
413- . extend ( list. values ( ) . iter ( ) . map ( |v| Hashable ( * v) ) ) ;
414- } ;
415- Ok ( ( ) )
416- } )
417- }
418-
419- fn evaluate ( & mut self ) -> Result < ScalarValue > {
420- Ok ( ScalarValue :: Int64 ( Some ( self . values . len ( ) as i64 ) ) )
421- }
422-
423- fn size ( & self ) -> usize {
424- let estimated_buckets = ( self . values . len ( ) . checked_mul ( 8 ) . unwrap_or ( usize:: MAX )
425- / 7 )
426- . next_power_of_two ( ) ;
427-
428- // Size of accumulator
429- // + size of entry * number of buckets
430- // + 1 byte for each bucket
431- // + fixed size of HashSet
432- std:: mem:: size_of_val ( self )
433- + std:: mem:: size_of :: < T :: Native > ( ) * estimated_buckets
434- + estimated_buckets
435- + std:: mem:: size_of_val ( & self . values )
436- }
437- }
438-
439264#[ cfg( test) ]
440265mod tests {
441266 use arrow:: array:: {
0 commit comments