Skip to content

Commit 5ab75b1

Browse files
committed
Fix aggregate nullability calculation
1 parent a98b6a0 commit 5ab75b1

File tree

5 files changed

+157
-290
lines changed

5 files changed

+157
-290
lines changed

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -565,11 +565,12 @@ where
565565
let sums = std::mem::take(&mut self.sums);
566566
let nulls = self.null_state.build();
567567

568+
assert_eq!(nulls.len(), sums.len());
568569
assert_eq!(counts.len(), sums.len());
569570

570571
// don't evaluate averages with null inputs to avoid errors on null values
571-
let array: PrimitiveArray<T> = if let Some(nulls) = nulls.as_ref() {
572-
assert_eq!(nulls.len(), sums.len());
572+
573+
let array: PrimitiveArray<T> = if nulls.null_count() > 0 {
573574
let mut builder = PrimitiveBuilder::<T>::with_capacity(nulls.len());
574575
let iter = sums.into_iter().zip(counts.into_iter()).zip(nulls.iter());
575576

@@ -587,7 +588,7 @@ where
587588
.zip(counts.into_iter())
588589
.map(|(sum, count)| (self.avg_fn)(sum, count))
589590
.collect::<Result<Vec<_>>>()?;
590-
PrimitiveArray::new(averages.into(), nulls) // no copy
591+
PrimitiveArray::new(averages.into(), Some(nulls)) // no copy
591592
};
592593

593594
// fix up decimal precision and scale for decimals
@@ -598,9 +599,9 @@ where
598599

599600
// return arrays for sums and counts
600601
fn state(&mut self) -> Result<Vec<ArrayRef>> {
601-
let nulls = self.null_state.build();
602+
let nulls = Some(self.null_state.build());
602603
let counts = std::mem::take(&mut self.counts);
603-
let counts = UInt64Array::from(counts); // zero copy
604+
let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy
604605

605606
let sums = std::mem::take(&mut self.sums);
606607
let sums = PrimitiveArray::<T>::new(sums.into(), nulls); // zero copy

0 commit comments

Comments
 (0)