Skip to content

Commit 8b50774

Browse files
authored
Split count_distinct.rs into separate modules (#9087)
* Split count_distinct.rs into separate modules * Remove unecessary typedef * Rename * improve module comments
1 parent 968c05f commit 8b50774

File tree

3 files changed

+261
-217
lines changed

3 files changed

+261
-217
lines changed

datafusion/physical-expr/src/aggregate/count_distinct/mod.rs

Lines changed: 41 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,36 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
mod native;
1819
mod strings;
1920

2021
use std::any::Any;
21-
use std::cmp::Eq;
2222
use std::collections::HashSet;
2323
use std::fmt::Debug;
24-
use std::hash::Hash;
2524
use std::sync::Arc;
2625

2726
use ahash::RandomState;
2827
use arrow::array::{Array, ArrayRef};
2928
use arrow::datatypes::{DataType, Field, TimeUnit};
3029
use arrow_array::types::{
31-
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
32-
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
33-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
30+
Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type,
31+
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType,
32+
Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
3433
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
3534
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
3635
};
37-
use arrow_array::PrimitiveArray;
3836

39-
use datafusion_common::cast::{as_list_array, as_primitive_array};
40-
use datafusion_common::utils::array_into_list_array;
4137
use datafusion_common::{Result, ScalarValue};
4238
use datafusion_expr::Accumulator;
4339

40+
use crate::aggregate::count_distinct::native::{
41+
FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
42+
};
4443
use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
45-
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
44+
use crate::aggregate::utils::down_cast_any_ref;
4645
use crate::expressions::format_state_name;
4746
use crate::{AggregateExpr, PhysicalExpr};
4847

49-
type DistinctScalarValues = ScalarValue;
50-
5148
/// Expression for a COUNT(DISTINCT) aggregation.
5249
#[derive(Debug)]
5350
pub struct DistinctCount {
@@ -101,46 +98,46 @@ impl AggregateExpr for DistinctCount {
10198
use TimeUnit::*;
10299

103100
Ok(match &self.state_data_type {
104-
Int8 => Box::new(NativeDistinctCountAccumulator::<Int8Type>::new()),
105-
Int16 => Box::new(NativeDistinctCountAccumulator::<Int16Type>::new()),
106-
Int32 => Box::new(NativeDistinctCountAccumulator::<Int32Type>::new()),
107-
Int64 => Box::new(NativeDistinctCountAccumulator::<Int64Type>::new()),
108-
UInt8 => Box::new(NativeDistinctCountAccumulator::<UInt8Type>::new()),
109-
UInt16 => Box::new(NativeDistinctCountAccumulator::<UInt16Type>::new()),
110-
UInt32 => Box::new(NativeDistinctCountAccumulator::<UInt32Type>::new()),
111-
UInt64 => Box::new(NativeDistinctCountAccumulator::<UInt64Type>::new()),
101+
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
102+
Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
103+
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
104+
Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new()),
105+
UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new()),
106+
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
107+
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
108+
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
112109
Decimal128(_, _) => {
113-
Box::new(NativeDistinctCountAccumulator::<Decimal128Type>::new())
110+
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
114111
}
115112
Decimal256(_, _) => {
116-
Box::new(NativeDistinctCountAccumulator::<Decimal256Type>::new())
113+
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
117114
}
118115

119-
Date32 => Box::new(NativeDistinctCountAccumulator::<Date32Type>::new()),
120-
Date64 => Box::new(NativeDistinctCountAccumulator::<Date64Type>::new()),
121-
Time32(Millisecond) => {
122-
Box::new(NativeDistinctCountAccumulator::<Time32MillisecondType>::new())
123-
}
116+
Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
117+
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
118+
Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::<
119+
Time32MillisecondType,
120+
>::new()),
124121
Time32(Second) => {
125-
Box::new(NativeDistinctCountAccumulator::<Time32SecondType>::new())
126-
}
127-
Time64(Microsecond) => {
128-
Box::new(NativeDistinctCountAccumulator::<Time64MicrosecondType>::new())
122+
Box::new(PrimitiveDistinctCountAccumulator::<Time32SecondType>::new())
129123
}
124+
Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::<
125+
Time64MicrosecondType,
126+
>::new()),
130127
Time64(Nanosecond) => {
131-
Box::new(NativeDistinctCountAccumulator::<Time64NanosecondType>::new())
128+
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
132129
}
133-
Timestamp(Microsecond, _) => Box::new(NativeDistinctCountAccumulator::<
130+
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
134131
TimestampMicrosecondType,
135132
>::new()),
136-
Timestamp(Millisecond, _) => Box::new(NativeDistinctCountAccumulator::<
133+
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
137134
TimestampMillisecondType,
138135
>::new()),
139-
Timestamp(Nanosecond, _) => {
140-
Box::new(NativeDistinctCountAccumulator::<TimestampNanosecondType>::new())
141-
}
136+
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
137+
TimestampNanosecondType,
138+
>::new()),
142139
Timestamp(Second, _) => {
143-
Box::new(NativeDistinctCountAccumulator::<TimestampSecondType>::new())
140+
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
144141
}
145142

