From eb0bb915838de0927bac2b356c2e3d27fad79a17 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 29 Oct 2024 17:47:57 -0700 Subject: [PATCH] fix agg decimal --- src/daft-core/src/array/ops/apply.rs | 1 - src/daft-core/src/array/ops/sum.rs | 23 +++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/daft-core/src/array/ops/apply.rs b/src/daft-core/src/array/ops/apply.rs index bf456868e8..904ab0a057 100644 --- a/src/daft-core/src/array/ops/apply.rs +++ b/src/daft-core/src/array/ops/apply.rs @@ -34,7 +34,6 @@ where R: DaftNumericType, F: Fn(T::Native, R::Native) -> T::Native + Copy, { - assert_eq!(self.data_type(), rhs.data_type()); match (self.len(), rhs.len()) { (x, y) if x == y => { let lhs_arr: &PrimitiveArray = diff --git a/src/daft-core/src/array/ops/sum.rs b/src/daft-core/src/array/ops/sum.rs index 9091e84efe..c4d7e3157b 100644 --- a/src/daft-core/src/array/ops/sum.rs +++ b/src/daft-core/src/array/ops/sum.rs @@ -11,16 +11,18 @@ macro_rules! impl_daft_numeric_agg { fn sum(&self) -> Self::Output { let primitive_arr = self.as_arrow(); let sum_value = arrow2::compute::aggregate::sum_primitive(primitive_arr); - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([sum_value])); - DataArray::new(self.field.clone(), arrow_array) + Ok(DataArray::<$T>::from_iter( + self.field.clone(), + std::iter::once(sum_value), + )) } fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output { - use arrow2::array::PrimitiveArray; let arrow_array = self.as_arrow(); let sum_per_group = if arrow_array.null_count() > 0 { - Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map( - |g| { + DataArray::<$T>::from_iter( + self.field.clone(), + groups.iter().map(|g| { g.iter().fold(None, |acc, index| { let idx = *index as usize; match (acc, arrow_array.is_null(idx)) { @@ -29,20 +31,21 @@ macro_rules! impl_daft_numeric_agg { (Some(acc), false) => Some(acc + arrow_array.value(idx)), } }) - }, - ))) + }), + ) } else { - Box::new(PrimitiveArray::from_trusted_len_values_iter( + DataArray::<$T>::from_values_iter( + self.field.clone(), groups.iter().map(|g| { g.iter().fold(0 as $AggType, |acc, index| { let idx = *index as usize; acc + unsafe { arrow_array.value_unchecked(idx) } }) }), - )) + ) }; - DataArray::new(self.field.clone(), sum_per_group) + Ok(sum_per_group) } } };