Skip to content

Commit

Permalink
fix agg decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Oct 30, 2024
1 parent 87c4bf1 commit eb0bb91
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
1 change: 0 additions & 1 deletion src/daft-core/src/array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ where
R: DaftNumericType,
F: Fn(T::Native, R::Native) -> T::Native + Copy,
{
assert_eq!(self.data_type(), rhs.data_type());
match (self.len(), rhs.len()) {
(x, y) if x == y => {
let lhs_arr: &PrimitiveArray<T::Native> =
Expand Down
23 changes: 13 additions & 10 deletions src/daft-core/src/array/ops/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ macro_rules! impl_daft_numeric_agg {
fn sum(&self) -> Self::Output {
let primitive_arr = self.as_arrow();
let sum_value = arrow2::compute::aggregate::sum_primitive(primitive_arr);
let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([sum_value]));
DataArray::new(self.field.clone(), arrow_array)
Ok(DataArray::<$T>::from_iter(
self.field.clone(),
std::iter::once(sum_value),
))
}

fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output {
use arrow2::array::PrimitiveArray;
let arrow_array = self.as_arrow();
let sum_per_group = if arrow_array.null_count() > 0 {
Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(
|g| {
DataArray::<$T>::from_iter(
self.field.clone(),
groups.iter().map(|g| {
g.iter().fold(None, |acc, index| {
let idx = *index as usize;
match (acc, arrow_array.is_null(idx)) {
Expand All @@ -29,20 +31,21 @@ macro_rules! impl_daft_numeric_agg {
(Some(acc), false) => Some(acc + arrow_array.value(idx)),
}
})
},
)))
}),
)
} else {
Box::new(PrimitiveArray::from_trusted_len_values_iter(
DataArray::<$T>::from_values_iter(
self.field.clone(),
groups.iter().map(|g| {
g.iter().fold(0 as $AggType, |acc, index| {
let idx = *index as usize;
acc + unsafe { arrow_array.value_unchecked(idx) }
})
}),
))
)
};

DataArray::new(self.field.clone(), sum_per_group)
Ok(sum_per_group)
}
}
};
Expand Down

0 comments on commit eb0bb91

Please sign in to comment.