Skip to content

Commit 45fb919

Browse files
authored
Add sum_dyn to calculate sum for dictionary array (#2566)
* Add sum_dyn * Add null values test case
1 parent b34adcc commit 45fb919

File tree

1 file changed

+59
-4
lines changed

1 file changed

+59
-4
lines changed

arrow/src/compute/kernels/aggregate.rs

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ use multiversion::multiversion;
2121
use std::ops::Add;
2222

2323
use crate::array::{
24-
Array, BooleanArray, GenericBinaryArray, GenericStringArray, OffsetSizeTrait,
25-
PrimitiveArray,
24+
as_primitive_array, Array, ArrayAccessor, ArrayIter, BooleanArray,
25+
GenericBinaryArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
2626
};
27-
use crate::datatypes::{ArrowNativeType, ArrowNumericType};
27+
use crate::datatypes::{ArrowNativeType, ArrowNumericType, DataType};
2828

2929
/// Generic test for NaN, the optimizer should be able to remove this for integer types.
3030
#[inline]
@@ -185,6 +185,37 @@ pub fn min_string<T: OffsetSizeTrait>(array: &GenericStringArray<T>) -> Option<&
185185
}
186186

187187
/// Returns the sum of values in the array.
188+
pub fn sum_dyn<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
189+
where
190+
T: ArrowNumericType,
191+
T::Native: Add<Output = T::Native>,
192+
{
193+
match array.data_type() {
194+
DataType::Dictionary(_, _) => {
195+
let null_count = array.null_count();
196+
197+
if null_count == array.len() {
198+
return None;
199+
}
200+
201+
let iter = ArrayIter::new(array);
202+
let sum = iter
203+
.into_iter()
204+
.fold(T::default_value(), |accumulator, value| {
205+
if let Some(value) = value {
206+
accumulator + value
207+
} else {
208+
accumulator
209+
}
210+
});
211+
212+
Some(sum)
213+
}
214+
_ => sum::<T>(as_primitive_array(&array)),
215+
}
216+
}
217+
218+
/// Returns the sum of values in the primitive array.
188219
///
189220
/// Returns `None` if the array is empty or only contains null values.
190221
#[cfg(not(feature = "simd"))]
@@ -583,7 +614,7 @@ mod simd {
583614
}
584615
}
585616

586-
/// Returns the sum of values in the array.
617+
/// Returns the sum of values in the primitive array.
587618
///
588619
/// Returns `None` if the array is empty or only contains null values.
589620
#[cfg(feature = "simd")]
@@ -625,6 +656,7 @@ mod tests {
625656
use super::*;
626657
use crate::array::*;
627658
use crate::compute::add;
659+
use crate::datatypes::{Int32Type, Int8Type};
628660

629661
#[test]
630662
fn test_primitive_array_sum() {
@@ -1003,4 +1035,27 @@ mod tests {
10031035
assert_eq!(Some(true), min_boolean(&a));
10041036
assert_eq!(Some(true), max_boolean(&a));
10051037
}
1038+
1039+
#[test]
1040+
fn test_sum_dyn() {
1041+
let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]);
1042+
let keys = Int8Array::from_iter_values([2_i8, 3, 4]);
1043+
1044+
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
1045+
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
1046+
assert_eq!(39, sum_dyn::<Int8Type, _>(array).unwrap());
1047+
1048+
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
1049+
assert_eq!(15, sum_dyn::<Int32Type, _>(&a).unwrap());
1050+
1051+
let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]);
1052+
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
1053+
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
1054+
assert_eq!(26, sum_dyn::<Int8Type, _>(array).unwrap());
1055+
1056+
let keys = Int8Array::from(vec![None, None, None]);
1057+
let dict_array = DictionaryArray::try_new(&keys, &values).unwrap();
1058+
let array = dict_array.downcast_dict::<Int8Array>().unwrap();
1059+
assert!(sum_dyn::<Int8Type, _>(array).is_none());
1060+
}
10061061
}

0 commit comments

Comments
 (0)