146143
Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
@@ -175,9 +172,13 @@ impl PartialEq<dyn Any> for DistinctCount {
175172
}
176173
}
177174

175+
/// General purpose distinct accumulator that works for any DataType by using
176+
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
177+
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
178+
/// [`StringDistinctCountAccumulator`]
178179
#[derive(Debug)]
179180
struct DistinctCountAccumulator {
180-
values: HashSet<DistinctScalarValues, RandomState>,
181+
values: HashSet<ScalarValue, RandomState>,
181182
state_data_type: DataType,
182183
}
183184

@@ -186,7 +187,7 @@ impl DistinctCountAccumulator {
186187
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
187188
fn fixed_size(&self) -> usize {
188189
std::mem::size_of_val(self)
189-
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
190+
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
190191
+ self
191192
.values
192193
.iter()
@@ -199,7 +200,7 @@ impl DistinctCountAccumulator {
199200
// calculates the size as accurate as possible, call to this method is expensive
200201
fn full_size(&self) -> usize {
201202
std::mem::size_of_val(self)
202-
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
203+
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
203204
+ self
204205
.values
205206
.iter()
@@ -260,182 +261,6 @@ impl Accumulator for DistinctCountAccumulator {
260261
}
261262
}
262263

263-
#[derive(Debug)]
264-
struct NativeDistinctCountAccumulator<T>
265-
where
266-
T: ArrowPrimitiveType + Send,
267-
T::Native: Eq + Hash,
268-
{
269-
values: HashSet<T::Native, RandomState>,
270-
}
271-
272-
impl<T> NativeDistinctCountAccumulator<T>
273-
where
274-
T: ArrowPrimitiveType + Send,
275-
T::Native: Eq + Hash,
276-
{
277-
fn new() -> Self {
278-
Self {
279-
values: HashSet::default(),
280-
}
281-
}
282-
}
283-
284-
impl<T> Accumulator for NativeDistinctCountAccumulator<T>
285-
where
286-
T: ArrowPrimitiveType + Send + Debug,
287-
T::Native: Eq + Hash,
288-
{
289-
fn state(&mut self) -> Result<Vec<ScalarValue>> {
290-
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
291-
self.values.iter().cloned(),
292-
)) as ArrayRef;
293-
let list = Arc::new(array_into_list_array(arr));
294-
Ok(vec![ScalarValue::List(list)])
295-
}
296-
297-
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
298-
if values.is_empty() {
299-
return Ok(());
300-
}
301-
302-
let arr = as_primitive_array::<T>(&values[0])?;
303-
arr.iter().for_each(|value| {
304-
if let Some(value) = value {
305-
self.values.insert(value);
306-
}
307-
});
308-
309-
Ok(())
310-
}
311-
312-
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
313-
if states.is_empty() {
314-
return Ok(());
315-
}
316-
assert_eq!(
317-
states.len(),
318-
1,
319-
"count_distinct states must be single array"
320-
);
321-
322-
let arr = as_list_array(&states[0])?;
323-
arr.iter().try_for_each(|maybe_list| {
324-
if let Some(list) = maybe_list {
325-
let list = as_primitive_array::<T>(&list)?;
326-
self.values.extend(list.values())
327-
};
328-
Ok(())
329-
})
330-
}
331-
332-
fn evaluate(&mut self) -> Result<ScalarValue> {
333-
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
334-
}
335-
336-
fn size(&self) -> usize {
337-
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
338-
/ 7)
339-
.next_power_of_two();
340-
341-
// Size of accumulator
342-
// + size of entry * number of buckets
343-
// + 1 byte for each bucket
344-
// + fixed size of HashSet
345-
std::mem::size_of_val(self)
346-
+ std::mem::size_of::<T::Native>() * estimated_buckets
347-
+ estimated_buckets
348-
+ std::mem::size_of_val(&self.values)
349-
}
350-
}
351-
352-
#[derive(Debug)]
353-
struct FloatDistinctCountAccumulator<T>
354-
where
355-
T: ArrowPrimitiveType + Send,
356-
{
357-
values: HashSet<Hashable<T::Native>, RandomState>,
358-
}
359-
360-
impl<T> FloatDistinctCountAccumulator<T>
361-
where
362-
T: ArrowPrimitiveType + Send,
363-
{
364-
fn new() -> Self {
365-
Self {
366-
values: HashSet::default(),
367-
}
368-
}
369-
}
370-
371-
impl<T> Accumulator for FloatDistinctCountAccumulator<T>
372-
where
373-
T: ArrowPrimitiveType + Send + Debug,
374-
{
375-
fn state(&mut self) -> Result<Vec<ScalarValue>> {
376-
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
377-
self.values.iter().map(|v| v.0),
378-
)) as ArrayRef;
379-
let list = Arc::new(array_into_list_array(arr));
380-
Ok(vec![ScalarValue::List(list)])
381-
}
382-
383-
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
384-
if values.is_empty() {
385-
return Ok(());
386-
}
387-
388-
let arr = as_primitive_array::<T>(&values[0])?;
389-
arr.iter().for_each(|value| {
390-
if let Some(value) = value {
391-
self.values.insert(Hashable(value));
392-
}
393-
});
394-
395-
Ok(())
396-
}
397-
398-
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
399-
if states.is_empty() {
400-
return Ok(());
401-
}
402-
assert_eq!(
403-
states.len(),
404-
1,
405-
"count_distinct states must be single array"
406-
);
407-
408-
let arr = as_list_array(&states[0])?;
409-
arr.iter().try_for_each(|maybe_list| {
410-
if let Some(list) = maybe_list {
411-
let list = as_primitive_array::<T>(&list)?;
412-
self.values
413-
.extend(list.values().iter().map(|v| Hashable(*v)));
414-
};
415-
Ok(())
416-
})
417-
}
418-
419-
fn evaluate(&mut self) -> Result<ScalarValue> {
420-
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
421-
}
422-
423-
fn size(&self) -> usize {
424-
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
425-
/ 7)
426-
.next_power_of_two();
427-
428-
// Size of accumulator
429-
// + size of entry * number of buckets
430-
// + 1 byte for each bucket
431-
// + fixed size of HashSet
432-
std::mem::size_of_val(self)
433-
+ std::mem::size_of::<T::Native>() * estimated_buckets
434-
+ estimated_buckets
435-
+ std::mem::size_of_val(&self.values)
436-
}
437-
}
438-
439264
#[cfg(test)]
440265
mod tests {
441266
use arrow::array::{

0 commit comments

Comments
 (0)