-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: native types in DistinctCountAccumulator
for primitive types
#8721
Changes from all commits
aa7199e
891b541
c560ca5
251fed2
541fca8
ac870f9
5e7dfdb
dff53b5
b6772dd
a3944cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,21 +15,32 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::datatypes::{DataType, Field}; | ||
use arrow::datatypes::{DataType, Field, TimeUnit}; | ||
use arrow_array::types::{ | ||
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, | ||
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, | ||
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, | ||
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, | ||
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
}; | ||
use arrow_array::PrimitiveArray; | ||
|
||
use std::any::Any; | ||
use std::cmp::Eq; | ||
use std::fmt::Debug; | ||
use std::hash::Hash; | ||
use std::sync::Arc; | ||
|
||
use ahash::RandomState; | ||
use arrow::array::{Array, ArrayRef}; | ||
use std::collections::HashSet; | ||
|
||
use crate::aggregate::utils::down_cast_any_ref; | ||
use crate::aggregate::utils::{down_cast_any_ref, Hashable}; | ||
use crate::expressions::format_state_name; | ||
use crate::{AggregateExpr, PhysicalExpr}; | ||
use datafusion_common::Result; | ||
use datafusion_common::ScalarValue; | ||
use datafusion_common::cast::{as_list_array, as_primitive_array}; | ||
use datafusion_common::utils::array_into_list_array; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::Accumulator; | ||
|
||
type DistinctScalarValues = ScalarValue; | ||
|
@@ -60,6 +71,18 @@ impl DistinctCount { | |
} | ||
} | ||
|
||
macro_rules! native_distinct_count_accumulator { | ||
($TYPE:ident) => {{ | ||
Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new())) | ||
}}; | ||
} | ||
|
||
macro_rules! float_distinct_count_accumulator { | ||
($TYPE:ident) => {{ | ||
Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new())) | ||
}}; | ||
} | ||
|
||
impl AggregateExpr for DistinctCount { | ||
/// Return a reference to Any that can be used for downcasting | ||
fn as_any(&self) -> &dyn Any { | ||
|
@@ -83,10 +106,57 @@ impl AggregateExpr for DistinctCount { | |
} | ||
|
||
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> { | ||
Ok(Box::new(DistinctCountAccumulator { | ||
values: HashSet::default(), | ||
state_data_type: self.state_data_type.clone(), | ||
})) | ||
use DataType::*; | ||
use TimeUnit::*; | ||
|
||
match &self.state_data_type { | ||
Int8 => native_distinct_count_accumulator!(Int8Type), | ||
Int16 => native_distinct_count_accumulator!(Int16Type), | ||
Int32 => native_distinct_count_accumulator!(Int32Type), | ||
Int64 => native_distinct_count_accumulator!(Int64Type), | ||
UInt8 => native_distinct_count_accumulator!(UInt8Type), | ||
UInt16 => native_distinct_count_accumulator!(UInt16Type), | ||
UInt32 => native_distinct_count_accumulator!(UInt32Type), | ||
UInt64 => native_distinct_count_accumulator!(UInt64Type), | ||
Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type), | ||
Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type), | ||
|
||
Date32 => native_distinct_count_accumulator!(Date32Type), | ||
Date64 => native_distinct_count_accumulator!(Date64Type), | ||
Time32(Millisecond) => { | ||
native_distinct_count_accumulator!(Time32MillisecondType) | ||
} | ||
Time32(Second) => { | ||
native_distinct_count_accumulator!(Time32SecondType) | ||
} | ||
Time64(Microsecond) => { | ||
native_distinct_count_accumulator!(Time64MicrosecondType) | ||
} | ||
Time64(Nanosecond) => { | ||
native_distinct_count_accumulator!(Time64NanosecondType) | ||
} | ||
Timestamp(Microsecond, _) => { | ||
native_distinct_count_accumulator!(TimestampMicrosecondType) | ||
} | ||
Timestamp(Millisecond, _) => { | ||
native_distinct_count_accumulator!(TimestampMillisecondType) | ||
} | ||
Timestamp(Nanosecond, _) => { | ||
native_distinct_count_accumulator!(TimestampNanosecondType) | ||
} | ||
Timestamp(Second, _) => { | ||
native_distinct_count_accumulator!(TimestampSecondType) | ||
} | ||
|
||
Float16 => float_distinct_count_accumulator!(Float16Type), | ||
Float32 => float_distinct_count_accumulator!(Float32Type), | ||
Float64 => float_distinct_count_accumulator!(Float64Type), | ||
|
||
_ => Ok(Box::new(DistinctCountAccumulator { | ||
values: HashSet::default(), | ||
state_data_type: self.state_data_type.clone(), | ||
})), | ||
} | ||
} | ||
|
||
fn name(&self) -> &str { | ||
|
@@ -192,6 +262,182 @@ impl Accumulator for DistinctCountAccumulator { | |
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct NativeDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send, | ||
T::Native: Eq + Hash, | ||
{ | ||
values: HashSet<T::Native, RandomState>, | ||
} | ||
|
||
impl<T> NativeDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send, | ||
T::Native: Eq + Hash, | ||
{ | ||
fn new() -> Self { | ||
Self { | ||
values: HashSet::default(), | ||
} | ||
} | ||
} | ||
|
||
impl<T> Accumulator for NativeDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send + Debug, | ||
T::Native: Eq + Hash, | ||
{ | ||
fn state(&self) -> Result<Vec<ScalarValue>> { | ||
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values( | ||
self.values.iter().cloned(), | ||
)) as ArrayRef; | ||
let list = Arc::new(array_into_list_array(arr)) as ArrayRef; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👏 for @jayzhan211 for switching the native implementation of |
||
Ok(vec![ScalarValue::List(list)]) | ||
} | ||
|
||
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
if values.is_empty() { | ||
return Ok(()); | ||
} | ||
|
||
let arr = as_primitive_array::<T>(&values[0])?; | ||
arr.iter().for_each(|value| { | ||
if let Some(value) = value { | ||
self.values.insert(value); | ||
} | ||
}); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
if states.is_empty() { | ||
return Ok(()); | ||
} | ||
assert_eq!( | ||
states.len(), | ||
1, | ||
"count_distinct states must be single array" | ||
); | ||
|
||
let arr = as_list_array(&states[0])?; | ||
arr.iter().try_for_each(|maybe_list| { | ||
if let Some(list) = maybe_list { | ||
let list = as_primitive_array::<T>(&list)?; | ||
self.values.extend(list.values()) | ||
}; | ||
Ok(()) | ||
}) | ||
} | ||
|
||
fn evaluate(&self) -> Result<ScalarValue> { | ||
Ok(ScalarValue::Int64(Some(self.values.len() as i64))) | ||
} | ||
|
||
fn size(&self) -> usize { | ||
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we can (as a follow on PR) put this logic into its own function (with comments) as estimating the size of hashbrown hashtables is likely to come up again |
||
/ 7) | ||
.next_power_of_two(); | ||
|
||
// Size of accumulator | ||
// + size of entry * number of buckets | ||
// + 1 byte for each bucket | ||
// + fixed size of HashSet | ||
std::mem::size_of_val(self) | ||
+ std::mem::size_of::<T::Native>() * estimated_buckets | ||
+ estimated_buckets | ||
+ std::mem::size_of_val(&self.values) | ||
} | ||
} | ||
|
||
#[derive(Debug)] | ||
struct FloatDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send, | ||
{ | ||
values: HashSet<Hashable<T::Native>, RandomState>, | ||
korowa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
impl<T> FloatDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send, | ||
{ | ||
fn new() -> Self { | ||
Self { | ||
values: HashSet::default(), | ||
} | ||
} | ||
} | ||
|
||
impl<T> Accumulator for FloatDistinctCountAccumulator<T> | ||
where | ||
T: ArrowPrimitiveType + Send + Debug, | ||
{ | ||
fn state(&self) -> Result<Vec<ScalarValue>> { | ||
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values( | ||
self.values.iter().map(|v| v.0), | ||
)) as ArrayRef; | ||
let list = Arc::new(array_into_list_array(arr)) as ArrayRef; | ||
Ok(vec![ScalarValue::List(list)]) | ||
} | ||
|
||
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { | ||
if values.is_empty() { | ||
return Ok(()); | ||
} | ||
|
||
let arr = as_primitive_array::<T>(&values[0])?; | ||
arr.iter().for_each(|value| { | ||
if let Some(value) = value { | ||
self.values.insert(Hashable(value)); | ||
} | ||
}); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { | ||
if states.is_empty() { | ||
return Ok(()); | ||
} | ||
assert_eq!( | ||
states.len(), | ||
1, | ||
"count_distinct states must be single array" | ||
); | ||
|
||
let arr = as_list_array(&states[0])?; | ||
arr.iter().try_for_each(|maybe_list| { | ||
if let Some(list) = maybe_list { | ||
let list = as_primitive_array::<T>(&list)?; | ||
self.values | ||
.extend(list.values().iter().map(|v| Hashable(*v))); | ||
}; | ||
Ok(()) | ||
}) | ||
} | ||
|
||
fn evaluate(&self) -> Result<ScalarValue> { | ||
Ok(ScalarValue::Int64(Some(self.values.len() as i64))) | ||
} | ||
|
||
fn size(&self) -> usize { | ||
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX) | ||
/ 7) | ||
.next_power_of_two(); | ||
|
||
// Size of accumulator | ||
// + size of entry * number of buckets | ||
// + 1 byte for each bucket | ||
// + fixed size of HashSet | ||
std::mem::size_of_val(self) | ||
+ std::mem::size_of::<T::Native>() * estimated_buckets | ||
+ estimated_buckets | ||
+ std::mem::size_of_val(&self.values) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::expressions::NoOp; | ||
|
@@ -206,6 +452,8 @@ mod tests { | |
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, | ||
UInt32Type, UInt64Type, UInt8Type, | ||
}; | ||
use arrow_array::Decimal256Array; | ||
use arrow_buffer::i256; | ||
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; | ||
use datafusion_common::internal_err; | ||
use datafusion_common::DataFusionError; | ||
|
@@ -367,6 +615,35 @@ mod tests { | |
}}; | ||
} | ||
|
||
macro_rules! test_count_distinct_update_batch_bigint { | ||
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{ | ||
let values: Vec<Option<$PRIM_TYPE>> = vec![ | ||
Some(i256::from(1)), | ||
Some(i256::from(1)), | ||
None, | ||
Some(i256::from(3)), | ||
Some(i256::from(2)), | ||
None, | ||
Some(i256::from(2)), | ||
Some(i256::from(3)), | ||
Some(i256::from(1)), | ||
]; | ||
|
||
let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef]; | ||
|
||
let (states, result) = run_update_batch(&arrays)?; | ||
|
||
let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE); | ||
state_vec.sort(); | ||
|
||
assert_eq!(states.len(), 1); | ||
assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]); | ||
assert_eq!(result, ScalarValue::Int64(Some(3))); | ||
|
||
Ok(()) | ||
}}; | ||
} | ||
|
||
#[test] | ||
fn count_distinct_update_batch_i8() -> Result<()> { | ||
test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8) | ||
|
@@ -417,6 +694,11 @@ mod tests { | |
test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64) | ||
} | ||
|
||
#[test] | ||
fn count_distinct_update_batch_i256() -> Result<()> { | ||
test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256) | ||
} | ||
|
||
#[test] | ||
fn count_distinct_update_batch_boolean() -> Result<()> { | ||
let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for now, but if we would like to do it for strings / bytes, we could do use a datastructure like this to get maximal performance:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is similar to the idea in #7064
Maybe we can eventually use the same data structure (specialized for storing string values not using a
String
)